qgyd2021 commited on
Commit
7cb717b
1 Parent(s): 573f68e
Files changed (1) hide show
  1. main.py +6 -2
main.py CHANGED
@@ -36,7 +36,8 @@ def init_model(pretrained_model_name_or_path: str):
36
  pretrained_model_name_or_path,
37
  trust_remote_code=True,
38
  low_cpu_mem_usage=True,
39
- torch_dtype=torch.bfloat16,
 
40
  device_map="auto",
41
  offload_folder="./offload",
42
  offload_state_dict=True,
@@ -45,7 +46,8 @@ def init_model(pretrained_model_name_or_path: str):
45
  if model.config.model_type == "chatglm":
46
  model = model.eval()
47
  else:
48
- model = model.bfloat16().eval()
 
49
 
50
  tokenizer = AutoTokenizer.from_pretrained(
51
  pretrained_model_name_or_path,
@@ -146,6 +148,8 @@ def chat_with_llm_streaming(question: str,
146
  for idx, (h_question, h_answer) in enumerate(history):
147
  if model.config.model_type == "chatglm":
148
  h_question = "[Round {}]\n\n问:{}\n\n答:".format(idx, h_question)
 
 
149
  utterances.append(h_question)
150
  utterances.append(h_answer)
151
  utterances.append(question)
 
36
  pretrained_model_name_or_path,
37
  trust_remote_code=True,
38
  low_cpu_mem_usage=True,
39
+ # torch_dtype=torch.bfloat16,
40
+ torch_dtype=torch.float16,
41
  device_map="auto",
42
  offload_folder="./offload",
43
  offload_state_dict=True,
 
46
  if model.config.model_type == "chatglm":
47
  model = model.eval()
48
  else:
49
+ # model = model.bfloat16().eval()
50
+ model = model.eval()
51
 
52
  tokenizer = AutoTokenizer.from_pretrained(
53
  pretrained_model_name_or_path,
 
148
  for idx, (h_question, h_answer) in enumerate(history):
149
  if model.config.model_type == "chatglm":
150
  h_question = "[Round {}]\n\n问:{}\n\n答:".format(idx, h_question)
151
+ elif model.config.model_type == "llama2":
152
+ h_question = "Question: {}\n\nAnswer: ".format(h_question)
153
  utterances.append(h_question)
154
  utterances.append(h_answer)
155
  utterances.append(question)