ERmak1581's picture
Update app.py
2a116fd verified
raw
history blame contribute delete
No virus
1.38 kB
import gradio as gr
import torch
from transformers import GPT2Tokenizer, GPT2LMHeadModel
tokenizer = GPT2Tokenizer.from_pretrained("ERmak1581/rugpt3large_for_qna_400k1")
model = GPT2LMHeadModel.from_pretrained("ERmak1581/rugpt3large_for_qna_400k1")
def gen(request, temperature, maxnewtokens):
input_text = f"<s> [user] {request} [assistant]"
max_new_tokens = maxnewtokens
input_ids = tokenizer.encode(input_text, return_tensors='pt')
output = model.generate(
input_ids,
do_sample=True,
max_new_tokens=max_new_tokens,
temperature=temperature,
no_repeat_ngram_size=3
)
decoded_output = tokenizer.decode(output[0], skip_special_tokens=True)
res = decoded_output.split("[assistant]")[1]
res.removesuffix("</s>")
res = res.strip()
return res
inputs = [
gr.Textbox(lines=5, label="Input Text"),
gr.Slider(minimum=0.1, maximum=1.9, value=1.0, label="Temperature", step=0.05),
gr.Slider(minimum=10, maximum=200, value=50, label="Max New Tokens", step=5)
]
output = gr.Textbox(label="Output Text")
interface = gr.Interface(gen, inputs, output, title="GPT-2 Text Generation", theme="compact", description="Демонстрация <a href=https://huggingface.co/ERmak1581/rugpt3large_for_qna_400k1>модели</a> для задачи Question-Answer")
interface.launch(share=True)