Spaces:
Sleeping
Sleeping
import gradio as gr | |
from transformers import pipeline | |
def generate_text( | |
model_name, | |
text, | |
min_length, | |
max_length, | |
temperature, | |
top_k, | |
top_p | |
): | |
models_map = { | |
"Мои любимые юморески": "gpt2-vk-aneki", | |
"бугро тред": "gpt2-vk-bugro", | |
"Калик)": "gpt2-vk-kalik" | |
} | |
model = "MesonWarrior/" + models_map[model_name] | |
pipe = pipeline( | |
'text-generation', | |
model=model, | |
tokenizer=model, | |
min_length=min_length, | |
max_length=max_length | |
) | |
return pipe(text, temperature=temperature, top_k=top_k, top_p=top_p, do_sample=True)[0]['generated_text'] | |
def interface(): | |
with gr.Row(): | |
with gr.Column(): | |
with gr.Row(): | |
model = gr.Dropdown( | |
["Мои любимые юморески", "бугро тред", "Калик)"], | |
label="Модель (Текст какого паблика генерировать)", | |
value="Мои любимые юморески", | |
) | |
text = gr.Textbox(lines=7, label="Входной текст", placeholder="Введите текст который продолжит нейросеть...") | |
output = gr.Textbox(lines=12, label="Выходной текст", placeholder="Здесь будет текст сгенерированный нейросетью...") | |
with gr.Row(): | |
with gr.Column(): | |
min_length = gr.Slider( | |
minimum=0, maximum=100, value=32, step=1, | |
label="Минимальная длина", | |
info="Минимальное количество символов в выходном тексте." | |
) | |
max_length = gr.Slider( | |
minimum=0, maximum=200, value=64, step=1, | |
label="Максимальная длина", | |
info="Максимальное количество символов в выходном тексте." | |
) | |
temperature = gr.Slider( | |
minimum=0.05, maximum=1.95, value=0.9, step=0.05, | |
label="Температура", | |
info="Чем выше тем рандомнее, чем ниже тем больше повторений." | |
) | |
top_k = gr.Slider( | |
minimum=0, maximum=100, value=50, step=0.05, | |
label="Top K", | |
) | |
top_p = gr.Slider( | |
minimum=0, maximum=1, value=0.9, step=0.05, | |
label="Top P", | |
) | |
with gr.Column(): | |
with gr.Row(): | |
generate_btn = gr.Button( | |
"Сгенерировать", variant="primary", label="Generate", | |
) | |
generate_btn.click( | |
fn=generate_text, | |
inputs=[ | |
model, | |
text, | |
min_length, | |
max_length, | |
temperature, | |
top_k, | |
top_p | |
], | |
outputs=output, | |
) | |
with gr.Blocks( | |
title="GPT2 VK") as demo: | |
gr.Markdown(""" | |
# GPT2 VK | |
Файнтюны [этой](https://huggingface.co/ai-forever/rugpt3medium_based_on_gpt2) модели по вашим любимым пабликам ВКонтакте. | |
#### Паблики представленные в моделях: | |
- [Мои любимые юморески 🎩](https://huggingface.co/MesonWarrior/gpt2-vk-aneki) | |
- [бугро тред 💥](https://huggingface.co/MesonWarrior/gpt2-vk-bugro) | |
- [Калик) 🍏🍎💨](https://huggingface.co/MesonWarrior/gpt2-vk-kalik) <sub><sup>(Обучено на спорном датасете из постов и комментариев, надо бы переобучить на данных получше)</sup></sub> | |
""") | |
interface() | |
demo.queue().launch() |