Spaces:
Sleeping
Sleeping
[update]add main
Browse files- main.py +19 -14
- requirements.txt +1 -0
main.py
CHANGED
@@ -19,7 +19,9 @@ def greet(question: str, history: List[Tuple[str, str]]):
|
|
19 |
model_map: dict = dict()
|
20 |
|
21 |
|
22 |
-
def init_model(pretrained_model_name_or_path: str
|
|
|
|
|
23 |
global model_map
|
24 |
if pretrained_model_name_or_path not in model_map.keys():
|
25 |
# clear
|
@@ -70,18 +72,24 @@ def chat_with_llm_non_stream(question: str,
|
|
70 |
history: List[Tuple[str, str]],
|
71 |
pretrained_model_name_or_path: str,
|
72 |
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
|
73 |
-
device: str
|
74 |
):
|
75 |
-
|
|
|
|
|
76 |
|
77 |
-
|
78 |
-
|
79 |
-
|
80 |
-
|
81 |
-
|
82 |
-
|
83 |
-
|
84 |
-
|
|
|
|
|
|
|
|
|
|
|
85 |
|
86 |
with torch.no_grad():
|
87 |
outputs = model.generate(
|
@@ -106,8 +114,6 @@ def main():
|
|
106 |
chat llm
|
107 |
"""
|
108 |
|
109 |
-
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
110 |
-
|
111 |
with gr.Blocks() as blocks:
|
112 |
gr.Markdown(value=description)
|
113 |
|
@@ -143,7 +149,6 @@ def main():
|
|
143 |
inputs = [
|
144 |
text_box, chatbot, model_name,
|
145 |
max_new_tokens, top_p, temperature, repetition_penalty,
|
146 |
-
device
|
147 |
]
|
148 |
outputs = [
|
149 |
chatbot
|
|
|
19 |
model_map: dict = dict()
|
20 |
|
21 |
|
22 |
+
def init_model(pretrained_model_name_or_path: str):
|
23 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
24 |
+
|
25 |
global model_map
|
26 |
if pretrained_model_name_or_path not in model_map.keys():
|
27 |
# clear
|
|
|
72 |
history: List[Tuple[str, str]],
|
73 |
pretrained_model_name_or_path: str,
|
74 |
max_new_tokens: int, top_p: float, temperature: float, repetition_penalty: float,
|
|
|
75 |
):
|
76 |
+
device: str = "cuda" if torch.cuda.is_available() else "cpu"
|
77 |
+
|
78 |
+
model, tokenizer = init_model(pretrained_model_name_or_path)
|
79 |
|
80 |
+
text_list = list()
|
81 |
+
for pair in history:
|
82 |
+
text_list.extend(pair)
|
83 |
+
text_list.append(question)
|
84 |
+
|
85 |
+
text_encoded = tokenizer.__call__(text_list, add_special_tokens=False)
|
86 |
+
batch_input_ids = text_encoded["input_ids"]
|
87 |
+
|
88 |
+
input_ids = [tokenizer.bos_token_id]
|
89 |
+
for input_ids_ in batch_input_ids:
|
90 |
+
input_ids.extend(input_ids_)
|
91 |
+
input_ids.append(tokenizer.eos_token_id)
|
92 |
+
input_ids = torch.tensor([input_ids], dtype=torch.long).to(device)
|
93 |
|
94 |
with torch.no_grad():
|
95 |
outputs = model.generate(
|
|
|
114 |
chat llm
|
115 |
"""
|
116 |
|
|
|
|
|
117 |
with gr.Blocks() as blocks:
|
118 |
gr.Markdown(value=description)
|
119 |
|
|
|
149 |
inputs = [
|
150 |
text_box, chatbot, model_name,
|
151 |
max_new_tokens, top_p, temperature, repetition_penalty,
|
|
|
152 |
]
|
153 |
outputs = [
|
154 |
chatbot
|
requirements.txt
CHANGED
@@ -1,3 +1,4 @@
|
|
1 |
gradio==3.38.0
|
2 |
transformers==4.30.2
|
3 |
torch==1.13.0
|
|
|
|
1 |
gradio==3.38.0
|
2 |
transformers==4.30.2
|
3 |
torch==1.13.0
|
4 |
+
tiktoken==0.5.1
|