qgyd2021 commited on
Commit
7cee691
1 Parent(s): 297c6ff
Files changed (1) hide show
  1. main.py +6 -6
main.py CHANGED
@@ -137,12 +137,6 @@ def chat_with_llm_streaming(question: str,
137
 
138
  model, tokenizer = init_model(pretrained_model_name_or_path)
139
 
140
- # input_ids
141
- if model.config.model_type == "chatglm":
142
- input_ids = []
143
- else:
144
- input_ids = [tokenizer.bos_token_id]
145
-
146
  # history
147
  utterances = list()
148
  for idx, (h_question, h_answer) in enumerate(history):
@@ -157,6 +151,12 @@ def chat_with_llm_streaming(question: str,
157
  encoded_utterances = tokenizer.__call__(utterances, add_special_tokens=False)
158
  encoded_utterances = encoded_utterances["input_ids"]
159
 
 
 
 
 
 
 
160
  for encoded_utterance in encoded_utterances:
161
  input_ids.extend(encoded_utterance)
162
  if model.config.model_type != "chatglm":
 
137
 
138
  model, tokenizer = init_model(pretrained_model_name_or_path)
139
 
 
 
 
 
 
 
140
  # history
141
  utterances = list()
142
  for idx, (h_question, h_answer) in enumerate(history):
 
151
  encoded_utterances = tokenizer.__call__(utterances, add_special_tokens=False)
152
  encoded_utterances = encoded_utterances["input_ids"]
153
 
154
+ # input_ids
155
+ if model.config.model_type == "chatglm":
156
+ input_ids = []
157
+ else:
158
+ input_ids = [tokenizer.bos_token_id]
159
+
160
  for encoded_utterance in encoded_utterances:
161
  input_ids.extend(encoded_utterance)
162
  if model.config.model_type != "chatglm":