imamnurby commited on
Commit
15a540f
1 Parent(s): 2487c5b

Update backend_utils.py

Browse files
Files changed (1) hide show
  1. backend_utils.py +48 -17
backend_utils.py CHANGED
@@ -24,7 +24,7 @@ def generate_index(db):
24
  })
25
  return index_list
26
 
27
- def load_db(db_metadata_path, db_constructor_path):
28
  '''
29
  Function to load dataframe
30
 
@@ -40,7 +40,9 @@ def load_db(db_metadata_path, db_constructor_path):
40
  db_metadata.dropna(inplace=True)
41
  db_constructor = pd.read_csv(db_constructor_path)
42
  db_constructor.dropna(inplace=True)
43
- return db_metadata, db_constructor
 
 
44
 
45
 
46
 
@@ -142,8 +144,6 @@ def get_metadata_library(predictions, db_metadata):
142
 
143
  else:
144
  prediction_dict['Description'] = "Description not found"
145
- print(prediction_dict)
146
- print("-----------------------------------------------------------------")
147
  return predictions_cp
148
 
149
  def id_to_libname(id_, db_metadata):
@@ -201,7 +201,7 @@ def prepare_input_generative_model(library_ids, db_constructor):
201
  )
202
  return output_dict
203
 
204
- def generate_api_usage_patterns(generative_model, tokenizer, model_input, num_beams, num_return_sequences):
205
  '''
206
  Function to generate API usage patterns
207
 
@@ -221,7 +221,7 @@ def generate_api_usage_patterns(generative_model, tokenizer, model_input, num_be
221
  num_beams=num_beams,
222
  num_return_sequences=num_return_sequences,
223
  early_stopping=True,
224
- max_length=50
225
  )
226
  api_usage_patterns = tokenizer.batch_decode(
227
  model_output,
@@ -229,7 +229,36 @@ def generate_api_usage_patterns(generative_model, tokenizer, model_input, num_be
229
  )
230
  return api_usage_patterns
231
 
232
- def generate_api_usage_patterns_batch(generative_model, tokenizer, library_ids, db_constructor, num_beams, num_return_sequences):
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
233
  '''
234
  Function to generate API usage patterns in batch
235
 
@@ -260,7 +289,8 @@ def generate_api_usage_patterns_batch(generative_model, tokenizer, library_ids,
260
  tokenizer,
261
  input_generative_model,
262
  num_beams,
263
- num_return_sequences
 
264
  )
265
 
266
  temp = input_generative_model.split("[SEP]")
@@ -268,6 +298,7 @@ def generate_api_usage_patterns_batch(generative_model, tokenizer, library_ids,
268
  constructor = temp[1].strip()
269
 
270
  assert(constructor not in temp_dict.get('usage_patterns'))
 
271
  temp_dict['usage_patterns'][constructor] = api_usage_patterns
272
 
273
  assert(temp_dict.get('library_name')==None)
@@ -392,9 +423,10 @@ def initialize_all_components(config):
392
  classifier_head: a random forest model
393
  '''
394
  # load db
395
- db_metadata, db_constructor = load_db(
396
  config.get('db_metadata_path'),
397
- config.get('db_constructor_path')
 
398
  )
399
 
400
  # load model
@@ -411,14 +443,14 @@ def initialize_all_components(config):
411
  config.get('classifier_head_path')
412
  )
413
 
414
- return db_metadata, db_constructor, model_retrieval, model_generative, tokenizer_generative, model_classifier, classifier_head, tokenizer_classifier
415
 
416
  def make_predictions(input_query,
417
  model_retrieval,
418
  model_generative,
419
  model_classifier, classifier_head,
420
  tokenizer_generative, tokenizer_classifier,
421
- db_metadata, db_constructor,
422
  config):
423
  '''
424
  Function to retrieve relevant libraries, generate API usage patterns, and predict the hw configs
@@ -435,20 +467,19 @@ def make_predictions(input_query,
435
  Returns:
436
  predictions (list): a list of dictionary containing the prediction details
437
  '''
438
- print("retrieve library")
439
  library_ids, library_names = retrieve_libraries(model_retrieval, input_query, db_metadata)
440
-
441
- print("generate hw patterns")
442
  predictions = generate_api_usage_patterns_batch(
443
  model_generative,
444
  tokenizer_generative,
445
  library_ids,
446
  db_constructor,
 
447
  config.get('num_beams'),
448
- config.get('num_return_sequences')
 
449
  )
450
 
451
- print("generate hw config")
452
  hw_configs = predict_hw_config(
453
  model_classifier,
454
  tokenizer_classifier,
 
24
  })
