Spaces:
Sleeping
Sleeping
Fix multi-class
Browse files
app.py
CHANGED
@@ -1,7 +1,9 @@
|
|
1 |
import gradio as gr
|
2 |
from transformers import BartForSequenceClassification, BartTokenizer
|
3 |
-
import torch.nn.functional as F
|
4 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
|
|
|
|
|
|
5 |
|
6 |
te_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
|
7 |
te_model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
|
@@ -10,15 +12,15 @@ qa_model = T5ForConditionalGeneration.from_pretrained("google/flan-t5-small", de
|
|
10 |
|
11 |
def predict(context, intent, multi_class):
|
12 |
input_text = "In one word, what is the opposite of: " + intent + "?"
|
13 |
-
input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids
|
14 |
opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
|
15 |
input_text = "In one word, what is the following describing: " + context
|
16 |
-
input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids
|
17 |
object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
|
18 |
batch = ['I think the ' + object_output + ' are long.', 'I think the ' + object_output + ' are ' + opposite_output, 'I think the ' + object_output + ' are the perfect']
|
19 |
outputs = []
|
20 |
for i, hypothesis in enumerate(batch):
|
21 |
-
input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt')
|
22 |
# -> [contradiction, neutral, entailment]
|
23 |
logits = te_model(input_ids)[0][0]
|
24 |
|
@@ -38,20 +40,19 @@ def predict(context, intent, multi_class):
|
|
38 |
pn_tensor[2] = pn_tensor[2] * outputs[2][1]
|
39 |
pn_tensor[0] = pn_tensor[0] * outputs[2][1]
|
40 |
|
41 |
-
pn_tensor = F.normalize(pn_tensor, p=1, dim=0)
|
42 |
if (multi_class):
|
43 |
-
pn_tensor =
|
44 |
else:
|
45 |
pn_tensor = pn_tensor.softmax(dim=0)
|
46 |
pn_tensor = pn_tensor.tolist()
|
47 |
-
return {"
|
48 |
|
49 |
gradio_app = gr.Interface(
|
50 |
predict,
|
51 |
-
inputs=[gr.Text("Sentence"), gr.Text("Class"), gr.Checkbox("Allow multiple true classes")],
|
52 |
outputs=[gr.Label(num_top_classes=3)],
|
53 |
title="Intent Analysis",
|
|
|
54 |
)
|
55 |
|
56 |
-
|
57 |
-
gradio_app.launch()
|
|
|
1 |
import gradio as gr
|
2 |
from transformers import BartForSequenceClassification, BartTokenizer
|
|
|
3 |
from transformers import T5Tokenizer, T5ForConditionalGeneration
|
4 |
+
import torch
|
5 |
+
|
6 |
+
device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
|
7 |
|
8 |
te_tokenizer = BartTokenizer.from_pretrained('facebook/bart-large-mnli')
|
9 |
te_model = BartForSequenceClassification.from_pretrained('facebook/bart-large-mnli')
|
|
|
12 |
|
13 |
def predict(context, intent, multi_class):
|
14 |
input_text = "In one word, what is the opposite of: " + intent + "?"
|
15 |
+
input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
16 |
opposite_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
|
17 |
input_text = "In one word, what is the following describing: " + context
|
18 |
+
input_ids = qa_tokenizer(input_text, return_tensors="pt").input_ids.to(device)
|
19 |
object_output = qa_tokenizer.decode(qa_model.generate(input_ids, max_length=2)[0])
|
20 |
batch = ['I think the ' + object_output + ' are long.', 'I think the ' + object_output + ' are ' + opposite_output, 'I think the ' + object_output + ' are the perfect']
|
21 |
outputs = []
|
22 |
for i, hypothesis in enumerate(batch):
|
23 |
+
input_ids = te_tokenizer.encode(context, hypothesis, return_tensors='pt').to(device)
|
24 |
# -> [contradiction, neutral, entailment]
|
25 |
logits = te_model(input_ids)[0][0]
|
26 |
|
|
|
40 |
pn_tensor[2] = pn_tensor[2] * outputs[2][1]
|
41 |
pn_tensor[0] = pn_tensor[0] * outputs[2][1]
|
42 |
|
|
|
43 |
if (multi_class):
|
44 |
+
pn_tensor = torch.sigmoid(pn_tensor)
|
45 |
else:
|
46 |
pn_tensor = pn_tensor.softmax(dim=0)
|
47 |
pn_tensor = pn_tensor.tolist()
|
48 |
+
return {"agree": pn_tensor[0], "neutral": pn_tensor[1], "disagree": pn_tensor[2]}
|
49 |
|
50 |
gradio_app = gr.Interface(
|
51 |
predict,
|
52 |
+
inputs=[gr.Text(label="Sentence"), gr.Text(label="Class"), gr.Checkbox(label="Allow multiple true classes")],
|
53 |
outputs=[gr.Label(num_top_classes=3)],
|
54 |
title="Intent Analysis",
|
55 |
+
description="This model predicts whether or not the **class** describes the **object described in the sentence.**"
|
56 |
)
|
57 |
|
58 |
+
gradio_app.launch()
|
|