Upload folder using huggingface_hub
Browse files- README.md +3 -9
- webdemo.py +231 -0
README.md
CHANGED
@@ -1,12 +1,6 @@
|
|
1 |
---
|
2 |
-
title:
|
3 |
-
|
4 |
-
colorFrom: gray
|
5 |
-
colorTo: blue
|
6 |
sdk: gradio
|
7 |
-
sdk_version:
|
8 |
-
app_file: app.py
|
9 |
-
pinned: false
|
10 |
---
|
11 |
-
|
12 |
-
Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
|
|
|
1 |
---
|
2 |
+
title: secgpt
|
3 |
+
app_file: webdemo.py
|
|
|
|
|
4 |
sdk: gradio
|
5 |
+
sdk_version: 3.37.0
|
|
|
|
|
6 |
---
|
|
|
|
webdemo.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import json
|
2 |
+
import sys
|
3 |
+
from threading import Thread
|
4 |
+
from queue import Queue
|
5 |
+
import os
|
6 |
+
|
7 |
+
import gradio as gr
|
8 |
+
import torch
|
9 |
+
from peft import PeftModel
|
10 |
+
from transformers import GenerationConfig, AutoTokenizer, AutoModelForCausalLM
|
11 |
+
import time
|
12 |
+
|
13 |
+
|
14 |
+
os.environ['MallocStackLogging'] = '0'
|
15 |
+
if torch.cuda.is_available():
|
16 |
+
device = "auto"
|
17 |
+
else:
|
18 |
+
device = "cpu"
|
19 |
+
|
20 |
+
|
21 |
+
def reformat_sft(instruction, input):
|
22 |
+
if input:
|
23 |
+
prefix = (
|
24 |
+
"Below is an instruction that describes a task, paired with an input that provides further context. "
|
25 |
+
"Write a response that appropriately completes the request.\n"
|
26 |
+
f"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
|
27 |
+
)
|
28 |
+
else:
|
29 |
+
prefix = (
|
30 |
+
"Below is an instruction that describes a task. "
|
31 |
+
"Write a response that appropriately completes the request.\n"
|
32 |
+
f"### Instruction:\n{instruction}\n\n### Response:"
|
33 |
+
)
|
34 |
+
return prefix
|
35 |
+
|
36 |
+
|
37 |
+
class TextIterStreamer:
|
38 |
+
def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=False):
|
39 |
+
self.tokenizer = tokenizer
|
40 |
+
self.skip_prompt = skip_prompt
|
41 |
+
self.skip_special_tokens = skip_special_tokens
|
42 |
+
self.tokens = []
|
43 |
+
self.text_queue = Queue()
|
44 |
+
# self.text_queue = []
|
45 |
+
self.next_tokens_are_prompt = True
|
46 |
+
|
47 |
+
def put(self, value):
|
48 |
+
if self.skip_prompt and self.next_tokens_are_prompt:
|
49 |
+
self.next_tokens_are_prompt = False
|
50 |
+
else:
|
51 |
+
if len(value.shape) > 1:
|
52 |
+
value = value[0]
|
53 |
+
self.tokens.extend(value.tolist())
|
54 |
+
word = self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)
|
55 |
+
# self.text_queue.append(word)
|
56 |
+
self.text_queue.put(word)
|
57 |
+
|
58 |
+
def end(self):
|
59 |
+
# self.text_queue.append(None)
|
60 |
+
self.text_queue.put(None)
|
61 |
+
|
62 |
+
def __iter__(self):
|
63 |
+
return self
|
64 |
+
|
65 |
+
def __next__(self):
|
66 |
+
value = self.text_queue.get()
|
67 |
+
if value is None:
|
68 |
+
raise StopIteration()
|
69 |
+
else:
|
70 |
+
return value
|
71 |
+
|
72 |
+
|
73 |
+
def main(
|
74 |
+
base_model: str = "",
|
75 |
+
lora_weights: str = "",
|
76 |
+
share_gradio: bool = False,
|
77 |
+
):
|
78 |
+
tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
|
79 |
+
model = AutoModelForCausalLM.from_pretrained(
|
80 |
+
base_model,
|
81 |
+
device_map=device,
|
82 |
+
trust_remote_code=True,
|
83 |
+
torch_dtype=torch.float16
|
84 |
+
)
|
85 |
+
if lora_weights:
|
86 |
+
model = PeftModel.from_pretrained(
|
87 |
+
model,
|
88 |
+
lora_weights
|
89 |
+
)
|
90 |
+
|
91 |
+
model.eval()
|
92 |
+
|
93 |
+
def evaluate(
|
94 |
+
instruction,
|
95 |
+
temperature=0.1,
|
96 |
+
top_p=0.75,
|
97 |
+
max_new_tokens=128,
|
98 |
+
repetition_penalty=1.1,
|
99 |
+
**kwargs,
|
100 |
+
):
|
101 |
+
print(instruction,
|
102 |
+
temperature,
|
103 |
+
top_p,
|
104 |
+
max_new_tokens,
|
105 |
+
repetition_penalty,
|
106 |
+
**kwargs)
|
107 |
+
if not instruction:
|
108 |
+
return
|
109 |
+
prompt = reformat_sft(instruction, "")
|
110 |
+
|
111 |
+
inputs = tokenizer(prompt, return_tensors="pt")
|
112 |
+
input_ids = inputs["input_ids"].cuda()
|
113 |
+
|
114 |
+
if not (1 > temperature > 0):
|
115 |
+
temperature = 1
|
116 |
+
if not (1 > top_p > 0):
|
117 |
+
top_p = 1
|
118 |
+
if not (2000 > max_new_tokens > 0):
|
119 |
+
max_new_tokens = 200
|
120 |
+
if not (5 > repetition_penalty > 0):
|
121 |
+
repetition_penalty = 1.1
|
122 |
+
|
123 |
+
output = ['', '', '']
|
124 |
+
for i in range(3):
|
125 |
+
if i > 0:
|
126 |
+
time.sleep(0.5)
|
127 |
+
streamer = TextIterStreamer(tokenizer)
|
128 |
+
generation_config = dict(
|
129 |
+
temperature=temperature,
|
130 |
+
top_p=top_p,
|
131 |
+
max_new_tokens=max_new_tokens,
|
132 |
+
do_sample=True,
|
133 |
+
repetition_penalty=repetition_penalty,
|
134 |
+
streamer=streamer,
|
135 |
+
)
|
136 |
+
c = Thread(target=lambda: model.generate(input_ids=input_ids, **generation_config))
|
137 |
+
c.start()
|
138 |
+
for text in streamer:
|
139 |
+
output[i] = text
|
140 |
+
yield output[0], output[1], output[2]
|
141 |
+
print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
|
142 |
+
print(output)
|
143 |
+
|
144 |
+
def fk_select(select_option):
|
145 |
+
def inner(context, answer1, answer2, answer3, fankui):
|
146 |
+
print("反馈", select_option, context, answer1, answer2, answer3, fankui)
|
147 |
+
gr.Info("反馈成功")
|
148 |
+
data = {
|
149 |
+
"context": context,
|
150 |
+
"answer": [answer1, answer2, answer3],
|
151 |
+
"choose": ""
|
152 |
+
}
|
153 |
+
if select_option == 1:
|
154 |
+
data["choose"] = answer1
|
155 |
+
elif select_option == 2:
|
156 |
+
data["choose"] = answer2
|
157 |
+
elif select_option == 3:
|
158 |
+
data["choose"] = answer3
|
159 |
+
elif select_option == 4:
|
160 |
+
data["choose"] = fankui
|
161 |
+
with open("fankui.jsonl", 'a+', encoding="utf-8") as f:
|
162 |
+
f.write(json.dumps(data, ensure_ascii=False) + "\n")
|
163 |
+
|
164 |
+
return inner
|
165 |
+
|
166 |
+
with gr.Blocks() as demo:
|
167 |
+
gr.Markdown(
|
168 |
+
"# 云起无垠SecGPT模型RLHF测试\n\nHuggingface: https://huggingface.co/w8ay/secgpt\nGithub: https://github.com/Clouditera/secgpt")
|
169 |
+
with gr.Row():
|
170 |
+
with gr.Column(): # 列排列
|
171 |
+
context = gr.Textbox(
|
172 |
+
lines=3,
|
173 |
+
label="Instruction",
|
174 |
+
placeholder="Tell me ..",
|
175 |
+
)
|
176 |
+
temperature = gr.Slider(
|
177 |
+
minimum=0, maximum=1, value=0.3, label="Temperature"
|
178 |
+
)
|
179 |
+
topp = gr.Slider(
|
180 |
+
minimum=0, maximum=1, value=0.7, label="Top p"
|
181 |
+
)
|
182 |
+
max_tokens = gr.Slider(
|
183 |
+
minimum=1, maximum=2000, step=1, value=300, label="Max tokens"
|
184 |
+
)
|
185 |
+
repetion = gr.Slider(
|
186 |
+
minimum=0, maximum=10, value=1.1, label="repetition_penalty"
|
187 |
+
)
|
188 |
+
with gr.Column():
|
189 |
+
answer1 = gr.Textbox(
|
190 |
+
lines=4,
|
191 |
+
label="回答1",
|
192 |
+
)
|
193 |
+
fk1 = gr.Button("选这个")
|
194 |
+
answer2 = gr.Textbox(
|
195 |
+
lines=4,
|
196 |
+
label="回答2",
|
197 |
+
)
|
198 |
+
fk2 = gr.Button("选这个")
|
199 |
+
answer3 = gr.Textbox(
|
200 |
+
lines=4,
|
201 |
+
label="回答3",
|
202 |
+
)
|
203 |
+
fk3 = gr.Button("选这个")
|
204 |
+
fankui = gr.Textbox(
|
205 |
+
lines=4,
|
206 |
+
label="反馈回答",
|
207 |
+
)
|
208 |
+
fk4 = gr.Button("都不好,反馈")
|
209 |
+
with gr.Row():
|
210 |
+
submit = gr.Button("submit", variant="primary")
|
211 |
+
gr.ClearButton([context, answer1, answer2, answer3, fankui])
|
212 |
+
submit.click(fn=evaluate, inputs=[context, temperature, topp, max_tokens, repetion],
|
213 |
+
outputs=[answer1, answer2, answer3])
|
214 |
+
fk1.click(fn=fk_select(1), inputs=[context, answer1, answer2, answer3, fankui])
|
215 |
+
fk2.click(fn=fk_select(2), inputs=[context, answer1, answer2, answer3, fankui])
|
216 |
+
fk3.click(fn=fk_select(3), inputs=[context, answer1, answer2, answer3, fankui])
|
217 |
+
fk4.click(fn=fk_select(4), inputs=[context, answer1, answer2, answer3, fankui])
|
218 |
+
|
219 |
+
demo.queue().launch(server_name="0.0.0.0", share=True)
|
220 |
+
# Old testing code follows.
|
221 |
+
|
222 |
+
|
223 |
+
if __name__ == "__main__":
|
224 |
+
import argparse
|
225 |
+
|
226 |
+
parser = argparse.ArgumentParser(description='云起无垠SecGPT模型RLHF测试')
|
227 |
+
parser.add_argument("--base_model", type=str, required=True, help="基础模型")
|
228 |
+
parser.add_argument("--lora", type=str, help="lora模型")
|
229 |
+
parser.add_argument("--share_gradio", type=bool, default=False, help="开放外网访问")
|
230 |
+
args = parser.parse_args()
|
231 |
+
main(args.base_model, args.lora, args.share_gradio)
|