ctheodoris commited on
Commit
0680d55
1 Parent(s): 3e11b4f

update to account for eos token with overexpression

Browse files
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +125 -43
geneformer/in_silico_perturber.py CHANGED
@@ -50,6 +50,7 @@ from . import perturber_utils as pu
50
  from .emb_extractor import get_embs
51
  from .perturber_utils import TOKEN_DICTIONARY_FILE
52
 
 
53
  sns.set()
54
 
55
 
@@ -65,7 +66,7 @@ class InSilicoPerturber:
65
  "anchor_gene": {None, str},
66
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
67
  "num_classes": {int},
68
- "emb_mode": {"cell", "cell_and_gene"},
69
  "cell_emb_style": {"mean_pool"},
70
  "filter_data": {None, dict},
71
  "cell_states_to_model": {None, dict},
@@ -95,10 +96,9 @@ class InSilicoPerturber:
95
  max_ncells=None,
96
  cell_inds_to_perturb="all",
97
  emb_layer=-1,
98
- token_dictionary_file=None,
99
  forward_batch_size=100,
100
  nproc=4,
101
-
102
  ):
103
  """
104
  Initialize in silico perturber.
@@ -138,11 +138,11 @@ class InSilicoPerturber:
138
  num_classes : int
139
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
140
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
141
- emb_mode : {"cell", "cell_and_gene"}
142
- | Whether to output impact of perturbation on cell and/or gene embeddings.
143
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
144
  cell_emb_style : "mean_pool"
145
- | Method for summarizing cell embeddings.
146
  | Currently only option is mean pooling of gene embeddings for given cell.
147
  filter_data : None, dict
148
  | Default is to use all input data for in silico perturbation study.
@@ -188,6 +188,10 @@ class InSilicoPerturber:
188
  token_dictionary_file : Path
189
  | Path to pickle file containing token dictionary (Ensembl ID:token).
190
  """
 
 
 
 
191
 
192
  self.perturb_type = perturb_type
193
  self.perturb_rank_shift = perturb_rank_shift
@@ -217,9 +221,9 @@ class InSilicoPerturber:
217
  self.max_ncells = max_ncells
218
  self.cell_inds_to_perturb = cell_inds_to_perturb
219
  self.emb_layer = emb_layer
220
- self.token_dictionary_file = token_dictionary_file
221
  self.forward_batch_size = forward_batch_size
222
  self.nproc = nproc
 
223
 
224
  self.validate_options()
225
 
@@ -236,9 +240,14 @@ class InSilicoPerturber:
236
  # Identify if special token is present in the token dictionary
237
  lowercase_token_gene_dict = {k: v.lower() for k, v in self.token_gene_dict.items()}
238
  cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
239
- sep_present = any("sep" in value for value in lowercase_token_gene_dict.values())
240
- if cls_present or sep_present:
241
  self.special_token = True
 
 
 
 
 
242
 
243
  if self.anchor_gene is None:
244
  self.anchor_token = None
@@ -442,6 +451,16 @@ class InSilicoPerturber:
442
  filtered_input_data = pu.load_and_filter(
443
  self.filter_data, self.nproc, input_data_file
444
  )
 
 
 
 
 
 
 
 
 
 
445
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
446
 
447
  if self.perturb_group is True:
@@ -544,9 +563,14 @@ class InSilicoPerturber:
544
  # then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
545
  # (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
546
  # rather than only adding 2048)
547
- filtered_input_data = filtered_input_data.map(
548
- pu.truncate_by_n_overflow, num_proc=self.nproc
549
- )
 
 
 
 
 
550
 
551
  if self.emb_mode == "cell_and_gene":
552
  stored_gene_embs_dict = defaultdict(list)
@@ -592,20 +616,27 @@ class InSilicoPerturber:
592
  silent=True,
593
  )
594
 
595
- # remove overexpressed genes
 
 
 
 
 
 
 
 
596
  if self.perturb_type == "overexpress":
597
  perturbation_emb = full_perturbation_emb[
598
- :, len(self.tokens_to_perturb) :, :
599
  ]
600
-
601
  elif self.perturb_type == "delete":
602
  perturbation_emb = full_perturbation_emb[
603
- :, : max(perturbation_batch["length"]), :
604
  ]
605
 
606
  n_perturbation_genes = perturbation_emb.size()[1]
607
 
