binxu.wang commited on
Commit
671920a
1 Parent(s): cb2ce74

add min step

Browse files
Files changed (1) hide show
  1. app.py +4 -5
app.py CHANGED
@@ -22,8 +22,7 @@ generator = pipeline('text-generation', model=model, tokenizer='bert-base-chines
22
 
23
  def generate(prompt, num_beams, max_length, repetition_penalty, seed):
24
 
25
- if seed is not None:
26
- torch.manual_seed(seed)
27
 
28
  outputs = generator(prompt, max_length=max_length, num_return_sequences=5, num_beams=num_beams, repetition_penalty=repetition_penalty)
29
  output_texts = [output['generated_text'] for output in outputs]
@@ -37,15 +36,15 @@ examples = [["子曰", 10, 50, 1.5, 42],
37
  ["子路问仁", 10, 50, 1.5, 42],
38
  ["孙行者笑道", 10, 50, 1.5, 42],
39
  ["牛魔王与红孩儿", 10, 50, 1.5, 42],
40
- ["鲲鹏", 10, 50, 1.5, 42],
41
  ["宝玉道", 10, 50, 1.5, 42],
42
  ["黛玉行至贾母处", 10, 50, 1.5, 42],]
43
 
44
 
45
  iface = gr.Interface(fn=generate,
46
  inputs=[gr.inputs.Textbox(lines=2, label="Prompt"),
47
- gr.inputs.Slider(minimum=1, maximum=20, default=10, label="Number of beams"),
48
- gr.inputs.Slider(minimum=10, maximum=100, default=50, label="Max length"),
49
  gr.inputs.Slider(minimum=1, maximum=5, default=1.5, label="Repetition penalty"),
50
  gr.inputs.Number(default=0, label="Seed")],
51
  outputs=gr.outputs.Textbox(label="Generated Text"),
 
22
 
23
  def generate(prompt, num_beams, max_length, repetition_penalty, seed):
24
 
25
+ torch.manual_seed(seed)
 
26
 
27
  outputs = generator(prompt, max_length=max_length, num_return_sequences=5, num_beams=num_beams, repetition_penalty=repetition_penalty)
28
  output_texts = [output['generated_text'] for output in outputs]
 
36
  ["子路问仁", 10, 50, 1.5, 42],
37
  ["孙行者笑道", 10, 50, 1.5, 42],
38
  ["牛魔王与红孩儿", 10, 50, 1.5, 42],
39
+ ["鲲鹏", 10, 50, 1.5, 42],
40
  ["宝玉道", 10, 50, 1.5, 42],
41
  ["黛玉行至贾母处", 10, 50, 1.5, 42],]
42
 
43
 
44
  iface = gr.Interface(fn=generate,
45
  inputs=[gr.inputs.Textbox(lines=2, label="Prompt"),
46
+ gr.inputs.Slider(minimum=1, maximum=20, default=10, step=1, label="Number of beams"),
47
+ gr.inputs.Slider(minimum=10, maximum=100, default=50, step=1, label="Max length"),
48
  gr.inputs.Slider(minimum=1, maximum=5, default=1.5, label="Repetition penalty"),
49
  gr.inputs.Number(default=0, label="Seed")],
50
  outputs=gr.outputs.Textbox(label="Generated Text"),