openthaigpt / app.py
Kobkrit Viriyayudhakorn
Update to version 0.0.4
c09e81b
raw
history blame
No virus
1.11 kB
import gradio as gr
from transformers import GPT2LMHeadModel, GPT2Tokenizer
import pandas as pd
import torch
from torch.utils.data import Dataset, random_split
from transformers import GPT2Tokenizer, TrainingArguments, Trainer, GPT2LMHeadModel
pretrained_name = "kobkrit/openthaigpt-gpt2-instructgpt-poc-0.0.4"
tokenizer = GPT2Tokenizer.from_pretrained(pretrained_name, bos_token='<|startoftext|>',unk_token='<|unk|>', eos_token='<|endoftext|>', pad_token='<|pad|>')
model = GPT2LMHeadModel.from_pretrained(pretrained_name).cuda()
model.resize_token_embeddings(len(tokenizer))
def gen(input):
generated = tokenizer("<|startoftext|>"+input, return_tensors="pt").input_ids.cuda()
output = model.generate(generated, top_k=50, num_beams=5, no_repeat_ngram_size=2,
early_stopping=True, max_length=300, top_p=0.95, temperature=1.9)
return tokenizer.decode(output[0], skip_special_tokens=True)
demo = gr.Interface(fn=gen, inputs=gr.Textbox(lines=3, label="Input Text", value="Q: อยากลดความอ้วน ทำอย่างไร\n\nA:"), outputs="text")
demo.launch()