anonICPC commited on
Commit
5b90262
1 Parent(s): 9c4d60a

Upload comp.py

Browse files
Files changed (1) hide show
  1. comp.py +283 -0
comp.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ #!/usr/bin/env python
2
+ # coding: utf-8
3
+
4
+ # In[1]:
5
+
6
+
7
+ import pandas as pd
8
+ import numpy as np
9
+ import torch
10
+ from torch import nn
11
+ from torch.nn import init, MarginRankingLoss
12
+ from torch.optim import Adam
13
+ from distutils.version import LooseVersion
14
+ from torch.utils.data import Dataset, DataLoader
15
+ from torch.autograd import Variable
16
+ import math
17
+ from transformers import AutoConfig, AutoModel, AutoTokenizer
18
+ import nltk
19
+ import re
20
+ import torch.optim as optim
21
+ from tqdm import tqdm
22
+ from transformers import AutoModelForMaskedLM
23
+ import torch.nn.functional as F
24
+ import random
25
+
26
+
27
+ # In[2]:
28
+
29
+
30
+ maskis = []
31
+ n_y = []
32
+ class MyDataset(Dataset):
33
+ def __init__(self,file_name):
34
+ global maskis
35
+ global n_y
36
+ df = pd.read_csv(file_name)
37
+ df = df.fillna("")
38
+ self.inp_dicts = []
39
+ for r in range(df.shape[0]):
40
+ X_init = df['X'][r]
41
+ y = df['y'][r]
42
+ n_y.append(y)
43
+ nl = re.findall(r'[A-Z](?:[a-z]+|[A-Z]*(?=[A-Z]|$))|[a-z]+|\d+', y)
44
+ lb = ' '.join(nl).lower()
45
+ x = tokenizer.tokenize(lb)
46
+ num_sub_tokens_label = len(x)
47
+ X_init = X_init.replace("[MASK]", " ".join([tokenizer.mask_token] * num_sub_tokens_label))
48
+ tokens = tokenizer.encode_plus(X_init, add_special_tokens=False,return_tensors='pt')
49
+ input_id_chunki = tokens['input_ids'][0].split(510)
50
+ input_id_chunks = []
51
+ mask_chunks = []
52
+ mask_chunki = tokens['attention_mask'][0].split(510)
53
+ for tensor in input_id_chunki:
54
+ input_id_chunks.append(tensor)
55
+ for tensor in mask_chunki:
56
+ mask_chunks.append(tensor)
57
+ xi = torch.full((1,), fill_value=101)
58
+ yi = torch.full((1,), fill_value=1)
59
+ zi = torch.full((1,), fill_value=102)
60
+ for r in range(len(input_id_chunks)):
61
+ input_id_chunks[r] = torch.cat([xi, input_id_chunks[r]],dim = -1)
62
+ input_id_chunks[r] = torch.cat([input_id_chunks[r],zi],dim=-1)
63
+ mask_chunks[r] = torch.cat([yi, mask_chunks[r]],dim=-1)
64
+ mask_chunks[r] = torch.cat([mask_chunks[r],yi],dim=-1)
65
+ di = torch.full((1,), fill_value=0)
66
+ for i in range(len(input_id_chunks)):
67
+ pad_len = 512 - input_id_chunks[i].shape[0]
68
+ if pad_len > 0:
69
+ for p in range(pad_len):
70
+ input_id_chunks[i] = torch.cat([input_id_chunks[i],di],dim=-1)
71
+ mask_chunks[i] = torch.cat([mask_chunks[i],di],dim=-1)
72
+ vb = torch.ones_like(input_id_chunks[0])
73
+ fg = torch.zeros_like(input_id_chunks[0])
74
+ maski = []
75
+ for l in range(len(input_id_chunks)):
76
+ masked_pos = []
77
+ for i in range(len(input_id_chunks[l])):
78
+ if input_id_chunks[l][i] == tokenizer.mask_token_id: #103
79
+ if i != 0 and input_id_chunks[l][i-1] == tokenizer.mask_token_id:
80
+ continue
81
+ masked_pos.append(i)
82
+ maski.append(masked_pos)
83
+ maskis.append(maski)
84
+ while (len(input_id_chunks)<250):
85
+ input_id_chunks.append(vb)
86
+ mask_chunks.append(fg)
87
+ input_ids = torch.stack(input_id_chunks)
88
+ attention_mask = torch.stack(mask_chunks)
89
+ input_dict = {
90
+ 'input_ids': input_ids.long(),
91
+ 'attention_mask': attention_mask.int()
92
+ }
93
+ self.inp_dicts.append(input_dict)
94
+ del input_dict
95
+ del input_ids
96
+ del attention_mask
97
+ del maski
98
+ del mask_chunks
99
+ del input_id_chunks
100
+ del di
101
+ del fg
102
+ del vb
103
+ del mask_chunki
104
+ del input_id_chunki
105
+ del X_init
106
+ del y
107
+ del tokens
108
+ del x
109
+ del lb
110
+ del nl
111
+ del df
112
+ def __len__(self):
113
+ return len(self.inp_dicts)
114
+ def __getitem__(self,idx):
115
+ return self.inp_dicts[idx]
116
+
117
+
118
+ # In[3]:
119
+
120
+
121
+ tokenizer = AutoTokenizer.from_pretrained("microsoft/graphcodebert-base")
122
+ model = AutoModelForMaskedLM.from_pretrained("microsoft/graphcodebert-base")
123
+ base_model = AutoModelForMaskedLM.from_pretrained("microsoft/graphcodebert-base")
124
+ model.load_state_dict(torch.load('var_runs/model_26_2'))
125
+ model.eval()
126
+ base_model.eval()
127
+ myDs=MyDataset('d_t.csv')
128
+ train_loader=DataLoader(myDs,batch_size=1,shuffle=False)
129
+
130
+
131
+ # In[4]:
132
+
133
+
134
+ variable_names = [
135
+ # One-word Variable Names
136
+ 'count', 'value', 'result', 'flag', 'max', 'min', 'data', 'input', 'output', 'name', 'index', 'status', 'error', 'message', 'price', 'quantity', 'total', 'length', 'size', 'score',
137
+
138
+ # Two-word Variable Names
139
+ 'studentName', 'accountBalance', 'isFound', 'maxScore', 'userAge', 'carModel', 'bookTitle', 'arrayLength', 'employeeID', 'itemPrice', 'customerAddress', 'productCategory', 'orderNumber', 'transactionType', 'bankAccount', 'shippingMethod', 'deliveryDate', 'purchaseAmount', 'inventoryItem', 'salesRevenue',
140
+
141
+ # Three-word Variable Names
142
+ 'numberOfStudents', 'averageTemperature', 'userIsLoggedIn', 'totalSalesAmount', 'employeeSalaryRate', 'maxAllowedAttempts', 'selectedOption', 'shippingAddress', 'manufacturingDate', 'connectionPool', 'customerAccountBalance', 'employeeSalaryReport', 'productInventoryCount', 'transactionProcessingStatus', 'userAuthenticationToken', 'orderShippingAddress', 'databaseConnectionPoolSize', 'vehicleEngineTemperature', 'sensorDataProcessingRate', 'employeePayrollSystem',
143
+
144
+ # Four-word Variable Names
145
+ 'customerAccountBalanceValue', 'employeeSalaryReportData', 'productInventoryItemCount', 'transactionProcessingStatusFlag', 'userAuthenticationTokenKey', 'orderShippingAddressDetails', 'databaseConnectionPoolMaxSize', 'vehicleEngineTemperatureReading', 'sensorDataProcessingRateLimit', 'employeePayrollSystemData', 'customerOrderShippingAddress', 'productCatalogItemNumber', 'transactionProcessingSuccessFlag', 'userAuthenticationAccessToken', 'databaseConnectionPoolConfig', 'vehicleEngineTemperatureSensor', 'sensorDataProcessingRateLimitation', 'employeePayrollSystemConfiguration', 'customerAccountBalanceHistoryData', 'transactionProcessingStatusTracking'
146
+ ]
147
+ var_list = []
148
+ for j in range(6):
149
+ d =[]
150
+ var_list.append(d)
151
+ for var in variable_names:
152
+ try:
153
+ var_list[len(tokenizer.tokenize(var))-1].append(var)
154
+ except:
155
+ continue
156
+
157
+
158
+ # In[5]:
159
+
160
+
161
+ tot_pll = 0.0
162
+ base_tot_pll = 0.0
163
+ loop = tqdm(train_loader, leave=True)
164
+ cntr = 0
165
+ for batch in loop:
166
+ maxi = torch.tensor(0.0, requires_grad=True)
167
+ for i in range(len(batch['input_ids'])):
168
+ cntr+=1
169
+ maski = maskis[cntr-1]
170
+ li = len(maski)
171
+ input_ids = batch['input_ids'][i][:li]
172
+ att_mask = batch['attention_mask'][i][:li]
173
+ y = n_y[cntr-1]
174
+ ty = tokenizer.encode(y)[1:-1]
175
+ num_sub_tokens_label = len(ty)
176
+ if num_sub_tokens_label > 6:
177
+ continue
178
+ print("Ground truth:", y)
179
+ m_y = random.choice(var_list[num_sub_tokens_label-1])
180
+ m_ty = tokenizer.encode(m_y)[1:-1]
181
+ print("Mock truth:", m_y)
182
+ # input_ids, att_mask = input_ids.to(device),att_mask.to(device)
183
+ outputs = model(input_ids, attention_mask = att_mask)
184
+ base_outputs = base_model(input_ids, attention_mask = att_mask)
185
+ last_hidden_state = outputs[0].squeeze()
186
+ base_last_hidden_state = base_outputs[0].squeeze()
187
+ l_o_l_sa = []
188
+ base_l_o_l_sa = []
189
+ sum_state = []
190
+ base_sum_state = []
191
+ for t in range(num_sub_tokens_label):
192
+ c = []
193
+ d = []
194
+ l_o_l_sa.append(c)
195
+ base_l_o_l_sa.append(d)
196
+ if len(maski) == 1:
197
+ masked_pos = maski[0]
198
+ for k in masked_pos:
199
+ for t in range(num_sub_tokens_label):
200
+ l_o_l_sa[t].append(last_hidden_state[k+t])
201
+ base_l_o_l_sa[t].append(base_last_hidden_state[k+t])
202
+ else:
203
+ for p in range(len(maski)):
204
+ masked_pos = maski[p]
205
+ for k in masked_pos:
206
+ for t in range(num_sub_tokens_label):
207
+ if (k+t) >= len(last_hidden_state[p]):
208
+ l_o_l_sa[t].append(last_hidden_state[p+1][k+t-len(last_hidden_state[p])])
209
+ base_l_o_l_sa[t].append(base_last_hidden_state[p+1][k+t-len(base_last_hidden_state[p])])
210
+ continue
211
+ l_o_l_sa[t].append(last_hidden_state[p][k+t])
212
+ base_l_o_l_sa[t].append(base_last_hidden_state[p][k+t])
213
+ for t in range(num_sub_tokens_label):
214
+ sum_state.append(l_o_l_sa[t][0])
215
+ base_sum_state.append(base_l_o_l_sa[t][0])
216
+ for i in range(len(l_o_l_sa[0])):
217
+ if i == 0:
218
+ continue
219
+ for t in range(num_sub_tokens_label):
220
+ sum_state[t] = sum_state[t] + l_o_l_sa[t][i]
221
+ base_sum_state[t] = base_sum_state[t] + base_l_o_l_sa[t][i]
222
+ yip = len(l_o_l_sa[0])
223
+ val = 0.0
224
+ m_val = 0.0
225
+ m_base_val = 0.0
226
+ base_val = 0.0
227
+ for t in range(num_sub_tokens_label):
228
+ sum_state[t] /= yip
229
+ base_sum_state[t] /= yip
230
+ probs = F.softmax(sum_state[t], dim=0)
231
+ base_probs = F.softmax(base_sum_state[t], dim=0)
232
+ val = val - torch.log(probs[ty[t]])
233
+ m_val = m_val - torch.log(probs[m_ty[t]])
234
+ base_val = base_val - torch.log(base_probs[ty[t]])
235
+ m_base_val = m_base_val - torch.log(base_probs[m_ty[t]])
236
+ val = val / num_sub_tokens_label
237
+ base_val = base_val / num_sub_tokens_label
238
+ m_val = m_val / num_sub_tokens_label
239
+ m_base_val = m_base_val / num_sub_tokens_label
240
+ print("Sent PLL:")
241
+ print(val)
242
+ print("Base Sent PLL:")
243
+ print(base_val)
244
+ print("Net % difference:")
245
+ diff = (val-base_val)*100/base_val
246
+ print(diff)
247
+ tot_pll += val
248
+ base_tot_pll+=base_val
249
+ print()
250
+ print()
251
+ print("Mock Sent PLL:")
252
+ print(m_val)
253
+ print("Mock Base Sent PLL:")
254
+ print(m_base_val)
255
+ print("Mock Net % difference:")
256
+ m_diff = (m_val-m_base_val)*100/m_base_val
257
+ print(m_diff)
258
+ for c in sum_state:
259
+ del c
260
+ for d in base_sum_state:
261
+ del d
262
+ del sum_state
263
+ del base_sum_state
264
+ for c in l_o_l_sa:
265
+ del c
266
+ for c in base_l_o_l_sa:
267
+ del c
268
+ del l_o_l_sa
269
+ del base_l_o_l_sa
270
+ del maski
271
+ del input_ids
272
+ del att_mask
273
+ del last_hidden_state
274
+ del base_last_hidden_state
275
+ print("Tot PLL: ", tot_pll)
276
+ print("Base Tot PLL: ", base_tot_pll)
277
+
278
+
279
+ # In[ ]:
280
+
281
+
282
+
283
+