codys12 commited on
Commit
4e4fd76
1 Parent(s): a5f97a2
Files changed (1) hide show
  1. app.py +14 -19
app.py CHANGED
@@ -50,25 +50,20 @@ def generate(
50
  input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
51
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
52
 
53
- streamer = TextIteratorStreamer(tokenizer, timeout=10.0, skip_prompt=True, skip_special_tokens=True)
54
- generate_kwargs = dict(
55
- {"input_ids": input_ids},
56
- streamer=streamer,
57
- max_new_tokens=max_new_tokens,
58
- do_sample=True,
59
- top_p=top_p,
60
- top_k=top_k,
61
- temperature=temperature,
62
- num_beams=1,
63
- repetition_penalty=repetition_penalty,
64
- )
65
- t = Thread(target=model.generate, kwargs=generate_kwargs)
66
- t.start()
67
-
68
- outputs = []
69
- for text in streamer:
70
- outputs.append(text)
71
- yield "".join(outputs)
72
 
73
 
74
  chat_interface = gr.ChatInterface(
 
50
  input_ids = input_ids[-MAX_INPUT_TOKEN_LENGTH:]
51
  gr.Warning("Trimmed input from conversation as it was longer than {MAX_INPUT_TOKEN_LENGTH} tokens.")
52
 
53
+ input_ids = tokenizer(current_input, return_tensors="pt").to("cuda")
54
+
55
+ # Generate
56
+ output_ids = model.generate(input_ids=input_ids,
57
+ max_new_tokens=max_new_tokens,
58
+ do_sample=True,
59
+ top_p=top_p,
60
+ top_k=top_k,
61
+ temperature=temperature,
62
+ repetition_penalty=repetition_penalty)
63
+
64
+ # Stream output
65
+ for id in output_ids.tolist()[0]:
66
+ yield tokenizer.decode(id)
 
 
 
 
 
67
 
68
 
69
  chat_interface = gr.ChatInterface(