imamnurby commited on
Commit
f084e2e
1 Parent(s): f778d0e

Create backend_utils.py

Browse files
Files changed (1) hide show
  1. backend_utils.py +461 -0
backend_utils.py ADDED
@@ -0,0 +1,461 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from cherche import retrieve
2
+ from sentence_transformers import SentenceTransformer, util
3
+ from transformers import RobertaTokenizer, RobertaModel, EncoderDecoderModel
4
+ from config import classifier_class_mapping, config
5
+ import pandas as pd
6
+ import numpy as np
7
+ import pickle
8
+ import torch
9
+ from sklearn.multiclass import OneVsRestClassifier
10
+ from sklearn.ensemble import RandomForestClassifier
11
+
12
+ class wrappedTokenizer(RobertaTokenizer):
13
+ def __call__(self, text_input):
14
+ return self.tokenize(text_input)
15
+
16
+ def generate_index(db):
17
+ db_cp = db.copy()
18
+ index_list = []
19
+ for id_, dirname in db_cp.values:
20
+ index_list.append(
21
+ {
22
+ 'id': id_,
23
+ 'library': dirname.lower()
24
+ })
25
+ return index_list
26
+
27
+ def load_db(db_metadata_path, db_constructor_path):
28
+ '''
29
+ Function to load dataframe
30
+
31
+ Params:
32
+ db_metadata_path (string): the path to the db_metadata file
33
+ db_constructor_path (string): the path to the db_constructor file
34
+
35
+ Output:
36
+ db_metadata (pandas dataframe): a dataframe containing metadata information about the library
37
+ db_constructor (pandas dataframe): a dataframe containing the mapping of library names to valid constructor
38
+ '''
39
+ db_metadata = pd.read_csv(db_metadata_path)
40
+ db_metadata.dropna(inplace=True)
41
+ db_constructor = pd.read_csv(db_constructor_path)
42
+ db_constructor.dropna(inplace=True)
43
+ return db_metadata, db_constructor
44
+
45
+
46
+
47
+ def load_retrieval_model_lexical(tokenizer_path, max_k, db_metadata):
48
+ '''
49
+ Function to load BM25 model
50
+
51
+ Params:
52
+ tokenizer_path (string): the path to a tokenizer (can be a path to either a huggingface model or local directory)
53
+ max_k (int): the maximum number of returned sequences
54
+ db_metadata (pandas dataframe): a dataframe containing metadata information about the library
55
+
56
+ Returns:
57
+ retrieval_model: a retrieval model
58
+ '''
59
+ # generate index
60
+ index_list = generate_index(db_metadata[['id', 'library']])
61
+
62
+ # load model
63
+ tokenizer = wrappedTokenizer.from_pretrained(tokenizer_path)
64
+ retrieval_model = retrieve.BM25Okapi(
65
+ key='id',
66
+ on='library',
67
+ documents=index_list,
68
+ k=max_k,
69
+ tokenizer=tokenizer
70
+ )
71
+ return retrieval_model
72
+
73
+
74
+ def load_retrieval_model_deep_learning(model_path, max_k, db_metadata):
75
+ '''
76
+ Function to load a deep learning-based model
77
+
78
+ Params:
79
+ model_path (string): the path to the model (can be a path to either a huggingface model or local directory)
80
+ max_k (int): the maximum number of returned sequences
81
+ db_metadata (pandas dataframe): a dataframe containing metadata information about the library
82
+
83
+ Returns:
84
+ retrieval_model: a retrieval model
85
+ '''
86
+ # generate index
87
+ index_list = generate_index(db_metadata[['id', 'library']])
88
+
89
+ # load model
90
+ retrieval_model = retrieve.Encoder(
91
+ key='id',
92
+ on='library',
93
+ encoder=SentenceTransformer(model_path).encode,
94
+ k=max_k,
95
+ path=f"../temp/dl.pkl"
96
+ )
97
+ retrieval_model = dl_retriever.add(documents=index_list)
98
+
99
+ return retrieval_model
100
+
101
+ def load_generative_model_codebert(model_path):
102
+ '''
103
+ Function load a generative model using codebert checkpoint
104
+
105
+ Params:
106
+ model_path (string): path to the model (can be a path to either a huggingface model or local directory)
107
+
108
+ Returns:
109
+ tokenizer: a huggingface tokenizer
110
+ generative_model: a generative model to generate API pattern given the library name as the input
111
+ '''
112
+ tokenizer = RobertaTokenizer.from_pretrained(model_path)
113
+ generative_model = EncoderDecoderModel.from_pretrained(model_path)
114
+ return tokenizer, generative_model
115
+
116
+
117
+ def get_metadata_library(predictions, db_metadata):
118
+ '''
119
+ Function to get the metadata of a library using the library unique id
120
+
121
+ Params:
122
+ predictions (list): a list of dictionary containing the prediction details
123
+ db_metadata: a dataframe containing metadata information about the library
124
+
125
+ Returns:
126
+ metadata_dict (dict): a dictionary where the key is the metadata type and the value is the metadata value
127
+ '''
128
+ predictions_cp = predictions.copy()
129
+ for prediction_dict in predictions_cp:
130
+ temp_db = db_metadata[db_metadata.id==prediction_dict.get('id')]
131
+ assert(len(temp_db)==1)
132
+
133
+ prediction_dict['Sensor Type'] = temp_db.iloc[0]['cat'].capitalize()
134
+ prediction_dict['Github URL'] = temp_db.iloc[0]['url']
135
+
136
+ # prefer the description from the arduino library list, if not found use the repo description
137
+ if temp_db.iloc[0].desc_ardulib != 'nan':
138
+ prediction_dict['Description'] = temp_db.iloc[0].desc_ardulib
139
+
140
+ elif temp_db.iloc[0].desc_repo != 'nan':
141
+ prediction_dict['Description'] = temp_db.iloc[0].desc_repo
142
+
143
+ else:
144
+ prediction_dict['Description'] = "Description not found"
145
+ print(prediction_dict)
146
+ print("-----------------------------------------------------------------")
147
+ return predictions_cp
148
+
149
+ def id_to_libname(id_, db_metadata):
150
+ '''
151
+ Function to convert a library id to its library name
152
+
153
+ Params:
154
+ id_ (int): a unique library id
155
+ db_metadata (pandas dataframe): a dataframe containing metadata information about the library
156
+
157
+ Returns:
158
+ library_name (string): the library name that corresponds to the input id
159
+ '''
160
+ temp_db = db_metadata[db_metadata.id==id_]
161
+ assert(len(temp_db)==1)
162
+ library_name = temp_db.iloc[0].library
163
+ return library_name
164
+
165
+
166
+ def retrieve_libraries(retrieval_model, model_input, db_metadata):
167
+ '''
168
+ Function to retrieve a set of relevant libraries using a model based on the input query
169
+
170
+ Params:
171
+ retrieval_model: a model to perform retrieval
172
+ model_input (string): an input query from the user
173
+
174
+ Returns:
175
+ library_ids (list): a list of library unique ids
176
+ library_names (list): a list of library names
177
+ '''
178
+ results = retrieval_model(model_input)
179
+ library_ids = [item.get('id') for item in results]
180
+ library_names = [id_to_libname(item, db_metadata) for item in library_ids]
181
+ return library_ids, library_names
182
+
183
+ def prepare_input_generative_model(library_ids, db_constructor):
184
+ '''
185
+ Function to prepare the input of the model to generate API usage patterns
186
+
187
+ Params:
188
+ library_ids (list): a list of library ids
189
+ db_constructor (pandas dataframe): a dataframe containing the mapping of library names to valid constructor
190
+
191
+ Returns:
192
+ output_dict (dictionary): a dictionary where the key is library id and the value is a list of valid inputs
193
+ '''
194
+ output_dict = {}
195
+ for id_ in library_ids:
196
+ temp_db = db_constructor[db_constructor.id==id_]
197
+ output_dict[id_] = []
198
+ for id__, library_name, methods, constructor in temp_db.values:
199
+ output_dict[id_].append(
200
+ f'{library_name} [SEP] {constructor}'
201
+ )
202
+ return output_dict
203
+
204
+ def generate_api_usage_patterns(generative_model, tokenizer, model_input, num_beams, num_return_sequences):
205
+ '''
206
+ Function to generate API usage patterns
207
+
208
+ Params:
209
+ generative_model: a huggingface model
210
+ tokenizer: a huggingface tokenizer
211
+ model_input (string): a string in the form of <library-name> [SEP] constructor
212
+ num_beams (int): the beam width used for decoding
213
+ num_return_sequences (int): how many API usage patterns are returned by the model
214
+
215
+ Returns:
216
+ api_usage_patterns (list): a list of API usage patterns
217
+ '''
218
+ model_input = tokenizer(model_input, return_tensors='pt').input_ids
219
+ model_output = generative_model.generate(
220
+ model_input,
221
+ num_beams=num_beams,
222
+ num_return_sequences=num_return_sequences
223
+ )
224
+ api_usage_patterns = tokenizer.batch_decode(
225
+ model_output,
226
+ skip_special_tokens=True
227
+ )
228
+ return api_usage_patterns
229
+
230
+ def generate_api_usage_patterns_batch(generative_model, tokenizer, library_ids, db_constructor, num_beams, num_return_sequences):
231
+ '''
232
+ Function to generate API usage patterns in batch
233
+
234
+ Params:
235
+ generative_model: a huggingface model
236
+ tokenizer: a huggingface tokenizer
237
+ library_ids (list): a list of libary ids
238
+ db_constructor (pandas dataframe): a dataframe containing the mapping of library names to valid constructor
239
+ num_beams (int): the beam width used for decoding
240
+ num_return_sequences (int): how many API usage patterns are returned by the model
241
+
242
+ Returns:
243
+ predictions (list): a list of dictionary containing the api usage patterns, library name, and id
244
+ '''
245
+ input_generative_model_dict = prepare_input_generative_model(library_ids, db_constructor)
246
+
247
+ predictions = []
248
+ for id_ in input_generative_model_dict:
249
+ temp_dict = {
250
+ 'id': id_,
251
+ 'library_name': None,
252
+ 'hw_config': None,
253
+ 'usage_patterns': {}
254
+ }
255
+ for input_generative_model in input_generative_model_dict.get(id_):
256
+ api_usage_patterns = generate_api_usage_patterns(
257
+ generative_model,
258
+ tokenizer,
259
+ input_generative_model,
260
+ num_beams,
261
+ num_return_sequences
262
+ )
263
+
264
+ temp = input_generative_model.split("[SEP]")
265
+ library_name = temp[0].strip()
266
+ constructor = temp[1].strip()
267
+
268
+ assert(constructor not in temp_dict.get('usage_patterns'))
269
+ temp_dict['usage_patterns'][constructor] = api_usage_patterns
270
+
271
+ assert(temp_dict.get('library_name')==None)
272
+ temp_dict['library_name'] = library_name
273
+ predictions.append(temp_dict)
274
+ return predictions
275
+
276
+ # def generate_api_usage_patterns(generative_model, tokenizer, model_inputs, num_beams, num_return_sequences):
277
+ # '''
278
+ # Function to generate API usage patterns
279
+
280
+ # Params:
281
+ # generative_model: a huggingface model
282
+ # tokenizer: a huggingface tokenizer
283
+ # model_inputs (list): a list of <library-name> [SEP] <constructor>
284
+ # num_beams (int): the beam width used for decoding
285
+ # num_return_sequences (int): how many API usage patterns are returned by the model
286
+
287
+ # Returns:
288
+ # api_usage_patterns (list): a list of API usage patterns
289
+ # '''
290
+ # model_inputs = tokenizer(
291
+ # model_inputs,
292
+ # max_length=max_length,
293
+ # padding='max_length',
294
+ # return_tensors='pt',
295
+ # truncation=True)
296
+
297
+ # model_output = generative_model.generate(
298
+ # **model_inputs,
299
+ # num_beams=num_beams,
300
+ # num_return_sequences=num_return_sequences
301
+ # )
302
+ # api_usage_patterns = tokenizer.batch_decode(
303
+ # model_output,
304
+ # skip_special_tokens=True
305
+ # )
306
+
307
+ # api_usage_patterns = [api_usage_patterns[i:i+num_return_sequences] for i in range(0, len(api_usage_patterns), num_return_sequences)]
308
+ # return api_usage_patterns
309
+
310
+ def prepare_input_classification_model(id_, db_metadata):
311
+ '''
312
+ Function to get a feature for a classification model using library id
313
+
314
+ Params:
315
+ id_ (int): a unique library id
316
+ db_metadata (pandas dataframe): a dataframe containing metadata information about the library
317
+
318
+ Returns:
319
+ feature (string): a feature used for the classification model input
320
+ '''
321
+ temp_db = db_metadata[db_metadata.id == id_]
322
+ assert(len(temp_db)==1)
323
+ feature = temp_db.iloc[0].features
324
+ return feature
325
+
326
+ def load_hw_classifier(model_path_classifier, model_path_classifier_head):
327
+ '''
328
+ Function to load a classifier model and classifier head
329
+
330
+ Params:
331
+ model_path_classifier (string): path to the classifier checkpoint (can be either huggingface path or local directory)
332
+ model_path_classifier_head (string): path to the classifier head checkpoint (should be a local directory)
333
+
334
+ Returns:
335
+ classifier_model: a huggingface model
336
+ classifier_head: a classifier model (can be either svm or rf)
337
+ tokenizer: a huggingface tokenizer
338
+ '''
339
+ tokenizer = RobertaTokenizer.from_pretrained(model_path_classifier)
340
+ classifier_model = RobertaModel.from_pretrained(model_path_classifier)
341
+ with open(model_path_classifier_head, 'rb') as f:
342
+ classifier_head = pickle.load(f)
343
+ return classifier_model, classifier_head, tokenizer
344
+
345
+ def predict_hw_config(classifier_model, classifier_tokenizer, classifier_head, library_ids, db_metadata, max_length):
346
+ '''
347
+ Function to predict hardware configs
348
+
349
+ Params:
350
+ classifier_model: a huggingface model to convert a feature to a feature vector
351
+ classifier_tokenizer: a huggingface tokenizer
352
+ classifier_head: a classifier head
353
+ library_ids (list): a list of library ids
354
+ db_metadata (pandas dataframe): a dataframe containing metadata information about the library
355
+ max_length (int): max length of the tokenizer output
356
+
357
+ Returns:
358
+ prediction (list): a list of prediction
359
+ '''
360
+
361
+ features = [prepare_input_classification_model(id_, db_metadata) for id_ in library_ids]
362
+ tokenized_features = classifier_tokenizer(
363
+ features,
364
+ max_length=max_length,
365
+ padding='max_length',
366
+ return_tensors='pt',
367
+ truncation=True
368
+ )
369
+ with torch.no_grad():
370
+ embedding_features = classifier_model(**tokenized_features).pooler_output.numpy()
371
+ prediction = classifier_head.predict_proba(embedding_features).tolist()
372
+ prediction = np.argmax(prediction, axis=1).tolist()
373
+ prediction = [classifier_class_mapping.get(idx) for idx in prediction]
374
+ return prediction
375
+
376
+
377
+ def initialize_all_components(config):
378
+ '''
379
+ Function to initialize all components of ArduProg
380
+
381
+ Params:
382
+ config (dict): a dictionary containing the configuration to initialize all components
383
+
384
+ Returns:
385
+ db_metadata (pandas dataframe): a dataframe containing metadata information about the library
386
+ db_constructor (pandas dataframe): a dataframe containing the mapping of library names to valid constructor
387
+ model_retrieval, model_generative : a huggingface model
388
+ tokenizer_generative, tokenizer_classifier: a huggingface tokenizer
389
+ model_classifier: a huggingface model
390
+ classifier_head: a random forest model
391
+ '''
392
+ # load db
393
+ db_metadata, db_constructor = load_db(
394
+ config.get('db_metadata_path'),
395
+ config.get('db_constructor_path')
396
+ )
397
+
398
+ # load model
399
+ model_retrieval = load_retrieval_model_lexical(
400
+ config.get('tokenizer_path_retrieval'),
401
+ config.get('max_k'),
402
+ db_metadata,
403
+ )
404
+
405
+ tokenizer_generative, model_generative = load_generative_model_codebert(config.get('model_path_generative'))
406
+
407
+ model_classifier, classifier_head, tokenizer_classifier = load_hw_classifier(
408
+ config.get('model_path_classifier'),
409
+ config.get('classifier_head_path')
410
+ )
411
+
412
+ return db_metadata, db_constructor, model_retrieval, model_generative, tokenizer_generative, model_classifier, classifier_head, tokenizer_classifier
413
+
414
+ def make_predictions(input_query,
415
+ model_retrieval,
416
+ model_generative,
417
+ model_classifier, classifier_head,
418
+ tokenizer_generative, tokenizer_classifier,
419
+ db_metadata, db_constructor,
420
+ config):
421
+ '''
422
+ Function to retrieve relevant libraries, generate API usage patterns, and predict the hw configs
423
+
424
+ Params:
425
+ input_query (string): a query from the user
426
+ model_retrieval, model_generative, model_classifier: a huggingface model
427
+ classifier_head: a random forest classifier
428
+ toeknizer_generative, tokenizer_classifier: a hugggingface tokenizer,
429
+ db_metadata (pandas dataframe): a dataframe containing metadata information about the library
430
+ db_constructor (pandas dataframe): a dataframe containing the mapping of library names to valid constructor
431
+ config (dict): a dictionary containing the configuration to initialize all components
432
+
433
+ Returns:
434
+ predictions (list): a list of dictionary containing the prediction details
435
+ '''
436
+ library_ids, library_names = retrieve_libraries(model_retrieval, input_query, db_metadata)
437
+
438
+ predictions = generate_api_usage_patterns_batch(
439
+ model_generative,
440
+ tokenizer_generative,
441
+ library_ids,
442
+ db_constructor,
443
+ config.get('num_beams'),
444
+ config.get('num_return_sequences')
445
+ )
446
+
447
+ hw_configs = predict_hw_config(
448
+ model_classifier,
449
+ tokenizer_classifier,
450
+ classifier_head,
451
+ library_ids,
452
+ db_metadata,
453
+ config.get('max_length')
454
+ )
455
+
456
+ for output_dict, hw_config in zip(predictions, hw_configs):
457
+ output_dict['hw_config'] = hw_config
458
+
459
+ predictions = get_metadata_library(predictions, db_metadata)
460
+
461
+ return predictions