File size: 4,829 Bytes
15c0354
2bca1d4
 
bb13f04
 
 
15c0354
2bca1d4
2a8fa62
 
 
2bca1d4
78cf820
0b73704
bb13f04
0b73704
 
bb13f04
0b73704
 
 
2bca1d4
 
2a8fa62
 
2bca1d4
 
 
0b73704
2bca1d4
 
 
 
 
 
0b73704
 
2a8fa62
 
 
0b73704
 
2bca1d4
7a7170b
0b73704
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2a8fa62
 
2bca1d4
0b73704
78cf820
0b73704
 
78cf820
0b73704
 
2a8fa62
 
 
2bca1d4
 
 
2a8fa62
 
 
5f715e8
2a8fa62
 
2bca1d4
 
0b73704
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
import gradio as gr
from transformers import BartForSequenceClassification, BartTokenizer
from transformers import T5Tokenizer, T5ForConditionalGeneration
import torch

device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")

te_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
te_model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli', device_map="auto")
qa_tokenizer = T5Tokenizer.from_pretrained("google/flan-t5-large")
qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-large", device_map="auto")

def predict(context, intent, multi_class):
    input_text = "What is the opposite of " + intent + "?"
    input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
    opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0], skip_special_tokens=True)
    input_text = "What object is the following describing: " + context
    input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
    object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0], skip_special_tokens=True)
    batch = ['The ' + object_output + ' is ' + intent, 'The ' + object_output + ' is ' + opposite_output, 'The ' + object_output + ' is not ' + intent, 'The ' + object_output + ' is not ' + opposite_output]

    outputs = []
    for i, hypothesis in enumerate(batch):
        input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
        
        # -> [contradiction, neutral, entailment]
        logits = te_model(input_ids)[0][0]

        if (i >= 2):
            # -> [contradiction, entailment]
            probs = logits[[0,2]].softmax(dim=0)
        else:
            probs = logits.softmax(dim=0)
        outputs.append(probs)
        
    # calculate the stochastic vector for it being neither the positive or negative class
    perfect_prob = [0, 0]
    perfect_prob[1] = max(float(outputs[2][0]), float(outputs[3][0]))
    perfect_prob[0] = 1-perfect_prob[1]
    # -> [entailment, contradiction] for perfect
    
    # -> [entailment, neutral, contradiction] for positive
    outputs[0] = outputs[0].flip(dims=[0])
    
    # combine the negative and positive class by summing by the opposite of the negative class
    aggregated = (outputs[0] + outputs[1])/2
    
    # multiplying vectors
    aggregated[1] = aggregated[1] * perfect_prob[0]
    
    # if it is neither the positive or negative class, then it is more likely the neutral class, so adjust accordingly
    if (perfect_prob[0] > perfect_prob[1]):
        aggregated[2] = aggregated[2] * perfect_prob[1]
        aggregated[0] = aggregated[0] * perfect_prob[1]
    else:
        # if it is more likely the positive class, increase its probability by a scale of the probability of it not being perfect
        if (aggregated[0] > aggregated[2]):
            aggregated[2] = aggregated[2] * perfect_prob[0]
            aggregated[0] = aggregated[0] * perfect_prob[1]
        # if it is more likely the negative class, increase its probability by a scale of the probability of it not being perfect
        else:
            aggregated[2] = aggregated[2] * perfect_prob[1]
            aggregated[0] = aggregated[0] * perfect_prob[0]
    
    # to exagerate differences
    # this way 0 maps to 0
    aggregated = aggregated.exp()-1

    # multiple true classes
    if (multi_class):
        aggregated = torch.sigmoid(aggregated)
    # only one true class
    else:
        aggregated = aggregated.softmax(dim=0)
    aggregated = aggregated.tolist()
    return {"agree": aggregated[0], "neutral": aggregated[1], "disagree": aggregated[2]}, {"agree": outputs[0][0], "neutral": outputs[0][1], "disagree": outputs[0][2]}

examples = [["The pants fit great, even the waist will fit me fine once I'm back to my normal weight, but the bottom is what's large. You can roll up the bottom part of the legs, or the top at the waist band for hanging out at the house, but if you have one nearby, simply have them re-hemmed.", "long"]]

gradio_app = gr.Interface(
    predict,
    examples=examples,
    inputs=[gr.Text(label="Statement"), gr.Text(label="Class"), gr.Checkbox(label="Allow multiple true classes")],
    outputs=[gr.Label(num_top_classes=3, label="With Postprocessing"), gr.Label(num_top_classes=3, label="Without Postprocessing")],
    title="Intent Analysis",
    description="This model predicts whether or not the **_class_** describes the **_object described in the sentence_**. <br /> The two outputs shows what TE would predict with and without the postprocessing. An example edge case for normal TE is shown below. <br /> **_It is recommended that you clone the repository to speed up processing time_**.",
    cache_examples=True
)

gradio_app.launch(share=True)