|
import gradio as gr |
|
import transformers |
|
|
|
import torch |
|
from transformers import pipeline, set_seed |
|
from transformers import GPT2Model, GPT2Config, GPT2LMHeadModel, AutoModel |
|
from transformers import BertTokenizerFast, BertTokenizer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
model = GPT2LMHeadModel.from_pretrained("binxu/Ziyue-GPT2") |
|
generator = pipeline('text-generation', model=model, tokenizer='bert-base-chinese') |
|
|
|
def generate(prompt, num_beams, max_length, repetition_penalty, seed): |
|
|
|
if seed is not None: |
|
torch.manual_seed(seed) |
|
|
|
outputs = generator(prompt, max_length=max_length, num_return_sequences=5, num_beams=num_beams, repetition_penalty=repetition_penalty) |
|
output_texts = [output['generated_text'] for output in outputs] |
|
output_all = "\n\n".join(output_texts) |
|
return output_all |
|
|
|
examples = ["子曰", "子墨子曰", "孟子", "秦王", "子路问仁", "孙行者笑道", "牛魔王与红孩儿", "鲲鹏", "宝玉道", "黛玉行至贾母处"] |
|
|
|
iface = gr.Interface(fn=generate, |
|
inputs=[gr.inputs.Textbox(lines=2, label="Prompt"), |
|
gr.inputs.Slider(minimum=1, maximum=20, default=10, label="Number of beams"), |
|
gr.inputs.Slider(minimum=10, maximum=100, default=50, label="Max length"), |
|
gr.inputs.Slider(minimum=1, maximum=5, default=1.5, label="Repetition penalty"), |
|
gr.inputs.Number(default=0, label="Seed")], |
|
outputs=gr.outputs.Textbox(label="Generated Text"), |
|
examples=examples) |
|
iface.launch() |
|
|