608
- # if no goal states, the cosine similarties are the mean of gene cosine similarities
609
  if (
610
  self.cell_states_to_model is None
611
  or self.emb_mode == "cell_and_gene"
@@ -620,16 +651,22 @@ class InSilicoPerturber:
620
 
621
  # if there are goal states, the cosine similarities are the cell cosine similarities
622
  if self.cell_states_to_model is not None:
623
- original_cell_emb = pu.mean_nonpadding_embs(
624
- full_original_emb,
625
- torch.tensor(minibatch["length"], device="cuda"),
626
- dim=1,
627
- )
628
- perturbation_cell_emb = pu.mean_nonpadding_embs(
629
- full_perturbation_emb,
630
- torch.tensor(perturbation_batch["length"], device="cuda"),
631
- dim=1,
632
- )
 
 
 
 
 
 
633
  cell_cos_sims = pu.quant_cos_sims(
634
  perturbation_cell_emb,
635
  original_cell_emb,
@@ -649,6 +686,9 @@ class InSilicoPerturber:
649
  ]
650
  for genes in gene_list
651
  ]
 
 
 
652
 
653
  for cell_i, genes in enumerate(gene_list):
654
  for gene_j, affected_gene in enumerate(genes):
@@ -681,9 +721,21 @@ class InSilicoPerturber:
681
  ]
682
  else:
683
  nonpadding_lens = perturbation_batch["length"]
684
- cos_sims_data = pu.mean_nonpadding_embs(
685
- gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
686
- )
 
 
 
 
 
 
 
 
 
 
 
 
687
  cos_sims_dict = self.update_perturbation_dictionary(
688
  cos_sims_dict,
689
  cos_sims_data,
@@ -703,9 +755,15 @@ class InSilicoPerturber:
703
  )
704
  del minibatch
705
  del perturbation_batch
 
706
  del original_emb
 
707
  del perturbation_emb
708
  del cos_sims_data
 
 
 
 
709
 
710
  torch.cuda.empty_cache()
711
 
@@ -766,6 +824,7 @@ class InSilicoPerturber:
766
  self.anchor_token,
767
  self.combos,
768
  self.nproc,
 
769
  )
770
 
771
  full_perturbation_emb = get_embs(
@@ -781,16 +840,22 @@ class InSilicoPerturber:
781
  )
782
 
783
  num_inds_perturbed = 1 + self.combos
784
- # need to remove overexpressed gene to quantify cosine shifts
 
 
 
 
785
  if self.perturb_type == "overexpress":
786
- perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
787
  gene_list = gene_list[
788
- num_inds_perturbed:
789
- ] # index 0 is not overexpressed
790
 
791
  elif self.perturb_type == "delete":
792
- perturbation_emb = full_perturbation_emb
 
793
 
 
794
  original_batch = pu.make_comparison_batch(
795
  full_original_emb, indices_to_perturb, perturb_group=False
796
  )
@@ -803,13 +868,19 @@ class InSilicoPerturber:
803
  self.state_embs_dict,
804
  emb_mode="gene",
805
  )
 
806
  if self.cell_states_to_model is not None:
807
- original_cell_emb = pu.compute_nonpadded_cell_embedding(
808
- full_original_emb, "mean_pool"
809
- )
810
- perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
811
- full_perturbation_emb, "mean_pool"
812
- )
 
 
 
 
 
813
 
814
  cell_cos_sims = pu.quant_cos_sims(
815
  perturbation_cell_emb,
@@ -840,7 +911,18 @@ class InSilicoPerturber:
840
  ] = gene_cos_sims[perturbation_i, gene_j].item()
841
 
842
  if self.cell_states_to_model is None:
843
- cos_sims_data = torch.mean(gene_cos_sims, dim=1)
 
 
 
 
 
 
 
 
 
 
 
844
  cos_sims_dict = self.update_perturbation_dictionary(
845
  cos_sims_dict,
846
  cos_sims_data,
@@ -933,4 +1015,4 @@ class InSilicoPerturber:
933
  for i, cos in enumerate(cos_sims_data.tolist()):
934
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
935
 
936
- return cos_sims_dict
 
50
  from .emb_extractor import get_embs
51
  from .perturber_utils import TOKEN_DICTIONARY_FILE
52
 
53
+
54
  sns.set()
55
 
56
 
 
66
  "anchor_gene": {None, str},
67
  "model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
68
  "num_classes": {int},
69
+ "emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
70
  "cell_emb_style": {"mean_pool"},
71
  "filter_data": {None, dict},
72
  "cell_states_to_model": {None, dict},
 
96
  max_ncells=None,
97
  cell_inds_to_perturb="all",
98
  emb_layer=-1,
 
99
  forward_batch_size=100,
100
  nproc=4,
101
+ token_dictionary_file=None,
102
  ):
