Spaces:
Sleeping
Sleeping
File size: 2,316 Bytes
ea507ec a533705 ea507ec bb64cd3 ea507ec a533705 bb64cd3 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 |
import gradio as gr
import numpy as np
import time
from tokenizer import encode, decode, vocab_size
from model import *
model = TokenBasedLanguageModel()
m = model.to(device)
print("Loading checkpoint from file")
# second param is to work on huggingface cpu
checkpoint = torch.load("improved-v5.bin", map_location=torch.device('cpu'))
model.load_state_dict(checkpoint["model_state_dict"])
print("State restored")
def generate_llm(prompt, max_tokens = 512, analyze_probs = False):
prompt_encoded = encode(prompt) # trigger book 2 intro
#encode("[1]{.ePub-B}\n") # trigger first chapter
context = torch.tensor(prompt_encoded, dtype = torch.long, device = device).view(1, len(prompt_encoded))
output = prompt[:]
start_time = time.time()
token_count = 0
probtext = ""
for encoded_token_pair in model.generate(context, max_new_tokens=max_tokens, stream = True, stream_probs = analyze_probs):
probtext = ""
encoded_token = encoded_token_pair
if analyze_probs:
[encoded_token, probs] = encoded_token_pair
prob_list = []
for token_id in range(vocab_size):
prob_list.append([token_id, probs[token_id]])
prob_list.sort(key = lambda x: x[1], reverse = True)
for prob_pair in prob_list[:25]:
probtext += f'"{decode([prob_pair[0]])}": {prob_pair[1]}\n'
else:
probtext = "Feature disabled."
part = decode([encoded_token])
output += part
token_count += 1
yield [output, str(token_count / (time.time() - start_time)) + "tok/s " + str(token_count) + " tokens generated.", probtext]
return [output, str(token_count / (time.time() - start_time)) + "tok/s " + str(token_count) + " tokens generated.", probtext]
demo = gr.Interface(generate_llm,
inputs=[gr.TextArea(placeholder = "In the midst of chaos.", value = "Once upon a time"), gr.Number(value = 512, maximum = 2048, minimum = 1, step = 1, label = "Max tokens"), gr.Checkbox(label = "Show probs, 10x slower if run on gpu")],
outputs=[gr.TextArea(label = "Output"), gr.Text(placeholder = "tok/s and other stats", label = "Stats"), gr.TextArea(label = "Probability stats")])
if __name__ == "__main__":
demo.launch(share = False) |