binxu.wang
commited on
Commit
•
671920a
1
Parent(s):
cb2ce74
add min step
Browse files
app.py
CHANGED
@@ -22,8 +22,7 @@ generator = pipeline('text-generation', model=model, tokenizer='bert-base-chines
|
|
22 |
|
23 |
def generate(prompt, num_beams, max_length, repetition_penalty, seed):
|
24 |
|
25 |
-
|
26 |
-
torch.manual_seed(seed)
|
27 |
|
28 |
outputs = generator(prompt, max_length=max_length, num_return_sequences=5, num_beams=num_beams, repetition_penalty=repetition_penalty)
|
29 |
output_texts = [output['generated_text'] for output in outputs]
|
@@ -37,15 +36,15 @@ examples = [["子曰", 10, 50, 1.5, 42],
|
|
37 |
["子路问仁", 10, 50, 1.5, 42],
|
38 |
["孙行者笑道", 10, 50, 1.5, 42],
|
39 |
["牛魔王与红孩儿", 10, 50, 1.5, 42],
|
40 |
-
["鲲鹏", 10, 50, 1.5, 42],
|
41 |
["宝玉道", 10, 50, 1.5, 42],
|
42 |
["黛玉行至贾母处", 10, 50, 1.5, 42],]
|
43 |
|
44 |
|
45 |
iface = gr.Interface(fn=generate,
|
46 |
inputs=[gr.inputs.Textbox(lines=2, label="Prompt"),
|
47 |
-
gr.inputs.Slider(minimum=1, maximum=20, default=10, label="Number of beams"),
|
48 |
-
gr.inputs.Slider(minimum=10, maximum=100, default=50, label="Max length"),
|
49 |
gr.inputs.Slider(minimum=1, maximum=5, default=1.5, label="Repetition penalty"),
|
50 |
gr.inputs.Number(default=0, label="Seed")],
|
51 |
outputs=gr.outputs.Textbox(label="Generated Text"),
|
|
|
22 |
|
23 |
def generate(prompt, num_beams, max_length, repetition_penalty, seed):
|
24 |
|
25 |
+
torch.manual_seed(seed)
|
|
|
26 |
|
27 |
outputs = generator(prompt, max_length=max_length, num_return_sequences=5, num_beams=num_beams, repetition_penalty=repetition_penalty)
|
28 |
output_texts = [output['generated_text'] for output in outputs]
|
|
|
36 |
["子路问仁", 10, 50, 1.5, 42],
|
37 |
["孙行者笑道", 10, 50, 1.5, 42],
|
38 |
["牛魔王与红孩儿", 10, 50, 1.5, 42],
|
39 |
+
["鲲鹏", 10, 50, 1.5, 42],
|
40 |
["宝玉道", 10, 50, 1.5, 42],
|
41 |
["黛玉行至贾母处", 10, 50, 1.5, 42],]
|
42 |
|
43 |
|
44 |
iface = gr.Interface(fn=generate,
|
45 |
inputs=[gr.inputs.Textbox(lines=2, label="Prompt"),
|
46 |
+
gr.inputs.Slider(minimum=1, maximum=20, default=10, step=1, label="Number of beams"),
|
47 |
+
gr.inputs.Slider(minimum=10, maximum=100, default=50, step=1, label="Max length"),
|
48 |
gr.inputs.Slider(minimum=1, maximum=5, default=1.5, label="Repetition penalty"),
|
49 |
gr.inputs.Number(default=0, label="Seed")],
|
50 |
outputs=gr.outputs.Textbox(label="Generated Text"),
|