103
  """
104
  Initialize in silico perturber.
 
138
  num_classes : int
139
  | If model is a gene or cell classifier, specify number of classes it was trained to classify.
140
  | For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
141
+ emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"}
142
+ | Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings.
143
  | Gene embedding shifts only available as compared to original cell, not comparing to goal state.
144
  cell_emb_style : "mean_pool"
145
+ | Method for summarizing cell embeddings if not using CLS token.
146
  | Currently only option is mean pooling of gene embeddings for given cell.
147
  filter_data : None, dict
148
  | Default is to use all input data for in silico perturbation study.
 
188
  token_dictionary_file : Path
189
  | Path to pickle file containing token dictionary (Ensembl ID:token).
190
  """
191
+ try:
192
+ set_start_method("spawn")
193
+ except RuntimeError:
194
+ pass
195
 
196
  self.perturb_type = perturb_type
197
  self.perturb_rank_shift = perturb_rank_shift
 
221
  self.max_ncells = max_ncells
222
  self.cell_inds_to_perturb = cell_inds_to_perturb
223
  self.emb_layer = emb_layer
 
224
  self.forward_batch_size = forward_batch_size
225
  self.nproc = nproc
226
+ self.token_dictionary_file = token_dictionary_file
227
 
228
  self.validate_options()
229
 
 
240
  # Identify if special token is present in the token dictionary
241
  lowercase_token_gene_dict = {k: v.lower() for k, v in self.token_gene_dict.items()}
242
  cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
243
+ eos_present = any("eos" in value for value in lowercase_token_gene_dict.values())
244
+ if cls_present or eos_present:
245
  self.special_token = True
246
+ else:
247
+ if "cls" in self.emb_mode:
248
+ logger.error(f"emb_mode set to {self.emb_mode} but <cls> token not in token dictionary.")
249
+ raise
250
+ self.special_token = False
251
 
252
  if self.anchor_gene is None:
253
  self.anchor_token = None
 
451
  filtered_input_data = pu.load_and_filter(
452
  self.filter_data, self.nproc, input_data_file
453
  )
454
+
455
+ # Ensure emb_mode is cls if first token of the filtered input data is cls token
456
+ if self.special_token:
457
+ cls_token_id = self.gene_token_dict["<cls>"]
458
+ if (filtered_input_data["input_ids"][0][0] == cls_token_id) and ("cls" not in self.emb_mode):
459
+ logger.error(
460
+ "Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
461
+ )
462
+ raise
463
+
464
  filtered_input_data = self.apply_additional_filters(filtered_input_data)
465
 
466
  if self.perturb_group is True:
 
563
  # then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
564
  # (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
565
  # rather than only adding 2048)
566
+ if self.special_token:
567
+ filtered_input_data = filtered_input_data.map(
568
+ pu.truncate_by_n_overflow_special, num_proc=self.nproc
569
+ )
570
+ else:
571
+ filtered_input_data = filtered_input_data.map(
572
+ pu.truncate_by_n_overflow, num_proc=self.nproc
573
+ )
574
 
575
  if self.emb_mode == "cell_and_gene":
576
  stored_gene_embs_dict = defaultdict(list)
 
616
  silent=True,
617
  )
618
 
619
+ if "cls" not in self.emb_mode:
620
+ start = 0
621
+ else:
622
+ start = 1
623
+
624
+ # remove overexpressed genes and cls
625
+ original_emb = original_emb[
626
+ :, start :, :
627
+ ]
628
  if self.perturb_type == "overexpress":
629
  perturbation_emb = full_perturbation_emb[
630
+ :, start+len(self.tokens_to_perturb) :, :
631
  ]
 
632
  elif self.perturb_type == "delete":
633
  perturbation_emb = full_perturbation_emb[
634
+ :, start : max(perturbation_batch["length"]), :
635
  ]
636
 
637
  n_perturbation_genes = perturbation_emb.size()[1]
638
 
639
+ # if no goal states, the cosine similarities are the mean of gene cosine similarities
640
  if (
641
  self.cell_states_to_model is None
642
  or self.emb_mode == "cell_and_gene"
 
651
 
652
  # if there are goal states, the cosine similarities are the cell cosine similarities
653
  if self.cell_states_to_model is not None:
654
+ if "cls" not in self.emb_mode:
655
+ original_cell_emb = pu.mean_nonpadding_embs(
656
+ full_original_emb,
657
+ torch.tensor(minibatch["length"], device="cuda"),
658
+ dim=1,
659
+ )
660
+ perturbation_cell_emb = pu.mean_nonpadding_embs(
661
+ full_perturbation_emb,
662
+ torch.tensor(perturbation_batch["length"], device="cuda"),
663
+ dim=1,
664
+ )
665
+ else:
666
+ # get cls emb
667
+ original_cell_emb = full_original_emb[:,0,:]
668
+ perturbation_cell_emb = full_perturbation_emb[:,0,:]
669
+
670
  cell_cos_sims = pu.quant_cos_sims(
671
  perturbation_cell_emb,
672
  original_cell_emb,
 
686
  ]
687
  for genes in gene_list
688
  ]
