qgyd2021 commited on
Commit
f97e6ec
1 Parent(s): 0cfc467

[update]add main

Browse files
Files changed (2) hide show
  1. main.py +19 -14
  2. requirements.txt +1 -0
main.py CHANGED
@@ -19,7 +19,9 @@ def greet(question: str, history: List[Tuple[str, str]]):
19
  model_map: dict = dict()
20
 
21
 
22
- def init_model(pretrained_model_name_or_path: str, device: str):
 
 
23
  global model_map
24
  if pretrained_model_name_or_path not in model_map.keys():
25
  # clear
@@ -70,18 +72,24 @@ def chat_with_llm_non_stream(question: str,
70
  history: List[Tuple[str, str]],
71
  pretrained_model_name_or_path: str,
72
  max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
73
- device: str
74
  ):
75
- model, tokenizer = init_model(pretrained_model_name_or_path, device)
 
 
76
 
77
- input_ids = tokenizer(
78
- question,
79
- return_tensors="pt",
80
- add_special_tokens=False,
81
- ).input_ids.to(device)
82
- bos_token_id = torch.tensor([[tokenizer.bos_token_id]], dtype=torch.long).to(device)
83
- eos_token_id = torch.tensor([[tokenizer.eos_token_id]], dtype=torch.long).to(device)
84
- input_ids = torch.concat([bos_token_id, input_ids, eos_token_id], dim=1)
 
 
 
 
 
85
 
86
  with torch.no_grad():
87
  outputs = model.generate(
@@ -106,8 +114,6 @@ def main():
106
  chat llm
107
  """
108
 
109
- device: str = "cuda" if torch.cuda.is_available() else "cpu"
110
-
111
  with gr.Blocks() as blocks:
112
  gr.Markdown(value=description)
113
 
@@ -143,7 +149,6 @@ def main():
143
  inputs = [
144
  text_box, chatbot, model_name,
145
  max_new_tokens, top_p, temperature, repetition_penalty,
146
- device
147
  ]
148
  outputs = [
149
  chatbot
 
19
  model_map: dict = dict()
20
 
21
 
22
+ def init_model(pretrained_model_name_or_path: str):
23
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
24
+
25
  global model_map
26
  if pretrained_model_name_or_path not in model_map.keys():
27
  # clear
 
72
  history: List[Tuple[str, str]],
73
  pretrained_model_name_or_path: str,
74
  max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
 
75
  ):
76
+ device: str = "cuda" if torch.cuda.is_available() else "cpu"
77
+
78
+ model, tokenizer = init_model(pretrained_model_name_or_path)
79
 
80
+ text_list = list()
81
+ for pair in history:
82
+ text_list.extend(pair)
83
+ text_list.append(question)
84
+
85
+ text_encoded = tokenizer.__call__(text_list, add_special_tokens=False)
86
+ batch_input_ids = text_encoded["input_ids"]
87
+
88
+ input_ids = [tokenizer.bos_token_id]
89
+ for input_ids_ in batch_input_ids:
90
+ input_ids.extend(input_ids_)
91
+ input_ids.append(tokenizer.eos_token_id)
92
+ input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
93
 
94
  with torch.no_grad():
95
  outputs = model.generate(
 
114
  chat llm
115
  """
116
 
 
 
117
  with gr.Blocks() as blocks:
118
  gr.Markdown(value=description)
119
 
 
149
  inputs = [
150
  text_box, chatbot, model_name,
151
  max_new_tokens, top_p, temperature, repetition_penalty,
 
152
  ]
153
  outputs = [
154
  chatbot
requirements.txt CHANGED
@@ -1,3 +1,4 @@
1
  gradio==3.38.0
2
  transformers==4.30.2
3
  torch==1.13.0
 
 
1
  gradio==3.38.0
2
  transformers==4.30.2
3
  torch==1.13.0
4
+ tiktoken==0.5.1