qgyd2021 commited on
Commit
3c85855
1 Parent(s): db14530

[update]add main

Browse files
Files changed (1) hide show
  1. main.py +54 -2
main.py CHANGED
@@ -5,6 +5,7 @@ from threading import Thread
5
 
6
  import gradio as gr
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
 
8
  import torch
9
 
10
  from project_settings import project_path
@@ -109,6 +110,57 @@ def chat_with_llm_non_stream(question: str,
109
  return result
110
 
111
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
112
  def main():
113
  description = """
114
  chat llm
@@ -153,8 +205,8 @@ def main():
153
  outputs = [
154
  chatbot
155
  ]
156
- text_box.submit(chat_with_llm_non_stream, inputs, outputs)
157
- submit_button.click(chat_with_llm_non_stream, inputs, outputs)
158
  clear_button.click(
159
  fn=lambda: ('', ''),
160
  outputs=[text_box, chatbot],
 
5
 
6
  import gradio as gr
7
  from transformers import AutoModelForCausalLM, AutoTokenizer
8
+ from transformers.generation.streamers import TextIteratorStreamer
9
  import torch
10
 
11
  from project_settings import project_path
 
110
  return result
111
 
112
 
113
+ def chat_with_llm_streaming(question: str,
114
+ history: List[Tuple[str, str]],
115
+ pretrained_model_name_or_path: str,
116
+ max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
117
+ ):
118
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
119
+
120
+ model, tokenizer = init_model(pretrained_model_name_or_path)
121
+
122
+ text_list = list()
123
+ for pair in history:
124
+ text_list.extend(pair)
125
+ text_list.append(question)
126
+
127
+ text_encoded = tokenizer.__call__(text_list, add_special_tokens=False)
128
+ batch_input_ids = text_encoded["input_ids"]
129
+
130
+ input_ids = [tokenizer.bos_token_id]
131
+ for input_ids_ in batch_input_ids:
132
+ input_ids.extend(input_ids_)
133
+ input_ids.append(tokenizer.eos_token_id)
134
+ input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
135
+
136
+ streamer = TextIteratorStreamer(tokenizer=tokenizer)
137
+
138
+ generation_kwargs = dict(
139
+ inputs=input_ids,
140
+ max_new_tokens=max_new_tokens,
141
+ do_sample=True,
142
+ top_p=top_p,
143
+ temperature=temperature,
144
+ repetition_penalty=repetition_penalty,
145
+ eos_token_id=tokenizer.eos_token_id,
146
+ pad_token_id=tokenizer.pad_token_id,
147
+ streamer=streamer,
148
+ )
149
+ thread = Thread(target=model.generate, kwargs=generation_kwargs)
150
+ thread.start()
151
+
152
+ answer = ""
153
+ for output_ in streamer:
154
+ output_ = output_.replace(question, "")
155
+ output_ = output_.replace(tokenizer.eos_token, "")
156
+
157
+ answer += output_
158
+
159
+ result = [(question, answer)]
160
+
161
+ yield history + result
162
+
163
+
164
  def main():
165
  description = """
166
  chat llm
 
205
  outputs = [
206
  chatbot
207
  ]
208
+ text_box.submit(chat_with_llm_streaming, inputs, outputs)
209
+ submit_button.click(chat_with_llm_streaming, inputs, outputs)
210
  clear_button.click(
211
  fn=lambda: ('', ''),
212
  outputs=[text_box, chatbot],