File size: 3,472 Bytes
e3475d1 c564047 e3475d1 c564047 e3475d1 c564047 e3475d1 c564047 e3475d1 c564047 e3475d1 c564047 e3475d1 69c3e34 e3475d1 c564047 e3475d1 c564047 e3475d1 c564047 e3475d1 c564047 e3475d1 69c3e34 c564047 e3475d1 69c3e34 c564047 69c3e34 c564047 e3475d1 571b9f0 e3475d1 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 |
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.pgt import (
PGT,
PGTCoherenceChecker,
PGTEditor,
PGTGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry
logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())
MODEL_FN = {
"PGTGenerator": PGTGenerator,
"PGTEditor": PGTEditor,
"PGTCoherenceChecker": PGTCoherenceChecker,
}
def run_inference(
model_type: str,
generator_task: str,
editor_task: str,
checker_task: str,
prompt: str,
second_prompt: str,
length: int,
k: int,
p: float,
):
kwargs = {"max_length": length, "top_k": k, "top_p": p}
if model_type == "PGTGenerator":
config = PGTGenerator(task=generator_task, input_text=prompt, **kwargs)
elif model_type == "PGTEditor":
config = PGTEditor(input_type=editor_task, input_text=prompt, **kwargs)
elif model_type == "PGTCoherenceChecker":
config = PGTCoherenceChecker(
coherence_type=checker_task, input_a=prompt, input_b=second_prompt, **kwargs
)
model = PGT(config)
text = list(model.sample(1))[0]
return text
if __name__ == "__main__":
# Preparation (retrieve all available algorithms)
all_algos = ApplicationsRegistry.list_available()
algos = [
x["algorithm_application"]
for x in list(filter(lambda x: "PGT" in x["algorithm_name"], all_algos))
]
# Load metadata
metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")
examples = pd.read_csv(
metadata_root.joinpath("examples.csv"), sep="|", header=None
).fillna("")
print("Examples: ", examples.values.tolist())
with open(metadata_root.joinpath("article.md"), "r") as f:
article = f.read()
with open(metadata_root.joinpath("description.md"), "r") as f:
description = f.read()
gen_tasks = [
"title-to-abstract",
"abstract-to-title",
"abstract-to-claim",
"claim-to-abstract",
]
demo = gr.Interface(
fn=run_inference,
title="Patent Generative Transformer",
inputs=[
gr.Dropdown(algos, label="Model type", value="PGTGenerator"),
gr.Dropdown(gen_tasks, label="Generator task", value="title-to-abstract"),
gr.Dropdown(["abstract", "claim"], label="Editor task", value="abstract"),
gr.Dropdown(
["title-abstract", "title-claim", "abstract-claim"],
label="Checker task",
value="title-abstract",
),
gr.Textbox(
label="Primary Text prompt",
placeholder="Artificial intelligence and machine learning infrastructure",
lines=5,
),
gr.Textbox(
label="Secondary text prompt (only coherence checker)",
placeholder="",
lines=1,
),
gr.Slider(
minimum=5, maximum=1024, value=512, label="Maximal length", step=1
),
gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1),
gr.Slider(minimum=0.5, maximum=1.0, value=0.95, label="Top-p"),
],
outputs=gr.Textbox(label="Output"),
article=article,
description=description,
examples=examples.values.tolist(),
)
demo.launch(debug=True, show_error=True)
|