|
import gradio as gr |
|
import re |
|
|
|
from transformers import ( |
|
AutoTokenizer, |
|
AutoModelForSeq2SeqLM, |
|
) |
|
|
|
def clean_text(text): |
|
text = text.encode("ascii", errors="ignore").decode( |
|
"ascii" |
|
) |
|
text = re.sub(r"http\S+", "", text) |
|
text = re.sub(r"\n", " ", text) |
|
text = re.sub(r"\n\n", " ", text) |
|
text = re.sub(r"\t", " ", text) |
|
text = re.sub(r"ADVERTISEMENT", " ", text) |
|
text = text.strip(" ") |
|
text = re.sub( |
|
" +", " ", text |
|
).strip() |
|
return text |
|
|
|
|
|
model_name = "chinhon/pegasus-newsroom-headline_writer_57k" |
|
|
|
def headline_writer(text): |
|
input_text = clean_text(text) |
|
|
|
tokenizer = AutoTokenizer.from_pretrained(model_name) |
|
|
|
model = AutoModelForSeq2SeqLM.from_pretrained(model_name) |
|
|
|
with tokenizer.as_target_tokenizer(): |
|
batch = tokenizer( |
|
input_text, |
|
truncation=True, |
|
padding="longest", |
|
return_tensors="pt", |
|
) |
|
|
|
raw_write = model.generate(**batch) |
|
|
|
headline = tokenizer.batch_decode( |
|
raw_write, skip_special_tokens=True, min_length=200, length_penalty=50.5 |
|
) |
|
|
|
return headline[0] |
|
|
|
|
|
gradio_ui = gr.Interface( |
|
fn=headline_writer, |
|
title="Generate News Headlines with AI", |
|
description="Too busy or tired to write a headline? Try this instead.", |
|
inputs=gr.inputs.Textbox( |
|
lines=20, label="Paste the first few paras of your news story here" |
|
), |
|
outputs=gr.outputs.Textbox(label="Suggested Headline"), |
|
theme="darkdefault" |
|
) |
|
|
|
gradio_ui.launch(enable_queue=True) |
|
|