|
import logging |
|
import pathlib |
|
import gradio as gr |
|
import pandas as pd |
|
from gt4sd.algorithms.generation.pgt import ( |
|
PGT, |
|
PGTCoherenceChecker, |
|
PGTEditor, |
|
PGTGenerator, |
|
) |
|
from gt4sd.algorithms.registry import ApplicationsRegistry |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
logger.addHandler(logging.NullHandler()) |
|
|
|
MODEL_FN = { |
|
"PGTGenerator": PGTGenerator, |
|
"PGTEditor": PGTEditor, |
|
"PGTCoherenceChecker": PGTCoherenceChecker, |
|
} |
|
|
|
|
|
def run_inference( |
|
model_type: str, |
|
generator_task: str, |
|
editor_task: str, |
|
checker_task: str, |
|
prompt: str, |
|
second_prompt: str, |
|
length: int, |
|
k: int, |
|
p: float, |
|
): |
|
|
|
kwargs = {"max_length": length, "top_k": k, "top_p": p} |
|
|
|
if model_type == "PGTGenerator": |
|
config = PGTGenerator(task=generator_task, input_text=prompt, **kwargs) |
|
elif model_type == "PGTEditor": |
|
config = PGTEditor(input_type=editor_task, input_text=prompt, **kwargs) |
|
elif model_type == "PGTCoherenceChecker": |
|
config = PGTCoherenceChecker( |
|
coherence_type=checker_task, input_a=prompt, input_b=second_prompt, **kwargs |
|
) |
|
|
|
model = PGT(config) |
|
text = list(model.sample(1))[0] |
|
|
|
return text |
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
all_algos = ApplicationsRegistry.list_available() |
|
algos = [ |
|
x["algorithm_application"] |
|
for x in list(filter(lambda x: "PGT" in x["algorithm_name"], all_algos)) |
|
] |
|
|
|
|
|
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") |
|
|
|
examples = pd.read_csv( |
|
metadata_root.joinpath("examples.csv"), sep="|", header=None |
|
).fillna("") |
|
print("Examples: ", examples.values.tolist()) |
|
|
|
with open(metadata_root.joinpath("article.md"), "r") as f: |
|
article = f.read() |
|
with open(metadata_root.joinpath("description.md"), "r") as f: |
|
description = f.read() |
|
|
|
gen_tasks = [ |
|
"title-to-abstract", |
|
"abstract-to-title", |
|
"abstract-to-claim", |
|
"claim-to-abstract", |
|
] |
|
|
|
demo = gr.Interface( |
|
fn=run_inference, |
|
title="Patent Generative Transformer", |
|
inputs=[ |
|
gr.Dropdown(algos, label="Model type", value="PGTGenerator"), |
|
gr.Dropdown(gen_tasks, label="Generator task", value="title-to-abstract"), |
|
gr.Dropdown(["abstract", "claim"], label="Editor task", value="abstract"), |
|
gr.Dropdown( |
|
["title-abstract", "title-claim", "abstract-claim"], |
|
label="Checker task", |
|
value="title-abstract", |
|
), |
|
gr.Textbox( |
|
label="Primary Text prompt", |
|
placeholder="Artificial intelligence and machine learning infrastructure", |
|
lines=5, |
|
), |
|
gr.Textbox( |
|
label="Secondary text prompt (only coherence checker)", |
|
placeholder="", |
|
lines=1, |
|
), |
|
gr.Slider( |
|
minimum=5, maximum=1024, value=512, label="Maximal length", step=1 |
|
), |
|
gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1), |
|
gr.Slider(minimum=0.5, maximum=1.0, value=0.95, label="Top-p"), |
|
], |
|
outputs=gr.Textbox(label="Output"), |
|
article=article, |
|
description=description, |
|
examples=examples.values.tolist(), |
|
) |
|
demo.launch(debug=True, show_error=True) |
|
|