qgyd2021 commited on
Commit
c64dba3
1 Parent(s): 3c85855

[update]add main

Browse files
Files changed (1) hide show
  1. main.py +18 -6
main.py CHANGED
@@ -73,6 +73,7 @@ def chat_with_llm_non_stream(question: str,
73
  history: List[Tuple[str, str]],
74
  pretrained_model_name_or_path: str,
75
  max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
 
76
  ):
77
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
78
 
@@ -90,7 +91,8 @@ def chat_with_llm_non_stream(question: str,
90
  for input_ids_ in batch_input_ids:
91
  input_ids.extend(input_ids_)
92
  input_ids.append(tokenizer.eos_token_id)
93
- input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
 
94
 
95
  with torch.no_grad():
96
  outputs = model.generate(
@@ -114,6 +116,7 @@ def chat_with_llm_streaming(question: str,
114
  history: List[Tuple[str, str]],
115
  pretrained_model_name_or_path: str,
116
  max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
 
117
  ):
118
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
119
 
@@ -131,7 +134,8 @@ def chat_with_llm_streaming(question: str,
131
  for input_ids_ in batch_input_ids:
132
  input_ids.extend(input_ids_)
133
  input_ids.append(tokenizer.eos_token_id)
134
- input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
 
135
 
136
  streamer = TextIteratorStreamer(tokenizer=tokenizer)
137
 
@@ -190,17 +194,25 @@ def main():
190
  temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature")
191
  with gr.Column(scale=1):
192
  repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty")
 
 
193
 
194
  with gr.Row():
195
- model_name = gr.Dropdown(choices=["Qwen/Qwen-7B-Chat"],
196
- value="Qwen/Qwen-7B-Chat",
197
- label="model_name",
198
- )
 
 
 
 
 
199
  gr.Examples(examples=["你好"], inputs=text_box)
200
 
201
  inputs = [
202
  text_box, chatbot, model_name,
203
  max_new_tokens, top_p, temperature, repetition_penalty,
 
204
  ]
205
  outputs = [
206
  chatbot
 
73
  history: List[Tuple[str, str]],
74
  pretrained_model_name_or_path: str,
75
  max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
76
+ history_max_len: int,
77
  ):
78
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
79
 
 
91
  for input_ids_ in batch_input_ids:
92
  input_ids.extend(input_ids_)
93
  input_ids.append(tokenizer.eos_token_id)
94
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
95
+ input_ids = input_ids[:, -history_max_len:].to(device)
96
 
97
  with torch.no_grad():
98
  outputs = model.generate(
 
116
  history: List[Tuple[str, str]],
117
  pretrained_model_name_or_path: str,
118
  max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
119
+ history_max_len: int,
120
  ):
121
  device: str = "cuda" if torch.cuda.is_available() else "cpu"
122
 
 
134
  for input_ids_ in batch_input_ids:
135
  input_ids.extend(input_ids_)
136
  input_ids.append(tokenizer.eos_token_id)
137
+ input_ids = torch.tensor([input_ids], dtype=torch.long)
138
+ input_ids = input_ids[:, -history_max_len:].to(device)
139
 
140
  streamer = TextIteratorStreamer(tokenizer=tokenizer)
141
 
 
194
  temperature = gr.Slider(minimum=0, maximum=1, value=0.35, step=0.01, label="temperature")
195
  with gr.Column(scale=1):
196
  repetition_penalty = gr.Slider(minimum=0, maximum=2, value=1.2, step=0.01, label="repetition_penalty")
197
+ with gr.Column(scale=1):
198
+ history_max_len = gr.Slider(minimum=0, maximum=4096, value=1024, step=1, label="history_max_len")
199
 
200
  with gr.Row():
201
+ model_name = gr.Dropdown(
202
+ choices=[
203
+ "Qwen/Qwen-7B-Chat",
204
+ "THUDM/chatglm2-6b",
205
+ "baichuan-inc/Baichuan2-7B-Chat",
206
+ ],
207
+ value="Qwen/Qwen-7B-Chat",
208
+ label="model_name",
209
+ )
210
  gr.Examples(examples=["你好"], inputs=text_box)
211
 
212
  inputs = [
213
  text_box, chatbot, model_name,
214
  max_new_tokens, top_p, temperature, repetition_penalty,
215
+ history_max_len
216
  ]
217
  outputs = [
218
  chatbot