Spaces:
Sleeping
Sleeping
[update]add main
Browse files
main.py
CHANGED
@@ -5,6 +5,7 @@ from threading import Thread
|
|
5 |
|
6 |
import gradio as gr
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
|
|
8 |
import torch
|
9 |
|
10 |
from project_settings import project_path
|
@@ -109,6 +110,57 @@ def chat_with_llm_non_stream(question: str,
|
|
109 |
return result
|
110 |
|
111 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
112 |
def main():
|
113 |
description = """
|
114 |
chat llm
|
@@ -153,8 +205,8 @@ def main():
|
|
153 |
outputs = [
|
154 |
chatbot
|
155 |
]
|
156 |
-
text_box.submit(
|
157 |
-
submit_button.click(
|
158 |
clear_button.click(
|
159 |
fn=lambda: ('', ''),
|
160 |
outputs=[text_box, chatbot],
|
|
|
5 |
|
6 |
import gradio as gr
|
7 |
from transformers import AutoModelForCausalLM, AutoTokenizer
|
8 |
+
from transformers.generation.streamers import TextIteratorStreamer
|
9 |
import torch
|
10 |
|
11 |
from project_settings import project_path
|
|
|
110 |
return result
|
111 |
|
112 |
|
113 |
+
def chat_with_llm_streaming(question: str,
|
114 |
+
history: List[Tuple[str, str]],
|
115 |
+
pretrained_model_name_or_path: str,
|
116 |
+
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
|
117 |
+
):
|
118 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
119 |
+
|
120 |
+
model, tokenizer = init_model(pretrained_model_name_or_path)
|
121 |
+
|
122 |
+
text_list = list()
|
123 |
+
for pair in history:
|
124 |
+
text_list.extend(pair)
|
125 |
+
text_list.append(question)
|
126 |
+
|
127 |
+
text_encoded = tokenizer.__call__(text_list, add_special_tokens=False)
|
128 |
+
batch_input_ids = text_encoded["input_ids"]
|
129 |
+
|
130 |
+
input_ids = [tokenizer.bos_token_id]
|
131 |
+
for input_ids_ in batch_input_ids:
|
132 |
+
input_ids.extend(input_ids_)
|
133 |
+
input_ids.append(tokenizer.eos_token_id)
|
134 |
+
input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
|
135 |
+
|
136 |
+
streamer = TextIteratorStreamer(tokenizer=tokenizer)
|
137 |
+
|
138 |
+
generation_kwargs = dict(
|
139 |
+
inputs=input_ids,
|
140 |
+
max_new_tokens=max_new_tokens,
|
141 |
+
do_sample=True,
|
142 |
+
top_p=top_p,
|
143 |
+
temperature=temperature,
|
144 |
+
repetition_penalty=repetition_penalty,
|
145 |
+
eos_token_id=tokenizer.eos_token_id,
|
146 |
+
pad_token_id=tokenizer.pad_token_id,
|
147 |
+
streamer=streamer,
|
148 |
+
)
|
149 |
+
thread = Thread(target=model.generate, kwargs=generation_kwargs)
|
150 |
+
thread.start()
|
151 |
+
|
152 |
+
answer = ""
|
153 |
+
for output_ in streamer:
|
154 |
+
output_ = output_.replace(question, "")
|
155 |
+
output_ = output_.replace(tokenizer.eos_token, "")
|
156 |
+
|
157 |
+
answer += output_
|
158 |
+
|
159 |
+
result = [(question, answer)]
|
160 |
+
|
161 |
+
yield history + result
|
162 |
+
|
163 |
+
|
164 |
def main():
|
165 |
description = """
|
166 |
chat llm
|
|
|
205 |
outputs = [
|
206 |
chatbot
|
207 |
]
|
208 |
+
text_box.submit(chat_with_llm_streaming, inputs, outputs)
|
209 |
+
submit_button.click(chat_with_llm_streaming, inputs, outputs)
|
210 |
clear_button.click(
|
211 |
fn=lambda: ('', ''),
|
212 |
outputs=[text_box, chatbot],
|