hchen725 commited on
Commit
3e11b4f
1 Parent(s): 471eefc

Update geneformer/in_silico_perturber.py

Browse files

Add custom token gene dictionary to in silico pertruber

Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +21 -10
geneformer/in_silico_perturber.py CHANGED
@@ -50,7 +50,6 @@ from . import perturber_utils as pu
50
  from .emb_extractor import get_embs
51
  from .perturber_utils import TOKEN_DICTIONARY_FILE
52
 
53
-
54
  sns.set()
55
 
56
 
@@ -74,6 +73,7 @@ class InSilicoPerturber:
74
  "max_ncells": {None, int},
75
  "cell_inds_to_perturb": {"all", dict},
76
  "emb_layer": {-1, 0},
 
77
  "forward_batch_size": {int},
78
  "nproc": {int},
79
  }
@@ -95,9 +95,10 @@ class InSilicoPerturber:
95
  max_ncells=None,
96
  cell_inds_to_perturb="all",
97
  emb_layer=-1,
 
98
  forward_batch_size=100,
99
  nproc=4,
100
- token_dictionary_file=TOKEN_DICTIONARY_FILE,
101
  ):
102
  """
103
  Initialize in silico perturber.
@@ -187,10 +188,6 @@ class InSilicoPerturber:
187
  token_dictionary_file : Path
188
  | Path to pickle file containing token dictionary (Ensembl ID:token).
189
  """
190
- try:
191
- set_start_method("spawn")
192
- except RuntimeError:
193
- pass
194
 
195
  self.perturb_type = perturb_type
196
  self.perturb_rank_shift = perturb_rank_shift
@@ -220,17 +217,29 @@ class InSilicoPerturber:
220
  self.max_ncells = max_ncells
221
  self.cell_inds_to_perturb = cell_inds_to_perturb
222
  self.emb_layer = emb_layer
 
223
  self.forward_batch_size = forward_batch_size
224
  self.nproc = nproc
225
 
226
  self.validate_options()
227
 
228
  # load token dictionary (Ensembl IDs:token)
 
 
229
  with open(token_dictionary_file, "rb") as f:
230
  self.gene_token_dict = pickle.load(f)
 
231
 
232
  self.pad_token_id = self.gene_token_dict.get("<pad>")
233
 
 
 
 
 
 
 
 
 
234
  if self.anchor_gene is None:
235
  self.anchor_token = None
236
  else:
@@ -287,7 +296,7 @@ class InSilicoPerturber:
287
  continue
288
  valid_type = False
289
  for option in valid_options:
290
- if (option in [bool, int, list, dict]) and isinstance(
291
  attr_value, option
292
  ):
293
  valid_type = True
@@ -428,7 +437,6 @@ class InSilicoPerturber:
428
  self.max_len = pu.get_model_input_size(model)
429
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
430
 
431
-
432
  ### filter input data ###
433
  # general filtering of input data based on filter_data argument
434
  filtered_input_data = pu.load_and_filter(
@@ -506,7 +514,7 @@ class InSilicoPerturber:
506
  if self.perturb_type == "delete":
507
  example = pu.delete_indices(example)
508
  elif self.perturb_type == "overexpress":
509
- example = pu.overexpress_tokens(example, self.max_len)
510
  example["n_overflow"] = pu.calc_n_overflow(
511
  self.max_len,
512
  example["length"],
@@ -527,7 +535,6 @@ class InSilicoPerturber:
527
  perturbed_data = filtered_input_data.map(
528
  make_group_perturbation_batch, num_proc=self.nproc
529
  )
530
-
531
  if self.perturb_type == "overexpress":
532
  filtered_input_data = filtered_input_data.add_column(
533
  "n_overflow", perturbed_data["n_overflow"]
@@ -560,6 +567,7 @@ class InSilicoPerturber:
560
  layer_to_quant,
561
  self.pad_token_id,
562
  self.forward_batch_size,
 
563
  summary_stat=None,
564
  silent=True,
565
  )
@@ -579,6 +587,7 @@ class InSilicoPerturber:
579
  layer_to_quant,
580
  self.pad_token_id,
581
  self.forward_batch_size,
 
582
  summary_stat=None,
583
  silent=True,
584
  )
@@ -738,6 +747,7 @@ class InSilicoPerturber:
738
  layer_to_quant,
739
  self.pad_token_id,
740
  self.forward_batch_size,
 
741
  summary_stat=None,
742
  silent=True,
743
  )
