File size: 3,472 Bytes
e3475d1
 
 
 
c564047
 
 
 
 
e3475d1
 
 
 
 
 
 
 
c564047
 
 
e3475d1
 
 
 
 
c564047
 
 
e3475d1
c564047
 
 
e3475d1
 
 
c564047
 
 
 
 
 
 
 
 
 
 
 
e3475d1
 
 
 
 
 
 
 
 
 
c564047
 
e3475d1
 
 
 
 
69c3e34
 
 
e3475d1
 
 
 
 
 
 
c564047
 
 
 
 
 
 
e3475d1
 
c564047
e3475d1
c564047
 
 
e3475d1
c564047
 
 
e3475d1
 
69c3e34
c564047
 
e3475d1
 
69c3e34
c564047
69c3e34
c564047
 
 
e3475d1
 
571b9f0
e3475d1
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
import logging
import pathlib
import gradio as gr
import pandas as pd
from gt4sd.algorithms.generation.pgt import (
    PGT,
    PGTCoherenceChecker,
    PGTEditor,
    PGTGenerator,
)
from gt4sd.algorithms.registry import ApplicationsRegistry


logger = logging.getLogger(__name__)
logger.addHandler(logging.NullHandler())

MODEL_FN = {
    "PGTGenerator": PGTGenerator,
    "PGTEditor": PGTEditor,
    "PGTCoherenceChecker": PGTCoherenceChecker,
}


def run_inference(
    model_type: str,
    generator_task: str,
    editor_task: str,
    checker_task: str,
    prompt: str,
    second_prompt: str,
    length: int,
    k: int,
    p: float,
):

    kwargs = {"max_length": length, "top_k": k, "top_p": p}

    if model_type == "PGTGenerator":
        config = PGTGenerator(task=generator_task, input_text=prompt, **kwargs)
    elif model_type == "PGTEditor":
        config = PGTEditor(input_type=editor_task, input_text=prompt, **kwargs)
    elif model_type == "PGTCoherenceChecker":
        config = PGTCoherenceChecker(
            coherence_type=checker_task, input_a=prompt, input_b=second_prompt, **kwargs
        )

    model = PGT(config)
    text = list(model.sample(1))[0]

    return text


if __name__ == "__main__":

    # Preparation (retrieve all available algorithms)
    all_algos = ApplicationsRegistry.list_available()
    algos = [
        x["algorithm_application"]
        for x in list(filter(lambda x: "PGT" in x["algorithm_name"], all_algos))
    ]

    # Load metadata
    metadata_root = pathlib.Path(__file__).parent.joinpath("model_cards")

    examples = pd.read_csv(
        metadata_root.joinpath("examples.csv"), sep="|", header=None
    ).fillna("")
    print("Examples: ", examples.values.tolist())

    with open(metadata_root.joinpath("article.md"), "r") as f:
        article = f.read()
    with open(metadata_root.joinpath("description.md"), "r") as f:
        description = f.read()

    gen_tasks = [
        "title-to-abstract",
        "abstract-to-title",
        "abstract-to-claim",
        "claim-to-abstract",
    ]

    demo = gr.Interface(
        fn=run_inference,
        title="Patent Generative Transformer",
        inputs=[
            gr.Dropdown(algos, label="Model type", value="PGTGenerator"),
            gr.Dropdown(gen_tasks, label="Generator task", value="title-to-abstract"),
            gr.Dropdown(["abstract", "claim"], label="Editor task", value="abstract"),
            gr.Dropdown(
                ["title-abstract", "title-claim", "abstract-claim"],
                label="Checker task",
                value="title-abstract",
            ),
            gr.Textbox(
                label="Primary Text prompt",
                placeholder="Artificial intelligence and machine learning infrastructure",
                lines=5,
            ),
            gr.Textbox(
                label="Secondary text prompt (only coherence checker)",
                placeholder="",
                lines=1,
            ),
            gr.Slider(
                minimum=5, maximum=1024, value=512, label="Maximal length", step=1
            ),
            gr.Slider(minimum=2, maximum=500, value=50, label="Top-k", step=1),
            gr.Slider(minimum=0.5, maximum=1.0, value=0.95, label="Top-p"),
        ],
        outputs=gr.Textbox(label="Output"),
        article=article,
        description=description,
        examples=examples.values.tolist(),
    )
    demo.launch(debug=True, show_error=True)