qgyd2021 commited on
Commit
0cfc467
1 Parent(s): 4e39876

[update]add main

Browse files
Files changed (1) hide show
  1. main.py +61 -33
main.py CHANGED
@@ -16,39 +16,63 @@ def greet(question: str, history: List[Tuple[str, str]]):
16
  return result
17
 
18
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
19
  def chat_with_llm_non_stream(question: str,
20
  history: List[Tuple[str, str]],
21
  pretrained_model_name_or_path: str,
22
  max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
 
23
  ):
24
- device = "cuda" if torch.cuda.is_available() else "cpu"
25
-
26
- model = AutoModelForCausalLM.from_pretrained(
27
- pretrained_model_name_or_path,
28
- trust_remote_code=True,
29
- low_cpu_mem_usage=True,
30
- torch_dtype=torch.bfloat16,
31
- device_map="auto",
32
- offload_folder="./offload",
33
- offload_state_dict=True,
34
- # load_in_4bit=True,
35
- )
36
- model = model.to(device)
37
- model = model.bfloat16().eval()
38
-
39
- tokenizer = AutoTokenizer.from_pretrained(
40
- pretrained_model_name_or_path,
41
- trust_remote_code=True,
42
- # llama不支持fast
43
- use_fast=False if model.config.model_type == "llama" else True,
44
- padding_side="left"
45
- )
46
-
47
- # QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
48
- if tokenizer.__class__.__name__ == "QWenTokenizer":
49
- tokenizer.pad_token_id = tokenizer.eod_id
50
- tokenizer.bos_token_id = tokenizer.eod_id
51
- tokenizer.eos_token_id = tokenizer.eod_id
52
 
53
  input_ids = tokenizer(
54
  question,
@@ -70,10 +94,11 @@ def chat_with_llm_non_stream(question: str,
70
  eos_token_id=tokenizer.eos_token_id
71
  )
72
  outputs = outputs.tolist()[0][len(input_ids[0]):]
73
- response = tokenizer.decode(outputs)
74
- response = response.strip().replace(tokenizer.eos_token, "").strip()
75
 
76
- return
 
77
 
78
 
79
  def main():
@@ -81,8 +106,10 @@ def main():
81
  chat llm
82
  """
83
 
 
 
84
  with gr.Blocks() as blocks:
85
- gr.Markdown(value="gradio demo")
86
 
87
  chatbot = gr.Chatbot([], elem_id="chatbot", height=400)
88
  with gr.Row():
@@ -115,7 +142,8 @@ def main():
115
 
116
  inputs = [
117
  text_box, chatbot, model_name,
118
- max_new_tokens, top_p, temperature, repetition_penalty
 
119
  ]
120
  outputs = [
121
  chatbot
 
16
  return result
17
 
18
 
19
+ model_map: dict = dict()
20
+
21
+
22
+ def init_model(pretrained_model_name_or_path: str, device: str):
23
+ global model_map
24
+ if pretrained_model_name_or_path not in model_map.keys():
25
+ # clear
26
+ for k1, v1 in model_map.items():
27
+ for k2, v2 in v1.items():
28
+ del v2
29
+ model_map = dict()
30
+
31
+ # build model
32
+ model = AutoModelForCausalLM.from_pretrained(
33
+ pretrained_model_name_or_path,
34
+ trust_remote_code=True,
35
+ low_cpu_mem_usage=True,
36
+ torch_dtype=torch.bfloat16,
37
+ device_map="auto",
38
+ offload_folder="./offload",
39
+ offload_state_dict=True,
40
+ # load_in_4bit=True,
41
+ )
42
+ model = model.to(device)
43
+ model = model.bfloat16().eval()
44
+
45
+ tokenizer = AutoTokenizer.from_pretrained(
46
+ pretrained_model_name_or_path,
47
+ trust_remote_code=True,
48
+ # llama不支持fast
49
+ use_fast=False if model.config.model_type == "llama" else True,
50
+ padding_side="left"
51
+ )
52
+
53
+ # QWenTokenizer比较特殊, pad_token_id, bos_token_id, eos_token_id 均 为None. eod_id对应的token为<|endoftext|>
54
+ if tokenizer.__class__.__name__ == "QWenTokenizer":
55
+ tokenizer.pad_token_id = tokenizer.eod_id
56
+ tokenizer.bos_token_id = tokenizer.eod_id
57
+ tokenizer.eos_token_id = tokenizer.eod_id
58
+
59
+ model_map[pretrained_model_name_or_path] = {
60
+ "model": model,
61
+ "tokenizer": tokenizer,
62
+ }
63
+ else:
64
+ model = model_map[pretrained_model_name_or_path]["model"]
65
+ tokenizer = model_map[pretrained_model_name_or_path]["tokenizer"]
66
+ return model, tokenizer
67
+
68
+
69
  def chat_with_llm_non_stream(question: str,
70
  history: List[Tuple[str, str]],
71
  pretrained_model_name_or_path: str,
72
  max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
73
+ device: str
74
  ):
75
+ model, tokenizer = init_model(pretrained_model_name_or_path, device)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
76
 
77
  input_ids = tokenizer(
78
  question,
 
94
  eos_token_id=tokenizer.eos_token_id
95
  )
96
  outputs = outputs.tolist()[0][len(input_ids[0]):]
97
+ answer = tokenizer.decode(outputs)
98
+ answer = answer.strip().replace(tokenizer.eos_token, "").strip()
99
 
100
+ result = history + [(question, answer)]
101
+ return result
102
 
103
 
104
  def main():
 
106
  chat llm
107
  """
108
 
109
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
110
+
111
  with gr.Blocks() as blocks:
112
+ gr.Markdown(value=description)
113
 
114
  chatbot = gr.Chatbot([], elem_id="chatbot", height=400)
115
  with gr.Row():
 
142
 
143
  inputs = [
144
  text_box, chatbot, model_name,
145
+ max_new_tokens, top_p, temperature, repetition_penalty,
146
+ device
147
  ]
148
  outputs = [
149
  chatbot