edwardjiang commited on
Commit
9a34d4f
1 Parent(s): 20272bc

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +3 -323
app.py CHANGED
@@ -1,323 +1,3 @@
1
- #!/usr/bin/env python3
2
- import argparse
3
-
4
- import torch
5
- import transformers
6
- from distutils.util import strtobool
7
- from tokenizers import pre_tokenizers
8
-
9
- from transformers.generation.utils import logger
10
- import mdtex2html
11
- import gradio as gr
12
- import warnings
13
- import os
14
-
15
- logger.setLevel("ERROR")
16
- warnings.filterwarnings("ignore")
17
-
18
- import os
19
- os.system("export PYTORCH_CUDA_ALLOC_CONF=max_split_size_mb:20")
20
- os.system("export batch_size=1")
21
-
22
- warnings.filterwarnings("ignore")
23
-
24
-
25
- def _strtobool(x):
26
- return bool(strtobool(x))
27
-
28
-
29
- QA_SPECIAL_TOKENS = {
30
- "Question": "<|prompter|>",
31
- "Answer": "<|assistant|>",
32
- "System": "<|system|>",
33
- "StartPrefix": "<|prefix_begin|>",
34
- "EndPrefix": "<|prefix_end|>",
35
- "InnerThought": "<|inner_thoughts|>",
36
- "EndOfThought": "<eot>"
37
- }
38
-
39
-
40
- def format_pairs(pairs, eos_token, add_initial_reply_token=False):
41
- conversations = [
42
- "{}{}{}".format(
43
- QA_SPECIAL_TOKENS["Question" if i % 2 == 0 else "Answer"], pairs[i], eos_token)
44
- for i in range(len(pairs))
45
- ]
46
- if add_initial_reply_token:
47
- conversations.append(QA_SPECIAL_TOKENS["Answer"])
48
- return conversations
49
-
50
-
51
- def format_system_prefix(prefix, eos_token):
52
- return "{}{}{}".format(
53
- QA_SPECIAL_TOKENS["System"],
54
- prefix,
55
- eos_token,
56
- )
57
-
58
-
59
- def get_specific_model(
60
- model_name, seq2seqmodel=False, without_head=False, cache_dir=".cache", quantization=False, **kwargs
61
- ):
62
- # encoder-decoder support for Flan-T5 like models
63
- # for now, we can use an argument but in the future,
64
- # we can automate this
65
-
66
- model = transformers.LlamaForCausalLM.from_pretrained(model_name,torch_dtype=torch.float16, ).half().cuda()
67
-
68
- return model
69
-
70
-
71
- parser = argparse.ArgumentParser()
72
- parser.add_argument("--model_path", type=str, required=True)
73
- parser.add_argument("--max_new_tokens", type=int, default=200)
74
- parser.add_argument("--top_k", type=int, default=40)
75
- parser.add_argument("--do_sample", type=_strtobool, default=True)
76
- # parser.add_argument("--system_prefix", type=str, default=None)
77
- parser.add_argument("--per-digit-tokens", action="store_true")
78
-
79
-
80
- args = parser.parse_args()
81
-
82
- # # 开放问答
83
- # system_prefix = \
84
- # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
85
- # - EduChat是一个由华东师范大学开发的对话式语言模型。
86
- # EduChat的工具
87
- # - Web search: Disable.
88
- # - Calculators: Disable.
89
- # EduChat的能力
90
- # - Inner Thought: Disable.
91
- # 对话主题
92
- # - General: Enable.
93
- # - Psychology: Disable.
94
- # - Socrates: Disable.'''"</s>"
95
-
96
- # # 启发式教学
97
- # system_prefix = \
98
- # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
99
- # - EduChat是一个由华东师范大学开发的对话式语言模型。
100
- # EduChat的工具
101
- # - Web search: Disable.
102
- # - Calculators: Disable.
103
- # EduChat的能力
104
- # - Inner Thought: Disable.
105
- # 对话主题
106
- # - General: Disable.
107
- # - Psychology: Disable.
108
- # - Socrates: Enable.'''"</s>"
109
-
110
- # 情感支持
111
- system_prefix = \
112
- "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
113
- - EduChat是一个由华东师范大学开发的对话式语言模型。
114
- EduChat的工具
115
- - Web search: Disable.
116
- - Calculators: Disable.
117
- EduChat的能力
118
- - Inner Thought: Disable.
119
- 对话主题
120
- - General: Disable.
121
- - Psychology: Enable.
122
- - Socrates: Disable.'''"</s>"
123
-
124
- # # 情感支持(with InnerThought)
125
- # system_prefix = \
126
- # "<|system|>"'''你是一个人工智能助手,名字叫EduChat。
127
- # - EduChat是一个由华东师范大学开发的对话式语言模型。
128
- # EduChat的工具
129
- # - Web search: Disable.
130
- # - Calculators: Disable.
131
- # EduChat的能力
132
- # - Inner Thought: Enable.
133
- # 对话主题
134
- # - General: Disable.
135
- # - Psychology: Enable.
136
- # - Socrates: Disable.'''"</s>"
137
-
138
-
139
- print('Loading model......')
140
-
141
- model = get_specific_model(args.model_path)
142
-
143
- model.gradient_checkpointing_enable() # reduce number of stored activations
144
-
145
- print('Loading tokenizer...')
146
- tokenizer = transformers.LlamaTokenizer.from_pretrained(args.model_path)
147
-
148
- tokenizer.add_special_tokens(
149
- {
150
- "pad_token": "</s>",
151
- "eos_token": "</s>",
152
- "sep_token": "<s>",
153
- }
154
- )
155
- additional_special_tokens = (
156
- []
157
- if "additional_special_tokens" not in tokenizer.special_tokens_map
158
- else tokenizer.special_tokens_map["additional_special_tokens"]
159
- )
160
- additional_special_tokens = list(
161
- set(additional_special_tokens + list(QA_SPECIAL_TOKENS.values())))
162
-
163
- print("additional_special_tokens:", additional_special_tokens)
164
-
165
- tokenizer.add_special_tokens(
166
- {"additional_special_tokens": additional_special_tokens})
167
-
168
- if args.per_digit_tokens:
169
- tokenizer._tokenizer.pre_processor = pre_tokenizers.Digits(True)
170
-
171
- human_token_id = tokenizer.additional_special_tokens_ids[
172
- tokenizer.additional_special_tokens.index(QA_SPECIAL_TOKENS["Question"])
173
- ]
174
-
175
- print('Type "quit" to exit')
176
- print("Press Control + C to restart conversation (spam to exit)")
177
-
178
- conversation_history = []
179
-
180
-
181
- """Override Chatbot.postprocess"""
182
-
183
-
184
- def postprocess(self, y):
185
- if y is None:
186
- return []
187
- for i, (message, response) in enumerate(y):
188
- y[i] = (
189
- None if message is None else mdtex2html.convert((message)),
190
- None if response is None else mdtex2html.convert(response),
191
- )
192
- return y
193
-
194
-
195
- gr.Chatbot.postprocess = postprocess
196
-
197
-
198
- def parse_text(text):
199
- """copy from https://github.com/GaiZhenbiao/ChuanhuChatGPT/"""
200
- lines = text.split("\n")
201
- lines = [line for line in lines if line != ""]
202
- count = 0
203
- for i, line in enumerate(lines):
204
- if "```" in line:
205
- count += 1
206
- items = line.split('`')
207
- if count % 2 == 1:
208
- lines[i] = f'<pre><code class="language-{items[-1]}">'
209
- else:
210
- lines[i] = f'<br></code></pre>'
211
- else:
212
- if i > 0:
213
- if count % 2 == 1:
214
- line = line.replace("`", "\`")
215
- line = line.replace("<", "&lt;")
216
- line = line.replace(">", "&gt;")
217
- line = line.replace(" ", "&nbsp;")
218
- line = line.replace("*", "&ast;")
219
- line = line.replace("_", "&lowbar;")
220
- line = line.replace("-", "&#45;")
221
- line = line.replace(".", "&#46;")
222
- line = line.replace("!", "&#33;")
223
- line = line.replace("(", "&#40;")
224
- line = line.replace(")", "&#41;")
225
- line = line.replace("$", "&#36;")
226
- lines[i] = "<br>"+line
227
- text = "".join(lines)
228
- return text
229
-
230
-
231
- def predict(input, chatbot, max_length, top_p, temperature, history):
232
- query = parse_text(input)
233
- chatbot.append((query, ""))
234
- conversation_history = []
235
- for i, (old_query, response) in enumerate(history):
236
- conversation_history.append(old_query)
237
- conversation_history.append(response)
238
-
239
- conversation_history.append(query)
240
-
241
- query_str = "".join(format_pairs(conversation_history,
242
- tokenizer.eos_token, add_initial_reply_token=True))
243
-
244
- if system_prefix:
245
- query_str = system_prefix + query_str
246
- print("query:", query_str)
247
-
248
- batch = tokenizer.encode(
249
- query_str,
250
- return_tensors="pt",
251
- )
252
-
253
- with torch.cuda.amp.autocast():
254
- out = model.generate(
255
- input_ids=batch.to(model.device),
256
- # The maximum numbers of tokens to generate, ignoring the number of tokens in the prompt.
257
- max_new_tokens=args.max_new_tokens,
258
- do_sample=args.do_sample,
259
- max_length=max_length,
260
- top_k=args.top_k,
261
- top_p=top_p,
262
- temperature=temperature,
263
- eos_token_id=tokenizer.eos_token_id,
264
- pad_token_id=tokenizer.eos_token_id,
265
- )
266
-
267
- if out[0][-1] == tokenizer.eos_token_id:
268
- response = out[0][:-1]
269
- else:
270
- response = out[0]
271
-
272
- response = tokenizer.decode(out[0]).split(QA_SPECIAL_TOKENS["Answer"])[-1]
273
-
274
- conversation_history.append(response)
275
-
276
- with open("./educhat_query_record.txt", 'a+') as f:
277
- f.write(str(conversation_history) + '\n')
278
-
279
- chatbot[-1] = (query, parse_text(response))
280
- history = history + [(query, response)]
281
- print(f"chatbot is {chatbot}")
282
- print(f"history is {history}")
283
-
284
- return chatbot, history
285
-
286
-
287
- def reset_user_input():
288
- return gr.update(value='')
289
-
290
-
291
- def reset_state():
292
- return [], []
293
-
294
-
295
- with gr.Blocks() as demo:
296
- gr.HTML("""<h1 align="center">欢迎使用 EduChat 人工智能助手!</h1>""")
297
-
298
- chatbot = gr.Chatbot()
299
- with gr.Row():
300
- with gr.Column(scale=4):
301
- with gr.Column(scale=12):
302
- user_input = gr.Textbox(show_label=False, placeholder="Input...", lines=10).style(
303
- container=False)
304
- with gr.Column(min_width=32, scale=1):
305
- submitBtn = gr.Button("Submit", variant="primary")
306
- with gr.Column(scale=1):
307
- emptyBtn = gr.Button("Clear History")
308
- max_length = gr.Slider(
309
- 0, 2048, value=2048, step=1.0, label="Maximum length", interactive=True)
310
- top_p = gr.Slider(0, 1, value=0.2, step=0.01,
311
- label="Top P", interactive=True)
312
- temperature = gr.Slider(
313
- 0, 1, value=1, step=0.01, label="Temperature", interactive=True)
314
-
315
- history = gr.State([]) # (message, bot_message)
316
-
317
- submitBtn.click(predict, [user_input, chatbot, max_length, top_p, temperature, history], [chatbot, history],
318
- show_progress=True)
319
- submitBtn.click(reset_user_input, [], [user_input])
320
-
321
- emptyBtn.click(reset_state, outputs=[chatbot, history], show_progress=True)
322
-
323
- demo.queue().launch(inbrowser=True, share=True)
 
1
+ import subprocess
2
+ command = "python educhat_gradio.py --model_path ecnu-icalk/educhat-base-002-7b --top_k 50 --do_sample True --max_new_tokens 512"
3
+ subprocess.Popen(command, shell=True)