import logging import pathlib import gradio as gr import pandas as pd from gt4sd.algorithms.generation.pgt import ( PGT, PGTCoherenceChecker, PGTEditor, PGTGenerator, ) from gt4sd.algorithms.registry import ApplicationsRegistry logger = logging.getLogger(__name__) logger.addHandler(logging.NullHandler()) MODEL_FN = { "PGTGenerator": PGTGenerator, "PGTEditor": PGTEditor, "PGTCoherenceChecker": PGTCoherenceChecker, } def run_inference( model_type: str, generator_task: str, editor_task: str, checker_task: str, prompt: str, second_prompt: str, length: int, k: int, p: float, ): kwargs = {"max_length": length, "top_k": k, "top_p": p} if model_type == "PGTGenerator": config = PGTGenerator(task=generator_task, input_text=prompt, **kwargs) elif model_type == "PGTEditor": config = PGTEditor(input_type=editor_task, input_text=prompt, **kwargs) elif model_type == "PGTCoherenceChecker": config = PGTCoherenceChecker( coherence_type=checker_task, input_a=prompt, input_b=second_prompt, **kwargs ) model = PGT(config) text = list(model.sample(1))[0] return text if __name__ == "__main__": # Preparation (retrieve all available algorithms) all_algos = ApplicationsRegistry.list_available() algos = [ x["algorithm_application"] for x in list(filter(lambda x: "PGT" in x["algorithm_name"], all_algos)) ] # Load metadata metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards") examples = pd.read_csv( metadata_root.joinpath("examples.csv"), sep="|", header=None ).fillna("") print("Examples: ", examples.values.tolist()) with open(metadata_root.joinpath("article.md"), "r") as f: article = f.read() with open(metadata_root.joinpath("description.md"), "r") as f: description = f.read() gen_tasks = [ "title-to-abstract", "abstract-to-title", "abstract-to-claim", "claim-to-abstract", ] demo = gr.Interface( fn=run_inference, title="Patent Generative Transformer", inputs=[ gr.Dropdown(algos, label="Model type", value="PGTGenerator"), gr.Dropdown(gen_tasks, label="Generator task", value="title-to-abstract"), gr.Dropdown(["abstract", "claim"], label="Editor task", value="abstract"), gr.Dropdown( ["title-abstract", "title-claim", "abstract-claim"], label="Checker task", value="title-abstract", ), gr.Textbox( label="Primary Text prompt", placeholder="Artificial intelligence and machine learning infrastructure", lines=5, ), gr.Textbox( label="Secondary text prompt (only coherence checker)", placeholder="", lines=1, ), gr.Slider( minimum=5, maximum=1024, value=512, label="Maximal length", step=1 ), gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1), gr.Slider(minimum=0.5, maximum=1.0, value=0.95, label="Top-p"), ], outputs=gr.Textbox(label="Output"), article=article, description=description, examples=examples.values.tolist(), ) demo.launch(debug=True, show_error=True)