25
  return index_list
26
 
27
+ def load_db(db_metadata_path, db_constructor_path, db_params_path):
28
  '''
29
  Function to load dataframe
30
 
 
40
  db_metadata.dropna(inplace=True)
41
  db_constructor = pd.read_csv(db_constructor_path)
42
  db_constructor.dropna(inplace=True)
43
+ db_params = pd.read_csv(db_params_path)
44
+ db_params.dropna(inplace=True)
45
+ return db_metadata, db_constructor, db_params
46
 
47
 
48
 
 
144
 
145
  else:
146
  prediction_dict['Description'] = "Description not found"
 
 
147
  return predictions_cp
148
 
149
  def id_to_libname(id_, db_metadata):
 
201
  )
202
  return output_dict
203
 
204
+ def generate_api_usage_patterns(generative_model, tokenizer, model_input, num_beams, num_return_sequences, max_length):
205
  '''
206
  Function to generate API usage patterns
207
 
 
221
  num_beams=num_beams,
222
  num_return_sequences=num_return_sequences,
223
  early_stopping=True,
224
+ max_length=max_length
225
  )
226
  api_usage_patterns = tokenizer.batch_decode(
227
  model_output,
 
229
  )
230
  return api_usage_patterns
231
 
232
+ def add_params(api_usage_patterns, db_params, library_id):
233
+ patterns_cp = api_usage_patterns.copy()
234
+ valid = True
235
+ processed_sequences = []
236
+ for sequence in patterns_cp:
237
+ sequence_list = sequence.split()
238
+
239
+ if len(sequence_list) < 2:
240
+ continue
241
+
242
+ temp_list = []
243
+ for api in sequence_list:
244
+ temp_db = db_params[(db_params.id==library_id) & (db_params.methods==api.split(".")[-1])]
245
+
246
+ if len(temp_db) > 0:
247
+ param = temp_db.iloc[0].params
248
+ new_api = api + param
249
+ temp_list.append(new_api)
250
+ else:
251
+ valid = False
252
+ break
253
+
254
+ if valid:
255
+ processed_sequences.append("[API-SEP]".join(temp_list))
256
+ else:
257
+ valid = True
258
+ return processed_sequences
259
+
260
+
261
+ def generate_api_usage_patterns_batch(generative_model, tokenizer, library_ids, db_constructor, db_params, num_beams, num_return_sequences, max_length):
262
  '''
263
  Function to generate API usage patterns in batch
264
 
 
289
  tokenizer,
290
  input_generative_model,
291
  num_beams,
292
+ num_return_sequences,
293
+ max_length
294
  )
295
 
296
  temp = input_generative_model.split("[SEP]")
 
298
  constructor = temp[1].strip()
299
 
300
  assert(constructor not in temp_dict.get('usage_patterns'))
301
+ api_usage_patterns = add_params(api_usage_patterns, db_params, id_)
302
  temp_dict['usage_patterns'][constructor] = api_usage_patterns
303
 
304
  assert(temp_dict.get('library_name')==None)
 
423
  classifier_head: a random forest model
424
  '''
425
  # load db
426
+ db_metadata, db_constructor, db_params = load_db(
427
  config.get('db_metadata_path'),
428
+ config.get('db_constructor_path'),
429
+ config.get('db_params_path')
430
  )
431
 
432
  # load model
 
443
  config.get('classifier_head_path')
444
  )
445
 
446
+ return db_metadata, db_constructor, db_params, model_retrieval, model_generative, tokenizer_generative, model_classifier, classifier_head, tokenizer_classifier
447
 
448
  def make_predictions(input_query,
449
  model_retrieval,
450
  model_generative,
451
  model_classifier, classifier_head,
452
  tokenizer_generative, tokenizer_classifier,
453
+ db_metadata, db_constructor, db_params,
454
  config):
455
  '''
456
  Function to retrieve relevant libraries, generate API usage patterns, and predict the hw configs
 
467
  Returns:
468
  predictions (list): a list of dictionary containing the prediction details
469
  '''
 
470
  library_ids, library_names = retrieve_libraries(model_retrieval, input_query, db_metadata)
471
+
 
472
  predictions = generate_api_usage_patterns_batch(
473
  model_generative,
474
  tokenizer_generative,
475
  library_ids,
476
  db_constructor,
477
+ db_params,
478
  config.get('num_beams'),
479
+ config.get('num_return_sequences'),
480
+ config.get('max_length_generate')
481
  )
482
 
 
483
  hw_configs = predict_hw_config(
484
  model_classifier,
485
  tokenizer_classifier,