binxu.wang
commited on
Commit
•
3a34f41
1
Parent(s):
ec0497c
add additional inputs
Browse files
app.py
CHANGED
@@ -20,16 +20,24 @@ from transformers import BertTokenizerFast, BertTokenizer
|
|
20 |
model = GPT2LMHeadModel.from_pretrained("binxu/Ziyue-GPT2")
|
21 |
generator = pipeline('text-generation', model=model, tokenizer='bert-base-chinese')
|
22 |
|
23 |
-
def generate(prompt):
|
24 |
-
|
|
|
|
|
|
|
|
|
25 |
output_texts = [output['generated_text'] for output in outputs]
|
26 |
output_all = "\n\n".join(output_texts)
|
27 |
return output_all
|
28 |
|
29 |
-
examples = ["子曰", "子墨子曰", "孟子", "秦王", "子路问仁"]
|
30 |
|
31 |
iface = gr.Interface(fn=generate,
|
32 |
-
inputs=gr.inputs.Textbox(lines=
|
|
|
|
|
|
|
|
|
33 |
outputs=gr.outputs.Textbox(label="Generated Text"),
|
34 |
examples=examples)
|
35 |
iface.launch()
|
|
|
20 |
model = GPT2LMHeadModel.from_pretrained("binxu/Ziyue-GPT2")
|
21 |
generator = pipeline('text-generation', model=model, tokenizer='bert-base-chinese')
|
22 |
|
23 |
+
def generate(prompt, num_beams, max_length, repetition_penalty, seed):
|
24 |
+
|
25 |
+
if seed is not None:
|
26 |
+
torch.manual_seed(seed)
|
27 |
+
|
28 |
+
outputs = generator(prompt, max_length=max_length, num_return_sequences=5, num_beams=num_beams, repetition_penalty=repetition_penalty)
|
29 |
output_texts = [output['generated_text'] for output in outputs]
|
30 |
output_all = "\n\n".join(output_texts)
|
31 |
return output_all
|
32 |
|
33 |
+
examples = ["子曰", "子墨子曰", "孟子", "秦王", "子路问仁", "孙行者笑道", "牛魔王与红孩儿", "鲲鹏", "宝玉道", "黛玉行至贾母处"]
|
34 |
|
35 |
iface = gr.Interface(fn=generate,
|
36 |
+
inputs=[gr.inputs.Textbox(lines=2, label="Prompt"),
|
37 |
+
gr.inputs.Slider(minimum=1, maximum=20, default=10, label="Number of beams"),
|
38 |
+
gr.inputs.Slider(minimum=10, maximum=100, default=50, label="Max length"),
|
39 |
+
gr.inputs.Slider(minimum=1, maximum=5, default=1.5, label="Repetition penalty"),
|
40 |
+
gr.inputs.Number(default=0, label="Seed")],
|
41 |
outputs=gr.outputs.Textbox(label="Generated Text"),
|
42 |
examples=examples)
|
43 |
iface.launch()
|