vk / app.py
MesonWarrior's picture
Update app.py
0933f21
raw
history blame
2.26 kB
import gradio as gr
from transformers import pipeline
from huggingface_hub import login
login(token="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")
def generate_text(
model,
text,
min_length,
max_length,
do_not_truncate,
):
pipe = pipeline(
'text-generation',
model='MesonWarrior/gpt2-vk-bugro',
tokenizer='MesonWarrior/gpt2-vk-bugro',
min_length=min_length,
max_length=max_length,
# do_not_truncate=do_not_truncate,
use_auth_token=True
)
print('generating...')
output = pipe(text)
print(output)
return output[0]['generated_text']
def interface():
with gr.Row():
with gr.Column():
with gr.Row():
model = gr.Dropdown(
["Бугро", "Юморески", "Калик"], label="Model", value="Бугро",
)
text = gr.Textbox(lines=7, label="Input text")
output = gr.Textbox(lines=12, label="Output text")
with gr.Row():
with gr.Column():
min_length = gr.Slider(
minimum=0, maximum=128, value=32, step=1,
label="Min Length",
)
max_length = gr.Slider(
minimum=0, maximum=512, value=96, step=1,
label="Max Length",
)
# do_not_truncate = gr.Checkbox(
# True,
# label="Do not truncate"
# )
with gr.Column():
with gr.Row():
generate_btn = gr.Button(
"Generate", variant="primary", label="Generate",
)
generate_btn.click(
fn=generate_text,
inputs=[
model,
text,
min_length,
max_length,
do_not_truncate
],
outputs=output,
)
with gr.Blocks(
title="GPT2 VK") as demo:
gr.Markdown("""
## GPT2 VK
Файнтюны модели [ai-forever/rugpt3medium_based_on_gpt2](https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2) по вашим любимым пабликам ВКонтакте.
""")
interface()
demo.queue().launch()