akhaliq HF staff commited on
Commit
d1dc1ec
1 Parent(s): 04c51ab

Create new file

Browse files
Files changed (1) hide show
  1. app.py +126 -0
app.py ADDED
@@ -0,0 +1,126 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import gradio as gr
2
+
3
+ import torch
4
+ import numpy as np
5
+ from transformers import AutoTokenizer
6
+ from GlobEnc.src.modeling.modeling_bert import BertForSequenceClassification
7
+ from GlobEnc.src.modeling.modeling_electra import ElectraForSequenceClassification
8
+ from GlobEnc.src.attention_rollout import AttentionRollout
9
+
10
+ import seaborn as sns
11
+ import pandas as pd
12
+ import numpy as np
13
+ import matplotlib.pyplot as plt
14
+ import matplotlib.gridspec as gridspec
15
+ import matplotlib.backends.backend_pdf
16
+
17
+ def inference(text, model):
18
+ if model == "bert-base-uncased-cls-sst2":
19
+ config = {
20
+ # As of now, BERT and ELECTRA are supported. You can choose any checkpoing of these models.
21
+ ### BERT-base
22
+ "MODEL": "TehranNLP-org/bert-base-uncased-cls-sst2"
23
+ # "MODEL": "TehranNLP-org/bert-base-uncased-cls-mnli"
24
+ # "MODEL": "TehranNLP-org/bert-base-uncased-cls-hatexplain"
25
+ ### BERT-large
26
+ # "MODEL": "TehranNLP-org/bert-large-sst2"
27
+ # "MODEL": "TehranNLP-org/bert-large-mnli"
28
+ # "MODEL": "TehranNLP-org/bert-large-hateXplain"
29
+ ### ELECTRA
30
+ # "MODEL": "TehranNLP-org/electra-base-sst2"
31
+ # "MODEL": "TehranNLP-org/electra-base-mnli"
32
+ # "MODEL": "TehranNLP-org/electra-base-hateXplain"
33
+ }
34
+ elif model == "bert-large-sst2":
35
+ config = {
36
+ # As of now, BERT and ELECTRA are supported. You can choose any checkpoing of these models.
37
+ ### BERT-base
38
+ #"MODEL": "TehranNLP-org/bert-base-uncased-cls-sst2"
39
+ # "MODEL": "TehranNLP-org/bert-base-uncased-cls-mnli"
40
+ # "MODEL": "TehranNLP-org/bert-base-uncased-cls-hatexplain"
41
+ ### BERT-large
42
+ "MODEL": "TehranNLP-org/bert-large-sst2"
43
+ # "MODEL": "TehranNLP-org/bert-large-mnli"
44
+ # "MODEL": "TehranNLP-org/bert-large-hateXplain"
45
+ ### ELECTRA
46
+ # "MODEL": "TehranNLP-org/electra-base-sst2"
47
+ # "MODEL": "TehranNLP-org/electra-base-mnli"
48
+ # "MODEL": "TehranNLP-org/electra-base-hateXplain"
49
+ }
50
+ else:
51
+ config = {
52
+ # As of now, BERT and ELECTRA are supported. You can choose any checkpoing of these models.
53
+ ### BERT-base
54
+ #"MODEL": "TehranNLP-org/bert-base-uncased-cls-sst2"
55
+ # "MODEL": "TehranNLP-org/bert-base-uncased-cls-mnli"
56
+ # "MODEL": "TehranNLP-org/bert-base-uncased-cls-hatexplain"
57
+ ### BERT-large
58
+ #"MODEL": "TehranNLP-org/bert-large-sst2"
59
+ # "MODEL": "TehranNLP-org/bert-large-mnli"
60
+ # "MODEL": "TehranNLP-org/bert-large-hateXplain"
61
+ ### ELECTRA
62
+ "MODEL": "TehranNLP-org/electra-base-sst2"
63
+ # "MODEL": "TehranNLP-org/electra-base-mnli"
64
+ # "MODEL": "TehranNLP-org/electra-base-hateXplain"
65
+ }
66
+ SENTENCE = text
67
+
68
+ tokenizer = AutoTokenizer.from_pretrained(config["MODEL"])
69
+ tokenized_sentence = tokenizer.encode_plus(SENTENCE, return_tensors="pt")
70
+ if "bert" in config["MODEL"]:
71
+ model = BertForSequenceClassification.from_pretrained(config["MODEL"])
72
+ elif "electra" in config["MODEL"]:
73
+ model = ElectraForSequenceClassification.from_pretrained(config["MODEL"])
74
+ else:
75
+ raise Exception(f"Not implented model: {config['MODEL']}")
76
+
77
+ # Extract single layer attentions
78
+ with torch.no_grad():
79
+ logits, attentions, norms = model(**tokenized_sentence, output_attentions=True, output_norms=True, return_dict=False)
80
+ num_layers = len(attentions)
81
+ norm_nenc = torch.stack([norms[i][4] for i in range(num_layers)]).squeeze().cpu().numpy()
82
+ print("Single layer N-Enc token attribution:", norm_nenc.shape)
83
+
84
+ # Aggregate and compute GlobEnc
85
+ globenc = AttentionRollout().compute_flows([norm_nenc], output_hidden_states=True)[0]
86
+ globenc = np.array(globenc)
87
+ print("Aggregated N-Enc token attribution (GlobEnc):", globenc.shape)
88
+
89
+
90
+
91
+ tokenized_text = tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][0])
92
+ plt.figure(figsize=(14, 8))
93
+ norm_cls = globenc[:, 0, :]
94
+ norm_cls = np.flip(norm_cls, axis=0)
95
+ row_sums = norm_cls.max(axis=1)
96
+ norm_cls = norm_cls / row_sums[:, np.newaxis]
97
+ df = pd.DataFrame(norm_cls, columns=tokenized_text, index=range(len(norm_cls), 0, -1))
98
+ ax = sns.heatmap(df, cmap="Reds", square=True)
99
+ bottom, top = ax.get_ylim()
100
+ ax.set_ylim(bottom + 0.5, top - 0.5)
101
+ plt.title("GlobEnc", fontsize=16)
102
+ plt.ylabel("Layer", fontsize=16)
103
+ plt.xticks(rotation = 90, fontsize=16)
104
+ plt.yticks(fontsize=13)
105
+ plt.gcf().subplots_adjust(bottom=0.2)
106
+ print("logits:", logits)
107
+
108
+ return plt
109
+
110
+ demo = gr.Blocks()
111
+
112
+ with demo:
113
+ gr.Markdown(
114
+ """
115
+ # Hello World!
116
+ Start typing below to see the output.
117
+ """)
118
+ inp = [gr.Textbox(),gr.Dropdown(choices=['bert-base-uncased-cls-sst2','bert-large-sst2','electra-base-sst2'])]
119
+ out = gr.Plot()
120
+
121
+ button = gr.Button(value="Run")
122
+ button.click(fn=inference,
123
+ inputs=inp,
124
+ outputs=out)
125
+
126
+ demo.launch()