vilarin commited on
Commit
edb9e8a
1 Parent(s): fece758

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +12 -5
app.py CHANGED
@@ -2,9 +2,9 @@ import torch
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
- from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig
6
  import os
7
- import time
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
@@ -69,17 +69,24 @@ def stream_chat(message: str, history: list, temperature: float, max_new_tokens:
69
 
70
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
71
 
 
72
 
73
- gen_tokens= model.generate(
74
  input_ids,
 
75
  max_new_tokens=max_new_tokens,
76
  do_sample=True,
77
  temperature=temperature,
78
  )
 
 
 
79
 
80
- gen_text = tokenizer.batch_decode(gen_tokens[0], skip_special_tokens=True)
 
 
 
81
 
82
- return gen_text
83
 
84
 
85
  chatbot = gr.Chatbot(height=450)
 
2
  from PIL import Image
3
  import gradio as gr
4
  import spaces
5
+ from transformers import AutoModelForCausalLM, AutoTokenizer, BitsAndBytesConfig, TextIteratorStreamer
6
  import os
7
+ from threading import Thread
8
 
9
 
10
  HF_TOKEN = os.environ.get("HF_TOKEN", None)
 
69
 
70
  input_ids = tokenizer.apply_chat_template(conversation, tokenize=True, add_generation_prompt=True, return_tensors="pt").to(model.device)
71
 
72
+ streamer = TextIteratorStreamer(tokenizer, **{"skip_special_tokens": True, "skip_prompt": True, 'clean_up_tokenization_spaces':False,})
73
 
74
+ generate_kwargs = dict(
75
  input_ids,
76
+ streamer=streamer,
77
  max_new_tokens=max_new_tokens,
78
  do_sample=True,
79
  temperature=temperature,
80
  )
81
+
82
+ thread = Thread(target=model.generate, kwargs=generate_kwargs)
83
+ thread.start()
84
 
85
+ buffer = ""
86
+ for new_text in streamer:
87
+ buffer += new_text
88
+ yield buffer
89
 
 
90
 
91
 
92
  chatbot = gr.Chatbot(height=450)