689
+ # remove CLS if present
690
+ if "cls" in self.emb_mode:
691
+ gene_list = gene_list[1:]
692
 
693
  for cell_i, genes in enumerate(gene_list):
694
  for gene_j, affected_gene in enumerate(genes):
 
721
  ]
722
  else:
723
  nonpadding_lens = perturbation_batch["length"]
724
+ if "cls" not in self.emb_mode:
725
+ cos_sims_data = pu.mean_nonpadding_embs(
726
+ gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
727
+ )
728
+ else:
729
+ original_cls_emb = full_original_emb[:,0,:]
730
+ perturbation_cls_emb = full_perturbation_emb[:,0,:]
731
+ cos_sims_data = pu.quant_cos_sims(
732
+ perturbation_cls_emb,
733
+ original_cls_emb,
734
+ self.cell_states_to_model,
735
+ self.state_embs_dict,
736
+ emb_mode="cell",
737
+ )
738
+
739
  cos_sims_dict = self.update_perturbation_dictionary(
740
  cos_sims_dict,
741
  cos_sims_data,
 
755
  )
756
  del minibatch
757
  del perturbation_batch
758
+ del full_original_emb
759
  del original_emb
760
+ del full_perturbation_emb
761
  del perturbation_emb
762
  del cos_sims_data
763
+ if "cls" in self.emb_mode:
764
+ del original_cls_emb
765
+ del perturbation_cls_emb
766
+ del cls_cos_sims
767
 
768
  torch.cuda.empty_cache()
769
 
 
824
  self.anchor_token,
825
  self.combos,
826
  self.nproc,
827
+ self.special_token,
828
  )
829
 
830
  full_perturbation_emb = get_embs(
 
840
  )
841
 
842
  num_inds_perturbed = 1 + self.combos
843
+ # need to remove overexpressed gene and cls to quantify cosine shifts
844
+ if "cls" not in self.emb_mode:
845
+ start = 0
846
+ else:
847
+ start = 1
848
  if self.perturb_type == "overexpress":
849
+ perturbation_emb = full_perturbation_emb[:, start+num_inds_perturbed:, :]
850
  gene_list = gene_list[
851
+ start+num_inds_perturbed:
852
+ ] # cls and index 0 is not overexpressed
853
 
854
  elif self.perturb_type == "delete":
855
+ perturbation_emb = full_perturbation_emb[:, start:, :]
856
+ gene_list = gene_list[start:]
857
 
858
+ full_original_emb = full_original_emb[:, start:, :]
859
  original_batch = pu.make_comparison_batch(
860
  full_original_emb, indices_to_perturb, perturb_group=False
861
  )
 
868
  self.state_embs_dict,
869
  emb_mode="gene",
870
  )
871
+
872
  if self.cell_states_to_model is not None:
873
+ if "cls" not in self.emb_mode:
874
+ original_cell_emb = pu.compute_nonpadded_cell_embedding(
875
+ full_original_emb, "mean_pool"
876
+ )
877
+ perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
878
+ full_perturbation_emb, "mean_pool"
879
+ )
880
+ else:
881
+ # get cls emb
882
+ original_cell_emb = full_original_emb[:,0,:]
883
+ perturbation_cell_emb = full_perturbation_emb[:,0,:]
884
 
885
  cell_cos_sims = pu.quant_cos_sims(
886
  perturbation_cell_emb,
 
911
  ] = gene_cos_sims[perturbation_i, gene_j].item()
912
 
913
  if self.cell_states_to_model is None:
914
+ if "cls" not in self.emb_mode:
915
+ cos_sims_data = torch.mean(gene_cos_sims, dim=1)
916
+ else:
917
+ original_cls_emb = full_original_emb[:,0,:]
918
+ perturbation_cls_emb = full_perturbation_emb[:,0,:]
919
+ cos_sims_data = pu.quant_cos_sims(
920
+ perturbation_cls_emb,
921
+ original_cls_emb,
922
+ self.cell_states_to_model,
923
+ self.state_embs_dict,
924
+ emb_mode="cell",
925
+ )
926
  cos_sims_dict = self.update_perturbation_dictionary(
927
  cos_sims_dict,
928
  cos_sims_data,
 
1015
  for i, cos in enumerate(cos_sims_data.tolist()):
1016
  cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
1017
 
1018
+ return cos_sims_dict