qgyd2021 commited on
Commit
6e2102d
1 Parent(s): 2b1a555

[update]add main

Browse files
Files changed (1) hide show
  1. main.py +24 -11
main.py CHANGED
@@ -122,18 +122,31 @@ def chat_with_llm_streaming(question: str,
122
 
123
  model, tokenizer = init_model(pretrained_model_name_or_path)
124
 
125
- text_list = list()
126
- for pair in history:
127
- text_list.extend(pair)
128
- text_list.append(question)
129
-
130
- text_encoded = tokenizer.__call__(text_list, add_special_tokens=False)
131
- batch_input_ids = text_encoded["input_ids"]
132
-
133
- input_ids = [tokenizer.bos_token_id]
134
- for input_ids_ in batch_input_ids:
135
- input_ids.extend(input_ids_)
 
 
 
 
 
 
 
 
 
 
 
 
136
  input_ids.append(tokenizer.eos_token_id)
 
137
  input_ids = torch.tensor([input_ids], dtype=torch.long)
138
  input_ids = input_ids[:, -history_max_len:].to(device)
139
 
 
122
 
123
  model, tokenizer = init_model(pretrained_model_name_or_path)
124
 
125
+ if model.config.model_type == "chatglm":
126
+ input_ids = []
127
+ else:
128
+ input_ids = [tokenizer.bos_token_id]
129
+
130
+ # history
131
+ for idx, (h_question, h_answer) in enumerate(history):
132
+ if model.config.model_type == "chatglm":
133
+ h_question = "[Round {}]\n\n问:{}\n\n答:".format(idx, h_question)
134
+ h_question = tokenizer.__call__(h_question, add_special_tokens=False)
135
+ h_answer = tokenizer.__call__(h_answer, add_special_tokens=False)
136
+
137
+ input_ids.append(h_question)
138
+ if model.config.model_type != "chatglm":
139
+ input_ids.append(tokenizer.eos_token_id)
140
+ input_ids.append(h_answer)
141
+ if model.config.model_type != "chatglm":
142
+ input_ids.append(tokenizer.eos_token_id)
143
+
144
+ # question
145
+ question = tokenizer.__call__(question, add_special_tokens=False)
146
+ input_ids.append(question)
147
+ if model.config.model_type != "chatglm":
148
  input_ids.append(tokenizer.eos_token_id)
149
+
150
  input_ids = torch.tensor([input_ids], dtype=torch.long)
151
  input_ids = input_ids[:, -history_max_len:].to(device)
152