post_generator / app.py
ltomczak1's picture
Update app.py
e9469ba
raw
history blame
No virus
695 Bytes
import gradio as gr
from transformers import GPT2LMHeadModel, GPT2Tokenizer
tokenizer = GPT2Tokenizer.from_pretrained('gpt2-large')
model = GPT2LMHeadModel.from_pretrained('gpt2-large', pad_token_id = tokenizer.eos_token_id)
def generator(text):
input_ids = tokenizer.encode(text, return_tensors = 'pt')
output = model.generate(input_ids, max_length = 500, num_beams = 5, no_repeat_ngram_size = 2, early_stopping = True)
post = tokenizer.decode(output[0], skip_special_tokens = True)
return post
interface = gr.Interface(
fn=generator,
inputs=gr.inputs.Textbox(lines=5, label="Input Text"),
outputs=gr.outputs.Textbox(label="Generated Post"),
)
interface.launch()