add files
Browse files- .gitignore +165 -0
- chatglm-6b-int8-onnx-merged/chatglm-6b-int8.onnx +3 -0
- chatglm-6b-int8-onnx-merged/model_weights_0.bin +3 -0
- chatglm-6b-int8-onnx-merged/model_weights_1.bin +3 -0
- chatglm-6b-int8-onnx-merged/model_weights_2.bin +3 -0
- chatglm-6b-int8-onnx-merged/model_weights_3.bin +3 -0
- chatglm-6b-int8-onnx-merged/model_weights_4.bin +3 -0
- chatglm-6b-int8-onnx-merged/model_weights_5.bin +3 -0
- chatglm-6b-int8-onnx-merged/model_weights_6.bin +3 -0
- chatglm-6b-int8-onnx-merged/sentencepiece.model +3 -0
- model.py +125 -0
- requirements.txt +5 -0
- tokenizer.py +75 -0
- web-ui.py +74 -0
.gitignore
ADDED
@@ -0,0 +1,165 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Project ignores
|
2 |
+
models/
|
3 |
+
scripts/
|
4 |
+
data/
|
5 |
+
|
6 |
+
# Byte-compiled / optimized / DLL files
|
7 |
+
__pycache__/
|
8 |
+
*.py[cod]
|
9 |
+
*$py.class
|
10 |
+
|
11 |
+
# C extensions
|
12 |
+
*.so
|
13 |
+
|
14 |
+
# Distribution / packaging
|
15 |
+
.Python
|
16 |
+
build/
|
17 |
+
develop-eggs/
|
18 |
+
dist/
|
19 |
+
downloads/
|
20 |
+
eggs/
|
21 |
+
.eggs/
|
22 |
+
lib/
|
23 |
+
lib64/
|
24 |
+
parts/
|
25 |
+
sdist/
|
26 |
+
var/
|
27 |
+
wheels/
|
28 |
+
share/python-wheels/
|
29 |
+
*.egg-info/
|
30 |
+
.installed.cfg
|
31 |
+
*.egg
|
32 |
+
MANIFEST
|
33 |
+
|
34 |
+
# PyInstaller
|
35 |
+
# Usually these files are written by a python script from a template
|
36 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
37 |
+
*.manifest
|
38 |
+
*.spec
|
39 |
+
|
40 |
+
# Installer logs
|
41 |
+
pip-log.txt
|
42 |
+
pip-delete-this-directory.txt
|
43 |
+
|
44 |
+
# Unit test / coverage reports
|
45 |
+
htmlcov/
|
46 |
+
.tox/
|
47 |
+
.nox/
|
48 |
+
.coverage
|
49 |
+
.coverage.*
|
50 |
+
.cache
|
51 |
+
nosetests.xml
|
52 |
+
coverage.xml
|
53 |
+
*.cover
|
54 |
+
*.py,cover
|
55 |
+
.hypothesis/
|
56 |
+
.pytest_cache/
|
57 |
+
cover/
|
58 |
+
|
59 |
+
# Translations
|
60 |
+
*.mo
|
61 |
+
*.pot
|
62 |
+
|
63 |
+
# Django stuff:
|
64 |
+
*.log
|
65 |
+
local_settings.py
|
66 |
+
db.sqlite3
|
67 |
+
db.sqlite3-journal
|
68 |
+
|
69 |
+
# Flask stuff:
|
70 |
+
instance/
|
71 |
+
.webassets-cache
|
72 |
+
|
73 |
+
# Scrapy stuff:
|
74 |
+
.scrapy
|
75 |
+
|
76 |
+
# Sphinx documentation
|
77 |
+
docs/_build/
|
78 |
+
|
79 |
+
# PyBuilder
|
80 |
+
.pybuilder/
|
81 |
+
target/
|
82 |
+
|
83 |
+
# Jupyter Notebook
|
84 |
+
.ipynb_checkpoints
|
85 |
+
|
86 |
+
# IPython
|
87 |
+
profile_default/
|
88 |
+
ipython_config.py
|
89 |
+
|
90 |
+
# pyenv
|
91 |
+
# For a library or package, you might want to ignore these files since the code is
|
92 |
+
# intended to run in multiple environments; otherwise, check them in:
|
93 |
+
# .python-version
|
94 |
+
|
95 |
+
# pipenv
|
96 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
97 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
98 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
99 |
+
# install all needed dependencies.
|
100 |
+
#Pipfile.lock
|
101 |
+
|
102 |
+
# poetry
|
103 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
104 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
105 |
+
# commonly ignored for libraries.
|
106 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
107 |
+
#poetry.lock
|
108 |
+
|
109 |
+
# pdm
|
110 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
111 |
+
#pdm.lock
|
112 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
113 |
+
# in version control.
|
114 |
+
# https://pdm.fming.dev/#use-with-ide
|
115 |
+
.pdm.toml
|
116 |
+
|
117 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
118 |
+
__pypackages__/
|
119 |
+
|
120 |
+
# Celery stuff
|
121 |
+
celerybeat-schedule
|
122 |
+
celerybeat.pid
|
123 |
+
|
124 |
+
# SageMath parsed files
|
125 |
+
*.sage.py
|
126 |
+
|
127 |
+
# Environments
|
128 |
+
.env
|
129 |
+
.venv
|
130 |
+
env/
|
131 |
+
venv/
|
132 |
+
ENV/
|
133 |
+
env.bak/
|
134 |
+
venv.bak/
|
135 |
+
|
136 |
+
# Spyder project settings
|
137 |
+
.spyderproject
|
138 |
+
.spyproject
|
139 |
+
|
140 |
+
# Rope project settings
|
141 |
+
.ropeproject
|
142 |
+
|
143 |
+
# mkdocs documentation
|
144 |
+
/site
|
145 |
+
|
146 |
+
# mypy
|
147 |
+
.mypy_cache/
|
148 |
+
.dmypy.json
|
149 |
+
dmypy.json
|
150 |
+
|
151 |
+
# Pyre type checker
|
152 |
+
.pyre/
|
153 |
+
|
154 |
+
# pytype static type analyzer
|
155 |
+
.pytype/
|
156 |
+
|
157 |
+
# Cython debug symbols
|
158 |
+
cython_debug/
|
159 |
+
|
160 |
+
# PyCharm
|
161 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
162 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
163 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
164 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
165 |
+
.idea/
|
chatglm-6b-int8-onnx-merged/chatglm-6b-int8.onnx
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:93c988ddb30e2eb97aafe05fd8086f56faec47e8488bc2bb6dbd19ee50ce36ae
|
3 |
+
size 459821
|
chatglm-6b-int8-onnx-merged/model_weights_0.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:721f5497129c8f2bbffe685892a99bdc87e00fd29b70d54d5f75df8810811cf1
|
3 |
+
size 1069807488
|
chatglm-6b-int8-onnx-merged/model_weights_1.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:320f96165f0ba496292eb4dd35979d5fb5c0bbfc0fbaf83b0e8150a9959d4c8d
|
3 |
+
size 948125696
|
chatglm-6b-int8-onnx-merged/model_weights_2.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:92bc601207b27b08803e223b6a414eb533d3f4eeab26ed9c3b75ca4b0b977f41
|
3 |
+
size 1006960640
|
chatglm-6b-int8-onnx-merged/model_weights_3.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:26218891b8d13a8c3b3b5cc15b47c6ba1b5b140a614cd9a5ffb95a69e5180025
|
3 |
+
size 1006960640
|
chatglm-6b-int8-onnx-merged/model_weights_4.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:22f6b5087d50d39c566079a8677c1e1ef41e3b16763f4d022e00d385d4dc88af
|
3 |
+
size 1006960640
|
chatglm-6b-int8-onnx-merged/model_weights_5.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:d5c6502fdf30878a5e75be2da7e2e134e5bfe3a132b1e98880880687cce1e703
|
3 |
+
size 1006960640
|
chatglm-6b-int8-onnx-merged/model_weights_6.bin
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:82b140850685302b6939fca378a4174246304c4afb7b58b26aaecad370d2a15a
|
3 |
+
size 671842304
|
chatglm-6b-int8-onnx-merged/sentencepiece.model
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:5e974d9a69c242ce014c88c2b26089270f6198f3c0b700a887666cd3e816f17e
|
3 |
+
size 2706249
|
model.py
ADDED
@@ -0,0 +1,125 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
import numpy as np
|
3 |
+
from tokenizer import ChatGLMTokenizer
|
4 |
+
# import torch
|
5 |
+
from onnxruntime import InferenceSession, SessionOptions
|
6 |
+
|
7 |
+
|
8 |
+
# Currently `MatMulInteger` and `DynamicQuantizeLinear` are only supported on CPU,
|
9 |
+
# although they are documented as supported on CUDA.
|
10 |
+
providers = ["CPUExecutionProvider"]
|
11 |
+
|
12 |
+
# if torch.cuda.is_available():
|
13 |
+
# providers = ["CUDAExecutionProvider"] + providers
|
14 |
+
|
15 |
+
|
16 |
+
# Default paths
|
17 |
+
tokenizer_path = "chatglm-6b-int8-onnx-merged/sentencepiece.model"
|
18 |
+
onnx_model_path = "chatglm-6b-int8-onnx-merged/chatglm-6b-int8.onnx"
|
19 |
+
|
20 |
+
|
21 |
+
# input & output names
|
22 |
+
past_names = [f"past_{name}_{i}" for i in range(28) for name in ["key", "value"]]
|
23 |
+
present_names = [f"present_{name}_{i}" for i in range(28) for name in ["key", "value"]]
|
24 |
+
output_names = ["logits"] + present_names
|
25 |
+
|
26 |
+
|
27 |
+
# default kv_cache for first inference
|
28 |
+
default_past_key_values = {
|
29 |
+
k: np.zeros((1, 0, 32, 128), dtype=np.float32) for k in past_names
|
30 |
+
}
|
31 |
+
|
32 |
+
|
33 |
+
def chat_template(history: list[tuple[str, str]], current: str):
|
34 |
+
prompt = ""
|
35 |
+
chat_round = 0
|
36 |
+
for question, answer in history:
|
37 |
+
prompt += f"[Round {chat_round}]\n问:{question}\n答:{answer}\n"
|
38 |
+
chat_round += 1
|
39 |
+
prompt += f"[Round {chat_round}]\n问:{current}\n答:"
|
40 |
+
return prompt
|
41 |
+
|
42 |
+
|
43 |
+
def process_response(response: str):
|
44 |
+
response = response.strip()
|
45 |
+
response = response.replace("[[训练时间]]", "2023年")
|
46 |
+
punkts = [
|
47 |
+
[",", ","],
|
48 |
+
["!", "!"],
|
49 |
+
[":", ":"],
|
50 |
+
[";", ";"],
|
51 |
+
["\?", "?"],
|
52 |
+
]
|
53 |
+
for item in punkts:
|
54 |
+
response = re.sub(r"([\u4e00-\u9fff])%s" % item[0], r"\1%s" % item[1], response)
|
55 |
+
response = re.sub(r"%s([\u4e00-\u9fff])" % item[0], r"%s\1" % item[1], response)
|
56 |
+
return response
|
57 |
+
|
58 |
+
|
59 |
+
class ChatGLMModel():
|
60 |
+
|
61 |
+
def __init__(self, onnx_model_path=onnx_model_path, tokenizer_path=tokenizer_path, profile=False) -> None:
|
62 |
+
self.tokenizer = ChatGLMTokenizer(tokenizer_path)
|
63 |
+
options = SessionOptions()
|
64 |
+
options.enable_profiling = profile
|
65 |
+
self.session = InferenceSession(onnx_model_path, options, providers=providers)
|
66 |
+
self.eop_token_id = self.tokenizer["<eop>"]
|
67 |
+
|
68 |
+
|
69 |
+
def prepare_input(self, prompt: str):
|
70 |
+
input_ids, prefix_mask = self.tokenizer.encode(prompt)
|
71 |
+
|
72 |
+
input_ids = np.array([input_ids], dtype=np.longlong)
|
73 |
+
prefix_mask = np.array([prefix_mask], dtype=np.longlong)
|
74 |
+
|
75 |
+
return input_ids, prefix_mask, default_past_key_values
|
76 |
+
|
77 |
+
|
78 |
+
def sample_next_token(self, logits: np.ndarray, top_k=50, top_p=0.7, temperature=1):
|
79 |
+
# softmax with temperature
|
80 |
+
exp_logits = np.exp(logits / temperature)
|
81 |
+
probs = exp_logits / np.sum(exp_logits)
|
82 |
+
|
83 |
+
# top k
|
84 |
+
top_k_idx = np.argsort(-probs)[:top_k]
|
85 |
+
top_k_probs = probs[top_k_idx]
|
86 |
+
|
87 |
+
# top p
|
88 |
+
cumsum_probs = np.cumsum(top_k_probs)
|
89 |
+
top_k_probs[(cumsum_probs - top_k_probs) > top_p] = 0.0
|
90 |
+
top_k_probs = top_k_probs / np.sum(top_k_probs)
|
91 |
+
|
92 |
+
# sample
|
93 |
+
next_token = np.random.choice(top_k_idx, size=1, p=top_k_probs)
|
94 |
+
return next_token[0].item()
|
95 |
+
|
96 |
+
|
97 |
+
def generate_iterate(self, prompt: str, max_generated_tokens=100, top_k=50, top_p=0.7, temperature=1):
|
98 |
+
input_ids, prefix_mask, past_key_values = self.prepare_input(prompt)
|
99 |
+
output_tokens = []
|
100 |
+
|
101 |
+
while True:
|
102 |
+
inputs = {
|
103 |
+
"input_ids": input_ids,
|
104 |
+
"prefix_mask": prefix_mask,
|
105 |
+
"use_past": np.array(len(output_tokens) > 0),
|
106 |
+
}
|
107 |
+
inputs.update(past_key_values)
|
108 |
+
|
109 |
+
logits, *past_key_values = self.session.run(output_names, inputs)
|
110 |
+
past_key_values = { k: v for k, v in zip(past_names, past_key_values) }
|
111 |
+
|
112 |
+
next_token = self.sample_next_token(logits[0, -1], top_k=top_k, top_p=top_p, temperature=temperature)
|
113 |
+
|
114 |
+
output_tokens += [next_token]
|
115 |
+
|
116 |
+
if next_token == self.eop_token_id or len(output_tokens) > max_generated_tokens:
|
117 |
+
break
|
118 |
+
|
119 |
+
input_ids = np.array([[next_token]], dtype=np.longlong)
|
120 |
+
prefix_mask = np.concatenate([prefix_mask, np.array([[0]], dtype=np.longlong)], axis=1)
|
121 |
+
|
122 |
+
yield process_response(self.tokenizer.decode(output_tokens))
|
123 |
+
|
124 |
+
return process_response(self.tokenizer.decode(output_tokens))
|
125 |
+
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
numpy
|
2 |
+
onnxruntime
|
3 |
+
sentencepiece
|
4 |
+
streamlit
|
5 |
+
streamlit-chat
|
tokenizer.py
ADDED
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import re
|
2 |
+
from sentencepiece import SentencePieceProcessor
|
3 |
+
|
4 |
+
|
5 |
+
def replace_spaces_with_blank(match: re.Match[str]):
|
6 |
+
return f"<|blank_{len(match.group())}|>"
|
7 |
+
|
8 |
+
|
9 |
+
def replace_blank_with_spaces(match: re.Match[str]):
|
10 |
+
return " " * int(match.group(1))
|
11 |
+
|
12 |
+
|
13 |
+
class ChatGLMTokenizer:
|
14 |
+
def __init__(self, vocab_file):
|
15 |
+
assert vocab_file is not None
|
16 |
+
self.vocab_file = vocab_file
|
17 |
+
self.special_tokens = ["[MASK]", "[gMASK]", "[sMASK]", "<unused_0>", "<sop>", "<eop>", "<ENC>", "<dBLOCK>"]
|
18 |
+
self.text_tokenizer = SentencePieceProcessor(str(vocab_file))
|
19 |
+
|
20 |
+
def __len__(self):
|
21 |
+
return len(self.text_tokenizer)
|
22 |
+
|
23 |
+
def __getitem__(self, key: str):
|
24 |
+
return self.text_tokenizer[key]
|
25 |
+
|
26 |
+
|
27 |
+
def preprocess(self, text: str, linebreak=True, whitespaces=True):
|
28 |
+
if linebreak:
|
29 |
+
text = text.replace("\n", "<n>")
|
30 |
+
if whitespaces:
|
31 |
+
text = text.replace("\t", "<|tab|>")
|
32 |
+
text = re.sub(r" {2,80}", replace_spaces_with_blank, text)
|
33 |
+
return text
|
34 |
+
|
35 |
+
|
36 |
+
def encode(
|
37 |
+
self, text: str, text_pair: str = None,
|
38 |
+
linebreak=True, whitespaces=True,
|
39 |
+
add_dummy_prefix=True, special_tokens=True,
|
40 |
+
) -> tuple[list[int], list[int]]:
|
41 |
+
"""
|
42 |
+
text: Text to encode. Bidirectional part with a [gMASK] and an <sop> for causal LM.
|
43 |
+
text_pair: causal LM part.
|
44 |
+
linebreak: Whether to encode newline (\n) in text.
|
45 |
+
whitespaces: Whether to encode multiple whitespaces or tab in text, useful for source code encoding.
|
46 |
+
special_tokens: Whether to encode special token ([MASK], [gMASK], etc.) in text.
|
47 |
+
add_dummy_prefix: Whether to add dummy blank space in the beginning.
|
48 |
+
"""
|
49 |
+
text = self.preprocess(text, linebreak, whitespaces)
|
50 |
+
if not add_dummy_prefix:
|
51 |
+
text = "<n>" + text
|
52 |
+
|
53 |
+
tokens = self.text_tokenizer.encode(text)
|
54 |
+
prefix_mask = [1] * len(tokens)
|
55 |
+
if special_tokens:
|
56 |
+
tokens += [self.text_tokenizer["[gMASK]"], self.text_tokenizer["<sop>"]]
|
57 |
+
prefix_mask += [1, 0]
|
58 |
+
|
59 |
+
if text_pair is not None:
|
60 |
+
pair_tokens = self.text_tokenizer.encode(text_pair)
|
61 |
+
tokens += pair_tokens
|
62 |
+
prefix_mask += [0] * len(pair_tokens)
|
63 |
+
if special_tokens:
|
64 |
+
tokens += [self.text_tokenizer["<eop>"]]
|
65 |
+
prefix_mask += [0]
|
66 |
+
|
67 |
+
return (tokens if add_dummy_prefix else tokens[2:]), prefix_mask
|
68 |
+
|
69 |
+
|
70 |
+
def decode(self, text_ids: list[int]) -> str:
|
71 |
+
text = self.text_tokenizer.decode(text_ids)
|
72 |
+
text = text.replace("<n>", "\n")
|
73 |
+
text = text.replace("<|tab|>", "\t")
|
74 |
+
text = re.sub(r"<\|blank_(\d\d?)\|>", replace_blank_with_spaces, text)
|
75 |
+
return text
|
web-ui.py
ADDED
@@ -0,0 +1,74 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import streamlit as st
|
2 |
+
from streamlit_chat import message
|
3 |
+
from model import ChatGLMModel, chat_template
|
4 |
+
|
5 |
+
|
6 |
+
# page state
|
7 |
+
|
8 |
+
@st.cache_resource
|
9 |
+
def create_model():
|
10 |
+
return ChatGLMModel()
|
11 |
+
|
12 |
+
with st.spinner("加载模型中..."):
|
13 |
+
model = create_model()
|
14 |
+
|
15 |
+
|
16 |
+
if "history" not in st.session_state:
|
17 |
+
st.session_state["history"] = []
|
18 |
+
|
19 |
+
|
20 |
+
# parameters
|
21 |
+
|
22 |
+
with st.sidebar:
|
23 |
+
st.markdown("## 采样参数")
|
24 |
+
|
25 |
+
max_tokens = st.number_input("max_tokens", min_value=1, max_value=500, value=200)
|
26 |
+
temperature = st.number_input("temperature", min_value=0.1, max_value=4.0, value=1.0)
|
27 |
+
top_p = st.number_input("top_p", min_value=0.1, max_value=1.0, value=0.7)
|
28 |
+
top_k = st.number_input("top_k", min_value=1, max_value=500, value=50)
|
29 |
+
|
30 |
+
if st.button("清空上下文"):
|
31 |
+
st.session_state.message = ""
|
32 |
+
st.session_state.history = []
|
33 |
+
|
34 |
+
st.markdown("""
|
35 |
+
[ChatGLM](https://huggingface.co/THUDM/chatglm-6b) + [ONNXRuntime](https://onnxruntime.ai/)
|
36 |
+
""")
|
37 |
+
|
38 |
+
|
39 |
+
# main body
|
40 |
+
|
41 |
+
st.markdown("## ChatGLM + ONNXRuntime")
|
42 |
+
|
43 |
+
history: list[tuple[str, str]] = st.session_state.history
|
44 |
+
|
45 |
+
if len(history) == 0:
|
46 |
+
st.caption("请在下方输入消息开始会话")
|
47 |
+
|
48 |
+
|
49 |
+
for idx, (question, answer) in enumerate(history):
|
50 |
+
message(question, is_user=True, key=f"history_question_{idx}")
|
51 |
+
message(answer, key=f"history_answer_{idx}")
|
52 |
+
|
53 |
+
|
54 |
+
next_answer = st.container()
|
55 |
+
|
56 |
+
question = st.text_area(label="消息", key="message")
|
57 |
+
|
58 |
+
if st.button("发送") and len(question.strip()):
|
59 |
+
with next_answer:
|
60 |
+
message(question, is_user=True, key="message_question")
|
61 |
+
with st.spinner("正在回复中"):
|
62 |
+
with st.empty():
|
63 |
+
prompt = chat_template(history, question)
|
64 |
+
for answer in model.generate_iterate(
|
65 |
+
prompt,
|
66 |
+
max_generated_tokens=max_tokens,
|
67 |
+
top_k=top_k,
|
68 |
+
top_p=top_p,
|
69 |
+
temperature=temperature,
|
70 |
+
):
|
71 |
+
st.write(answer)
|
72 |
+
message(answer, key="message_answer")
|
73 |
+
|
74 |
+
st.session_state.history = history + [(question, answer)]
|