Raymond commited on
Commit
a533705
1 Parent(s): ea507ec
Files changed (1) hide show
  1. app.py +4 -2
app.py CHANGED
@@ -9,7 +9,8 @@ model = TokenBasedLanguageModel()
9
  m = model.to(device)
10
 
11
  print("Loading checkpoint from file")
12
- checkpoint = torch.load("improved-v5.bin")
 
13
  model.load_state_dict(checkpoint["model_state_dict"])
14
  print("State restored")
15
 
@@ -44,4 +45,5 @@ demo = gr.Interface(generate_llm,
44
  inputs=[gr.TextArea(placeholder = "In the midst of chaos."), gr.Number(value = 512, maximum = 2048, minimum = 1, step = 1, label = "Max tokens"), gr.Checkbox(label = "Show probs, 10x slower")],
45
  outputs=[gr.TextArea(label = "Output"), gr.Text(placeholder = "tok/s and other stats", label = "Stats"), gr.TextArea(label = "Probability stats")])
46
 
47
- demo.launch(share = True)
 
 
9
  m = model.to(device)
10
 
11
  print("Loading checkpoint from file")
12
+ # second param is to work on huggingface cpu
13
+ checkpoint = torch.load("improved-v5.bin", map_location=torch.device('cpu'))
14
  model.load_state_dict(checkpoint["model_state_dict"])
15
  print("State restored")
16
 
 
45
  inputs=[gr.TextArea(placeholder = "In the midst of chaos."), gr.Number(value = 512, maximum = 2048, minimum = 1, step = 1, label = "Max tokens"), gr.Checkbox(label = "Show probs, 10x slower")],
46
  outputs=[gr.TextArea(label = "Output"), gr.Text(placeholder = "tok/s and other stats", label = "Stats"), gr.TextArea(label = "Probability stats")])
47
 
48
+ if __name__ == "__main__":
49
+ demo.launch(share = True)