@@ -765,6 +775,7 @@ class InSilicoPerturber:
765
  layer_to_quant,
766
  self.pad_token_id,
767
  self.forward_batch_size,
 
768
  summary_stat=None,
769
  silent=True,
770
  )
 
50
  from .emb_extractor import get_embs
51
  from .perturber_utils import TOKEN_DICTIONARY_FILE
52
 
 
53
  sns.set()
54
 
55
 
 
73
  "max_ncells": {None, int},
74
  "cell_inds_to_perturb": {"all", dict},
75
  "emb_layer": {-1, 0},
76
+ "token_dictionary_file" : {None, str},
77
  "forward_batch_size": {int},
78
  "nproc": {int},
79
  }
 
95
  max_ncells=None,
96
  cell_inds_to_perturb="all",
97
  emb_layer=-1,
98
+ token_dictionary_file=None,
99
  forward_batch_size=100,
100
  nproc=4,
101
+
102
  ):
103
  """
104
  Initialize in silico perturber.
 
188
  token_dictionary_file : Path
189
  | Path to pickle file containing token dictionary (Ensembl ID:token).
190
  """
 
 
 
 
191
 
192
  self.perturb_type = perturb_type
193
  self.perturb_rank_shift = perturb_rank_shift
 
217
  self.max_ncells = max_ncells
218
  self.cell_inds_to_perturb = cell_inds_to_perturb
219
  self.emb_layer = emb_layer
220
+ self.token_dictionary_file = token_dictionary_file
221
  self.forward_batch_size = forward_batch_size
222
  self.nproc = nproc
223
 
224
  self.validate_options()
225
 
226
  # load token dictionary (Ensembl IDs:token)
227
+ if self.token_dictionary_file is None:
228
+ token_dictionary_file = TOKEN_DICTIONARY_FILE
229
  with open(token_dictionary_file, "rb") as f:
230
  self.gene_token_dict = pickle.load(f)
231
+ self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
232
 
233
  self.pad_token_id = self.gene_token_dict.get("<pad>")
234
 
235
+
236
+ # Identify if special token is present in the token dictionary
237
+ lowercase_token_gene_dict = {k: v.lower() for k, v in self.token_gene_dict.items()}
238
+ cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
239
+ sep_present = any("sep" in value for value in lowercase_token_gene_dict.values())
240
+ if cls_present or sep_present:
241
+ self.special_token = True
242
+
243
  if self.anchor_gene is None:
244
  self.anchor_token = None
245
  else:
 
296
  continue
297
  valid_type = False
298
  for option in valid_options:
299
+ if (option in [bool, int, list, dict, str]) and isinstance(
300
  attr_value, option
301
  ):
302
  valid_type = True
 
437
  self.max_len = pu.get_model_input_size(model)
438
  layer_to_quant = pu.quant_layers(model) + self.emb_layer
439
 
 
440
  ### filter input data ###
441
  # general filtering of input data based on filter_data argument
442
  filtered_input_data = pu.load_and_filter(
 
514
  if self.perturb_type == "delete":
515
  example = pu.delete_indices(example)
516
  elif self.perturb_type == "overexpress":
517
+ example = pu.overexpress_tokens(example, self.max_len, self.special_token)
518
  example["n_overflow"] = pu.calc_n_overflow(
519
  self.max_len,
520
  example["length"],
 
535
  perturbed_data = filtered_input_data.map(
536
  make_group_perturbation_batch, num_proc=self.nproc
537
  )
 
538
  if self.perturb_type == "overexpress":
539
  filtered_input_data = filtered_input_data.add_column(
540
  "n_overflow", perturbed_data["n_overflow"]
 
567
  layer_to_quant,
568
  self.pad_token_id,
569
  self.forward_batch_size,
570
+ self.token_gene_dict,
571
  summary_stat=None,
572
  silent=True,
573
  )
 
587
  layer_to_quant,
588
  self.pad_token_id,
589
  self.forward_batch_size,
590
+ self.token_gene_dict,
591
  summary_stat=None,
592
  silent=True,
593
  )
 
747
  layer_to_quant,
748
  self.pad_token_id,
749
  self.forward_batch_size,
750
+ self.token_gene_dict,
751
  summary_stat=None,
752
  silent=True,
753
  )
 
775
  layer_to_quant,
776
  self.pad_token_id,
777
  self.forward_batch_size,
778
+ self.token_gene_dict,
779
  summary_stat=None,
780
  silent=True,
781
  )