jiamandu commited on
Commit
8766b45
1 Parent(s): 2167401

Upload folder using huggingface_hub

Browse files
Files changed (2) hide show
  1. README.md +3 -9
  2. webdemo.py +231 -0
README.md CHANGED
@@ -1,12 +1,6 @@
1
  ---
2
- title: Secgpt
3
- emoji: 🏆
4
- colorFrom: gray
5
- colorTo: blue
6
  sdk: gradio
7
- sdk_version: 4.8.0
8
- app_file: app.py
9
- pinned: false
10
  ---
11
-
12
- Check out the configuration reference at https://huggingface.co/docs/hub/spaces-config-reference
 
1
  ---
2
+ title: secgpt
3
+ app_file: webdemo.py
 
 
4
  sdk: gradio
5
+ sdk_version: 3.37.0
 
 
6
  ---
 
 
webdemo.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import sys
3
+ from threading import Thread
4
+ from queue import Queue
5
+ import os
6
+
7
+ import gradio as gr
8
+ import torch
9
+ from peft import PeftModel
10
+ from transformers import GenerationConfig, AutoTokenizer, AutoModelForCausalLM
11
+ import time
12
+
13
+
14
+ os.environ['MallocStackLogging'] = '0'
15
+ if torch.cuda.is_available():
16
+ device = "auto"
17
+ else:
18
+ device = "cpu"
19
+
20
+
21
+ def reformat_sft(instruction, input):
22
+ if input:
23
+ prefix = (
24
+ "Below is an instruction that describes a task, paired with an input that provides further context. "
25
+ "Write a response that appropriately completes the request.\n"
26
+ f"### Instruction:\n{instruction}\n\n### Input:\n{input}\n\n### Response:"
27
+ )
28
+ else:
29
+ prefix = (
30
+ "Below is an instruction that describes a task. "
31
+ "Write a response that appropriately completes the request.\n"
32
+ f"### Instruction:\n{instruction}\n\n### Response:"
33
+ )
34
+ return prefix
35
+
36
+
37
+ class TextIterStreamer:
38
+ def __init__(self, tokenizer, skip_prompt=True, skip_special_tokens=False):
39
+ self.tokenizer = tokenizer
40
+ self.skip_prompt = skip_prompt
41
+ self.skip_special_tokens = skip_special_tokens
42
+ self.tokens = []
43
+ self.text_queue = Queue()
44
+ # self.text_queue = []
45
+ self.next_tokens_are_prompt = True
46
+
47
+ def put(self, value):
48
+ if self.skip_prompt and self.next_tokens_are_prompt:
49
+ self.next_tokens_are_prompt = False
50
+ else:
51
+ if len(value.shape) > 1:
52
+ value = value[0]
53
+ self.tokens.extend(value.tolist())
54
+ word = self.tokenizer.decode(self.tokens, skip_special_tokens=self.skip_special_tokens)
55
+ # self.text_queue.append(word)
56
+ self.text_queue.put(word)
57
+
58
+ def end(self):
59
+ # self.text_queue.append(None)
60
+ self.text_queue.put(None)
61
+
62
+ def __iter__(self):
63
+ return self
64
+
65
+ def __next__(self):
66
+ value = self.text_queue.get()
67
+ if value is None:
68
+ raise StopIteration()
69
+ else:
70
+ return value
71
+
72
+
73
+ def main(
74
+ base_model: str = "",
75
+ lora_weights: str = "",
76
+ share_gradio: bool = False,
77
+ ):
78
+ tokenizer = AutoTokenizer.from_pretrained(base_model, trust_remote_code=True)
79
+ model = AutoModelForCausalLM.from_pretrained(
80
+ base_model,
81
+ device_map=device,
82
+ trust_remote_code=True,
83
+ torch_dtype=torch.float16
84
+ )
85
+ if lora_weights:
86
+ model = PeftModel.from_pretrained(
87
+ model,
88
+ lora_weights
89
+ )
90
+
91
+ model.eval()
92
+
93
+ def evaluate(
94
+ instruction,
95
+ temperature=0.1,
96
+ top_p=0.75,
97
+ max_new_tokens=128,
98
+ repetition_penalty=1.1,
99
+ **kwargs,
100
+ ):
101
+ print(instruction,
102
+ temperature,
103
+ top_p,
104
+ max_new_tokens,
105
+ repetition_penalty,
106
+ **kwargs)
107
+ if not instruction:
108
+ return
109
+ prompt = reformat_sft(instruction, "")
110
+
111
+ inputs = tokenizer(prompt, return_tensors="pt")
112
+ input_ids = inputs["input_ids"].cuda()
113
+
114
+ if not (1 > temperature > 0):
115
+ temperature = 1
116
+ if not (1 > top_p > 0):
117
+ top_p = 1
118
+ if not (2000 > max_new_tokens > 0):
119
+ max_new_tokens = 200
120
+ if not (5 > repetition_penalty > 0):
121
+ repetition_penalty = 1.1
122
+
123
+ output = ['', '', '']
124
+ for i in range(3):
125
+ if i > 0:
126
+ time.sleep(0.5)
127
+ streamer = TextIterStreamer(tokenizer)
128
+ generation_config = dict(
129
+ temperature=temperature,
130
+ top_p=top_p,
131
+ max_new_tokens=max_new_tokens,
132
+ do_sample=True,
133
+ repetition_penalty=repetition_penalty,
134
+ streamer=streamer,
135
+ )
136
+ c = Thread(target=lambda: model.generate(input_ids=input_ids, **generation_config))
137
+ c.start()
138
+ for text in streamer:
139
+ output[i] = text
140
+ yield output[0], output[1], output[2]
141
+ print(time.strftime("%Y-%m-%d %H:%M:%S", time.localtime()))
142
+ print(output)
143
+
144
+ def fk_select(select_option):
145
+ def inner(context, answer1, answer2, answer3, fankui):
146
+ print("反馈", select_option, context, answer1, answer2, answer3, fankui)
147
+ gr.Info("反馈成功")
148
+ data = {
149
+ "context": context,
150
+ "answer": [answer1, answer2, answer3],
151
+ "choose": ""
152
+ }
153
+ if select_option == 1:
154
+ data["choose"] = answer1
155
+ elif select_option == 2:
156
+ data["choose"] = answer2
157
+ elif select_option == 3:
158
+ data["choose"] = answer3
159
+ elif select_option == 4:
160
+ data["choose"] = fankui
161
+ with open("fankui.jsonl", 'a+', encoding="utf-8") as f:
162
+ f.write(json.dumps(data, ensure_ascii=False) + "\n")
163
+
164
+ return inner
165
+
166
+ with gr.Blocks() as demo:
167
+ gr.Markdown(
168
+ "# 云起无垠SecGPT模型RLHF测试\n\nHuggingface: https://huggingface.co/w8ay/secgpt\nGithub: https://github.com/Clouditera/secgpt")
169
+ with gr.Row():
170
+ with gr.Column(): # 列排列
171
+ context = gr.Textbox(
172
+ lines=3,
173
+ label="Instruction",
174
+ placeholder="Tell me ..",
175
+ )
176
+ temperature = gr.Slider(
177
+ minimum=0, maximum=1, value=0.3, label="Temperature"
178
+ )
179
+ topp = gr.Slider(
180
+ minimum=0, maximum=1, value=0.7, label="Top p"
181
+ )
182
+ max_tokens = gr.Slider(
183
+ minimum=1, maximum=2000, step=1, value=300, label="Max tokens"
184
+ )
185
+ repetion = gr.Slider(
186
+ minimum=0, maximum=10, value=1.1, label="repetition_penalty"
187
+ )
188
+ with gr.Column():
189
+ answer1 = gr.Textbox(
190
+ lines=4,
191
+ label="回答1",
192
+ )
193
+ fk1 = gr.Button("选这个")
194
+ answer2 = gr.Textbox(
195
+ lines=4,
196
+ label="回答2",
197
+ )
198
+ fk2 = gr.Button("选这个")
199
+ answer3 = gr.Textbox(
200
+ lines=4,
201
+ label="回答3",
202
+ )
203
+ fk3 = gr.Button("选这个")
204
+ fankui = gr.Textbox(
205
+ lines=4,
206
+ label="反馈回答",
207
+ )
208
+ fk4 = gr.Button("都不好,反馈")
209
+ with gr.Row():
210
+ submit = gr.Button("submit", variant="primary")
211
+ gr.ClearButton([context, answer1, answer2, answer3, fankui])
212
+ submit.click(fn=evaluate, inputs=[context, temperature, topp, max_tokens, repetion],
213
+ outputs=[answer1, answer2, answer3])
214
+ fk1.click(fn=fk_select(1), inputs=[context, answer1, answer2, answer3, fankui])
215
+ fk2.click(fn=fk_select(2), inputs=[context, answer1, answer2, answer3, fankui])
216
+ fk3.click(fn=fk_select(3), inputs=[context, answer1, answer2, answer3, fankui])
217
+ fk4.click(fn=fk_select(4), inputs=[context, answer1, answer2, answer3, fankui])
218
+
219
+ demo.queue().launch(server_name="0.0.0.0", share=True)
220
+ # Old testing code follows.
221
+
222
+
223
+ if __name__ == "__main__":
224
+ import argparse
225
+
226
+ parser = argparse.ArgumentParser(description='云起无垠SecGPT模型RLHF测试')
227
+ parser.add_argument("--base_model", type=str, required=True, help="基础模型")
228
+ parser.add_argument("--lora", type=str, help="lora模型")
229
+ parser.add_argument("--share_gradio", type=bool, default=False, help="开放外网访问")
230
+ args = parser.parse_args()
231
+ main(args.base_model, args.lora, args.share_gradio)