binxu.wang commited on
Commit
3a34f41
1 Parent(s): ec0497c

add additional inputs

Browse files
Files changed (1) hide show
  1. app.py +12 -4
app.py CHANGED
@@ -20,16 +20,24 @@ from transformers import BertTokenizerFast, BertTokenizer
20
  model = GPT2LMHeadModel.from_pretrained("binxu/Ziyue-GPT2")
21
  generator = pipeline('text-generation', model=model, tokenizer='bert-base-chinese')
22
 
23
- def generate(prompt):
24
- outputs = generator(prompt, max_length=50, num_return_sequences=5, num_beams=10, repetition_penalty=1.5)
 
 
 
 
25
  output_texts = [output['generated_text'] for output in outputs]
26
  output_all = "\n\n".join(output_texts)
27
  return output_all
28
 
29
- examples = ["子曰", "子墨子曰", "孟子", "秦王", "子路问仁"]
30
 
31
  iface = gr.Interface(fn=generate,
32
- inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
 
 
 
 
33
  outputs=gr.outputs.Textbox(label="Generated Text"),
34
  examples=examples)
35
  iface.launch()
 
20
  model = GPT2LMHeadModel.from_pretrained("binxu/Ziyue-GPT2")
21
  generator = pipeline('text-generation', model=model, tokenizer='bert-base-chinese')
22
 
23
+ def generate(prompt, num_beams, max_length, repetition_penalty, seed):
24
+
25
+ if seed is not None:
26
+ torch.manual_seed(seed)
27
+
28
+ outputs = generator(prompt, max_length=max_length, num_return_sequences=5, num_beams=num_beams, repetition_penalty=repetition_penalty)
29
  output_texts = [output['generated_text'] for output in outputs]
30
  output_all = "\n\n".join(output_texts)
31
  return output_all
32
 
33
+ examples = ["子曰", "子墨子曰", "孟子", "秦王", "子路问仁", "孙行者笑道", "牛魔王与红孩儿", "鲲鹏", "宝玉道", "黛玉行至贾母处"]
34
 
35
  iface = gr.Interface(fn=generate,
36
+ inputs=[gr.inputs.Textbox(lines=2, label="Prompt"),
37
+ gr.inputs.Slider(minimum=1, maximum=20, default=10, label="Number of beams"),
38
+ gr.inputs.Slider(minimum=10, maximum=100, default=50, label="Max length"),
39
+ gr.inputs.Slider(minimum=1, maximum=5, default=1.5, label="Repetition penalty"),
40
+ gr.inputs.Number(default=0, label="Seed")],
41
  outputs=gr.outputs.Textbox(label="Generated Text"),
42
  examples=examples)
43
  iface.launch()