qgyd2021 commited on
Commit
9d715da
1 Parent(s): 6e2102d

[update]add main

Browse files
Files changed (1) hide show
  1. main.py +33 -25
main.py CHANGED
@@ -42,8 +42,11 @@ def init_model(pretrained_model_name_or_path: str):
42
  offload_state_dict=True,
43
  # load_in_4bit=True,
44
  )
45
- model = model.to(device)
46
- model = model.bfloat16().eval()
 
 
 
47
 
48
  tokenizer = AutoTokenizer.from_pretrained(
49
  pretrained_model_name_or_path,
@@ -79,18 +82,27 @@ def chat_with_llm_non_stream(question: str,
79
 
80
  model, tokenizer = init_model(pretrained_model_name_or_path)
81
 
82
- text_list = list()
83
- for pair in history:
84
- text_list.extend(pair)
85
- text_list.append(question)
 
86
 
87
- text_encoded = tokenizer.__call__(text_list, add_special_tokens=False)
88
- batch_input_ids = text_encoded["input_ids"]
 
 
 
 
 
 
 
 
 
 
 
 
89
 
90
- input_ids = [tokenizer.bos_token_id]
91
- for input_ids_ in batch_input_ids:
92
- input_ids.extend(input_ids_)
93
- input_ids.append(tokenizer.eos_token_id)
94
  input_ids = torch.tensor([input_ids], dtype=torch.long)
95
  input_ids = input_ids[:, -history_max_len:].to(device)
96
 
@@ -122,31 +134,27 @@ def chat_with_llm_streaming(question: str,
122
 
123
  model, tokenizer = init_model(pretrained_model_name_or_path)
124
 
 
125
  if model.config.model_type == "chatglm":
126
  input_ids = []
127
  else:
128
  input_ids = [tokenizer.bos_token_id]
129
 
130
  # history
 
131
  for idx, (h_question, h_answer) in enumerate(history):
132
  if model.config.model_type == "chatglm":
133
  h_question = "[Round {}]\n\n问:{}\n\n答:".format(idx, h_question)
134
- h_question = tokenizer.__call__(h_question, add_special_tokens=False)
135
- h_answer = tokenizer.__call__(h_answer, add_special_tokens=False)
 
136
 
137
- input_ids.append(h_question)
138
- if model.config.model_type != "chatglm":
139
- input_ids.append(tokenizer.eos_token_id)
140
- input_ids.append(h_answer)
141
- if model.config.model_type != "chatglm":
142
  input_ids.append(tokenizer.eos_token_id)
143
 
144
- # question
145
- question = tokenizer.__call__(question, add_special_tokens=False)
146
- input_ids.append(question)
147
- if model.config.model_type != "chatglm":
148
- input_ids.append(tokenizer.eos_token_id)
149
-
150
  input_ids = torch.tensor([input_ids], dtype=torch.long)
151
  input_ids = input_ids[:, -history_max_len:].to(device)
152
 
 
42
  offload_state_dict=True,
43
  # load_in_4bit=True,
44
  )
45
+ if model.config.model_type == "chatglm":
46
+ model = model.eval()
47
+ else:
48
+ model = model.to(device)
49
+ model = model.bfloat16().eval()
50
 
51
  tokenizer = AutoTokenizer.from_pretrained(
52
  pretrained_model_name_or_path,
 
82
 
83
  model, tokenizer = init_model(pretrained_model_name_or_path)
84
 
85
+ # input_ids
86
+ if model.config.model_type == "chatglm":
87
+ input_ids = []
88
+ else:
89
+ input_ids = [tokenizer.bos_token_id]
90
 
91
+ # history
92
+ utterances = list()
93
+ for idx, (h_question, h_answer) in enumerate(history):
94
+ if model.config.model_type == "chatglm":
95
+ h_question = "[Round {}]\n\n问:{}\n\n答:".format(idx, h_question)
96
+ utterances.append(h_question)
97
+ utterances.append(h_answer)
98
+ utterances.append(question)
99
+
100
+ encoded_utterances = tokenizer.__call__(utterances, add_special_tokens=False)
101
+ for encoded_utterance in encoded_utterances:
102
+ input_ids.extend(encoded_utterance)
103
+ if model.config.model_type == "chatglm":
104
+ input_ids.append(tokenizer.eos_token_id)
105
 
 
 
 
 
106
  input_ids = torch.tensor([input_ids], dtype=torch.long)
107
  input_ids = input_ids[:, -history_max_len:].to(device)
108
 
 
134
 
135
  model, tokenizer = init_model(pretrained_model_name_or_path)
136
 
137
+ # input_ids
138
  if model.config.model_type == "chatglm":
139
  input_ids = []
140
  else:
141
  input_ids = [tokenizer.bos_token_id]
142
 
143
  # history
144
+ utterances = list()
145
  for idx, (h_question, h_answer) in enumerate(history):
146
  if model.config.model_type == "chatglm":
147
  h_question = "[Round {}]\n\n问:{}\n\n答:".format(idx, h_question)
148
+ utterances.append(h_question)
149
+ utterances.append(h_answer)
150
+ utterances.append(question)
151
 
152
+ encoded_utterances = tokenizer.__call__(utterances, add_special_tokens=False)
153
+ for encoded_utterance in encoded_utterances:
154
+ input_ids.extend(encoded_utterance)
155
+ if model.config.model_type == "chatglm":
 
156
  input_ids.append(tokenizer.eos_token_id)
157
 
 
 
 
 
 
 
158
  input_ids = torch.tensor([input_ids], dtype=torch.long)
159
  input_ids = input_ids[:, -history_max_len:].to(device)
160