File size: 4,875 Bytes
d1dc1ec
 
 
 
 
9ba191d
 
 
d1dc1ec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
import gradio as gr

import torch
import numpy as np
from transformers import AutoTokenizer
from src.modeling.modeling_bert import BertForSequenceClassification
from src.modeling.modeling_electra import ElectraForSequenceClassification
from src.attention_rollout import AttentionRollout

import seaborn as sns
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import matplotlib.gridspec as gridspec
import matplotlib.backends.backend_pdf

def inference(text, model):
  if model == "bert-base-uncased-cls-sst2":
    config = {
        # As of now, BERT and ELECTRA are supported. You can choose any checkpoing of these models.
        ### BERT-base
        "MODEL": "TehranNLP-org/bert-base-uncased-cls-sst2"
        # "MODEL": "TehranNLP-org/bert-base-uncased-cls-mnli"
        # "MODEL": "TehranNLP-org/bert-base-uncased-cls-hatexplain"
        ### BERT-large
        # "MODEL": "TehranNLP-org/bert-large-sst2"
        # "MODEL": "TehranNLP-org/bert-large-mnli"
        # "MODEL": "TehranNLP-org/bert-large-hateXplain"
        ### ELECTRA
        # "MODEL": "TehranNLP-org/electra-base-sst2"
        # "MODEL": "TehranNLP-org/electra-base-mnli"
        # "MODEL": "TehranNLP-org/electra-base-hateXplain"
    }
  elif model == "bert-large-sst2":
    config = {
        # As of now, BERT and ELECTRA are supported. You can choose any checkpoing of these models.
        ### BERT-base
        #"MODEL": "TehranNLP-org/bert-base-uncased-cls-sst2"
        # "MODEL": "TehranNLP-org/bert-base-uncased-cls-mnli"
        # "MODEL": "TehranNLP-org/bert-base-uncased-cls-hatexplain"
        ### BERT-large
        "MODEL": "TehranNLP-org/bert-large-sst2"
        # "MODEL": "TehranNLP-org/bert-large-mnli"
        # "MODEL": "TehranNLP-org/bert-large-hateXplain"
        ### ELECTRA
        # "MODEL": "TehranNLP-org/electra-base-sst2"
        # "MODEL": "TehranNLP-org/electra-base-mnli"
        # "MODEL": "TehranNLP-org/electra-base-hateXplain"
    }
  else:
   config = {
        # As of now, BERT and ELECTRA are supported. You can choose any checkpoing of these models.
        ### BERT-base
        #"MODEL": "TehranNLP-org/bert-base-uncased-cls-sst2"
        # "MODEL": "TehranNLP-org/bert-base-uncased-cls-mnli"
        # "MODEL": "TehranNLP-org/bert-base-uncased-cls-hatexplain"
        ### BERT-large
        #"MODEL": "TehranNLP-org/bert-large-sst2"
        # "MODEL": "TehranNLP-org/bert-large-mnli"
        # "MODEL": "TehranNLP-org/bert-large-hateXplain"
        ### ELECTRA
        "MODEL": "TehranNLP-org/electra-base-sst2"
        # "MODEL": "TehranNLP-org/electra-base-mnli"
        # "MODEL": "TehranNLP-org/electra-base-hateXplain"
    }
  SENTENCE = text
  
  tokenizer = AutoTokenizer.from_pretrained(config["MODEL"])
  tokenized_sentence = tokenizer.encode_plus(SENTENCE, return_tensors="pt")
  if "bert" in config["MODEL"]:
      model = BertForSequenceClassification.from_pretrained(config["MODEL"])
  elif "electra" in config["MODEL"]:
      model = ElectraForSequenceClassification.from_pretrained(config["MODEL"])
  else:
      raise Exception(f"Not implented model: {config['MODEL']}")
      
  # Extract single layer attentions
  with torch.no_grad():
      logits, attentions, norms = model(**tokenized_sentence, output_attentions=True, output_norms=True, return_dict=False)
      num_layers = len(attentions)
      norm_nenc = torch.stack([norms[i][4] for i in range(num_layers)]).squeeze().cpu().numpy()
      print("Single layer N-Enc token attribution:", norm_nenc.shape)
  
      # Aggregate and compute GlobEnc
      globenc = AttentionRollout().compute_flows([norm_nenc], output_hidden_states=True)[0]
      globenc = np.array(globenc)
      print("Aggregated N-Enc token attribution (GlobEnc):", globenc.shape)
      


  tokenized_text = tokenizer.convert_ids_to_tokens(tokenized_sentence["input_ids"][0])
  plt.figure(figsize=(14, 8))
  norm_cls = globenc[:, 0, :]
  norm_cls = np.flip(norm_cls, axis=0)
  row_sums = norm_cls.max(axis=1)
  norm_cls = norm_cls / row_sums[:, np.newaxis]
  df = pd.DataFrame(norm_cls, columns=tokenized_text, index=range(len(norm_cls), 0, -1))
  ax = sns.heatmap(df, cmap="Reds", square=True)
  bottom, top = ax.get_ylim()
  ax.set_ylim(bottom + 0.5, top - 0.5)
  plt.title("GlobEnc", fontsize=16)
  plt.ylabel("Layer", fontsize=16)
  plt.xticks(rotation = 90, fontsize=16)
  plt.yticks(fontsize=13)
  plt.gcf().subplots_adjust(bottom=0.2)
  print("logits:", logits)
  
  return plt
  
demo = gr.Blocks()

with demo:
    gr.Markdown(
    """
    # Hello World!
    Start typing below to see the output.
    """)
    inp = [gr.Textbox(),gr.Dropdown(choices=['bert-base-uncased-cls-sst2','bert-large-sst2','electra-base-sst2'])]
    out = gr.Plot()
    
    button = gr.Button(value="Run")
    button.click(fn=inference, 
               inputs=inp, 
               outputs=out)

demo.launch()