ctheodoris
commited on
Commit
•
0680d55
1
Parent(s):
3e11b4f
update to account for eos token with overexpression
Browse files- geneformer/in_silico_perturber.py +125 -43
geneformer/in_silico_perturber.py
CHANGED
@@ -50,6 +50,7 @@ from . import perturber_utils as pu
|
|
50 |
from .emb_extractor import get_embs
|
51 |
from .perturber_utils import TOKEN_DICTIONARY_FILE
|
52 |
|
|
|
53 |
sns.set()
|
54 |
|
55 |
|
@@ -65,7 +66,7 @@ class InSilicoPerturber:
|
|
65 |
"anchor_gene": {None, str},
|
66 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
67 |
"num_classes": {int},
|
68 |
-
"emb_mode": {"cell", "cell_and_gene"},
|
69 |
"cell_emb_style": {"mean_pool"},
|
70 |
"filter_data": {None, dict},
|
71 |
"cell_states_to_model": {None, dict},
|
@@ -95,10 +96,9 @@ class InSilicoPerturber:
|
|
95 |
max_ncells=None,
|
96 |
cell_inds_to_perturb="all",
|
97 |
emb_layer=-1,
|
98 |
-
token_dictionary_file=None,
|
99 |
forward_batch_size=100,
|
100 |
nproc=4,
|
101 |
-
|
102 |
):
|
103 |
"""
|
104 |
Initialize in silico perturber.
|
@@ -138,11 +138,11 @@ class InSilicoPerturber:
|
|
138 |
num_classes : int
|
139 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
140 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
141 |
-
emb_mode : {"cell", "cell_and_gene"}
|
142 |
-
| Whether to output impact of perturbation on cell and/or gene embeddings.
|
143 |
| Gene embedding shifts only available as compared to original cell, not comparing to goal state.
|
144 |
cell_emb_style : "mean_pool"
|
145 |
-
| Method for summarizing cell embeddings.
|
146 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
147 |
filter_data : None, dict
|
148 |
| Default is to use all input data for in silico perturbation study.
|
@@ -188,6 +188,10 @@ class InSilicoPerturber:
|
|
188 |
token_dictionary_file : Path
|
189 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
190 |
"""
|
|
|
|
|
|
|
|
|
191 |
|
192 |
self.perturb_type = perturb_type
|
193 |
self.perturb_rank_shift = perturb_rank_shift
|
@@ -217,9 +221,9 @@ class InSilicoPerturber:
|
|
217 |
self.max_ncells = max_ncells
|
218 |
self.cell_inds_to_perturb = cell_inds_to_perturb
|
219 |
self.emb_layer = emb_layer
|
220 |
-
self.token_dictionary_file = token_dictionary_file
|
221 |
self.forward_batch_size = forward_batch_size
|
222 |
self.nproc = nproc
|
|
|
223 |
|
224 |
self.validate_options()
|
225 |
|
@@ -236,9 +240,14 @@ class InSilicoPerturber:
|
|
236 |
# Identify if special token is present in the token dictionary
|
237 |
lowercase_token_gene_dict = {k: v.lower() for k, v in self.token_gene_dict.items()}
|
238 |
cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
|
239 |
-
|
240 |
-
if cls_present or
|
241 |
self.special_token = True
|
|
|
|
|
|
|
|
|
|
|
242 |
|
243 |
if self.anchor_gene is None:
|
244 |
self.anchor_token = None
|
@@ -442,6 +451,16 @@ class InSilicoPerturber:
|
|
442 |
filtered_input_data = pu.load_and_filter(
|
443 |
self.filter_data, self.nproc, input_data_file
|
444 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
445 |
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
446 |
|
447 |
if self.perturb_group is True:
|
@@ -544,9 +563,14 @@ class InSilicoPerturber:
|
|
544 |
# then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
|
545 |
# (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
|
546 |
# rather than only adding 2048)
|
547 |
-
|
548 |
-
|
549 |
-
|
|
|
|
|
|
|
|
|
|
|
550 |
|
551 |
if self.emb_mode == "cell_and_gene":
|
552 |
stored_gene_embs_dict = defaultdict(list)
|
@@ -592,20 +616,27 @@ class InSilicoPerturber:
|
|
592 |
silent=True,
|
593 |
)
|
594 |
|
595 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
596 |
if self.perturb_type == "overexpress":
|
597 |
perturbation_emb = full_perturbation_emb[
|
598 |
-
:, len(self.tokens_to_perturb) :, :
|
599 |
]
|
600 |
-
|
601 |
elif self.perturb_type == "delete":
|
602 |
perturbation_emb = full_perturbation_emb[
|
603 |
-
:, : max(perturbation_batch["length"]), :
|
604 |
]
|
605 |
|
606 |
n_perturbation_genes = perturbation_emb.size()[1]
|
607 |
|
608 |
-
# if no goal states, the cosine
|
609 |
if (
|
610 |
self.cell_states_to_model is None
|
611 |
or self.emb_mode == "cell_and_gene"
|
@@ -620,16 +651,22 @@ class InSilicoPerturber:
|
|
620 |
|
621 |
# if there are goal states, the cosine similarities are the cell cosine similarities
|
622 |
if self.cell_states_to_model is not None:
|
623 |
-
|
624 |
-
|
625 |
-
|
626 |
-
|
627 |
-
|
628 |
-
|
629 |
-
|
630 |
-
|
631 |
-
|
632 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
633 |
cell_cos_sims = pu.quant_cos_sims(
|
634 |
perturbation_cell_emb,
|
635 |
original_cell_emb,
|
@@ -649,6 +686,9 @@ class InSilicoPerturber:
|
|
649 |
]
|
650 |
for genes in gene_list
|
651 |
]
|
|
|
|
|
|
|
652 |
|
653 |
for cell_i, genes in enumerate(gene_list):
|
654 |
for gene_j, affected_gene in enumerate(genes):
|
@@ -681,9 +721,21 @@ class InSilicoPerturber:
|
|
681 |
]
|
682 |
else:
|
683 |
nonpadding_lens = perturbation_batch["length"]
|
684 |
-
|
685 |
-
|
686 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
687 |
cos_sims_dict = self.update_perturbation_dictionary(
|
688 |
cos_sims_dict,
|
689 |
cos_sims_data,
|
@@ -703,9 +755,15 @@ class InSilicoPerturber:
|
|
703 |
)
|
704 |
del minibatch
|
705 |
del perturbation_batch
|
|
|
706 |
del original_emb
|
|
|
707 |
del perturbation_emb
|
708 |
del cos_sims_data
|
|
|
|
|
|
|
|
|
709 |
|
710 |
torch.cuda.empty_cache()
|
711 |
|
@@ -766,6 +824,7 @@ class InSilicoPerturber:
|
|
766 |
self.anchor_token,
|
767 |
self.combos,
|
768 |
self.nproc,
|
|
|
769 |
)
|
770 |
|
771 |
full_perturbation_emb = get_embs(
|
@@ -781,16 +840,22 @@ class InSilicoPerturber:
|
|
781 |
)
|
782 |
|
783 |
num_inds_perturbed = 1 + self.combos
|
784 |
-
# need to remove overexpressed gene to quantify cosine shifts
|
|
|
|
|
|
|
|
|
785 |
if self.perturb_type == "overexpress":
|
786 |
-
perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
|
787 |
gene_list = gene_list[
|
788 |
-
num_inds_perturbed:
|
789 |
-
] # index 0 is not overexpressed
|
790 |
|
791 |
elif self.perturb_type == "delete":
|
792 |
-
perturbation_emb = full_perturbation_emb
|
|
|
793 |
|
|
|
794 |
original_batch = pu.make_comparison_batch(
|
795 |
full_original_emb, indices_to_perturb, perturb_group=False
|
796 |
)
|
@@ -803,13 +868,19 @@ class InSilicoPerturber:
|
|
803 |
self.state_embs_dict,
|
804 |
emb_mode="gene",
|
805 |
)
|
|
|
806 |
if self.cell_states_to_model is not None:
|
807 |
-
|
808 |
-
|
809 |
-
|
810 |
-
|
811 |
-
|
812 |
-
|
|
|
|
|
|
|
|
|
|
|
813 |
|
814 |
cell_cos_sims = pu.quant_cos_sims(
|
815 |
perturbation_cell_emb,
|
@@ -840,7 +911,18 @@ class InSilicoPerturber:
|
|
840 |
] = gene_cos_sims[perturbation_i, gene_j].item()
|
841 |
|
842 |
if self.cell_states_to_model is None:
|
843 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
844 |
cos_sims_dict = self.update_perturbation_dictionary(
|
845 |
cos_sims_dict,
|
846 |
cos_sims_data,
|
@@ -933,4 +1015,4 @@ class InSilicoPerturber:
|
|
933 |
for i, cos in enumerate(cos_sims_data.tolist()):
|
934 |
cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
|
935 |
|
936 |
-
return cos_sims_dict
|
|
|
50 |
from .emb_extractor import get_embs
|
51 |
from .perturber_utils import TOKEN_DICTIONARY_FILE
|
52 |
|
53 |
+
|
54 |
sns.set()
|
55 |
|
56 |
|
|
|
66 |
"anchor_gene": {None, str},
|
67 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
68 |
"num_classes": {int},
|
69 |
+
"emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
|
70 |
"cell_emb_style": {"mean_pool"},
|
71 |
"filter_data": {None, dict},
|
72 |
"cell_states_to_model": {None, dict},
|
|
|
96 |
max_ncells=None,
|
97 |
cell_inds_to_perturb="all",
|
98 |
emb_layer=-1,
|
|
|
99 |
forward_batch_size=100,
|
100 |
nproc=4,
|
101 |
+
token_dictionary_file=None,
|
102 |
):
|
103 |
"""
|
104 |
Initialize in silico perturber.
|
|
|
138 |
num_classes : int
|
139 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
140 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
141 |
+
emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"}
|
142 |
+
| Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings.
|
143 |
| Gene embedding shifts only available as compared to original cell, not comparing to goal state.
|
144 |
cell_emb_style : "mean_pool"
|
145 |
+
| Method for summarizing cell embeddings if not using CLS token.
|
146 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
147 |
filter_data : None, dict
|
148 |
| Default is to use all input data for in silico perturbation study.
|
|
|
188 |
token_dictionary_file : Path
|
189 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
190 |
"""
|
191 |
+
try:
|
192 |
+
set_start_method("spawn")
|
193 |
+
except RuntimeError:
|
194 |
+
pass
|
195 |
|
196 |
self.perturb_type = perturb_type
|
197 |
self.perturb_rank_shift = perturb_rank_shift
|
|
|
221 |
self.max_ncells = max_ncells
|
222 |
self.cell_inds_to_perturb = cell_inds_to_perturb
|
223 |
self.emb_layer = emb_layer
|
|
|
224 |
self.forward_batch_size = forward_batch_size
|
225 |
self.nproc = nproc
|
226 |
+
self.token_dictionary_file = token_dictionary_file
|
227 |
|
228 |
self.validate_options()
|
229 |
|
|
|
240 |
# Identify if special token is present in the token dictionary
|
241 |
lowercase_token_gene_dict = {k: v.lower() for k, v in self.token_gene_dict.items()}
|
242 |
cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
|
243 |
+
eos_present = any("eos" in value for value in lowercase_token_gene_dict.values())
|
244 |
+
if cls_present or eos_present:
|
245 |
self.special_token = True
|
246 |
+
else:
|
247 |
+
if "cls" in self.emb_mode:
|
248 |
+
logger.error(f"emb_mode set to {self.emb_mode} but <cls> token not in token dictionary.")
|
249 |
+
raise
|
250 |
+
self.special_token = False
|
251 |
|
252 |
if self.anchor_gene is None:
|
253 |
self.anchor_token = None
|
|
|
451 |
filtered_input_data = pu.load_and_filter(
|
452 |
self.filter_data, self.nproc, input_data_file
|
453 |
)
|
454 |
+
|
455 |
+
# Ensure emb_mode is cls if first token of the filtered input data is cls token
|
456 |
+
if self.special_token:
|
457 |
+
cls_token_id = self.gene_token_dict["<cls>"]
|
458 |
+
if (filtered_input_data["input_ids"][0][0] == cls_token_id) and ("cls" not in self.emb_mode):
|
459 |
+
logger.error(
|
460 |
+
"Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
|
461 |
+
)
|
462 |
+
raise
|
463 |
+
|
464 |
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
465 |
|
466 |
if self.perturb_group is True:
|
|
|
563 |
# then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
|
564 |
# (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
|
565 |
# rather than only adding 2048)
|
566 |
+
if self.special_token:
|
567 |
+
filtered_input_data = filtered_input_data.map(
|
568 |
+
pu.truncate_by_n_overflow_special, num_proc=self.nproc
|
569 |
+
)
|
570 |
+
else:
|
571 |
+
filtered_input_data = filtered_input_data.map(
|
572 |
+
pu.truncate_by_n_overflow, num_proc=self.nproc
|
573 |
+
)
|
574 |
|
575 |
if self.emb_mode == "cell_and_gene":
|
576 |
stored_gene_embs_dict = defaultdict(list)
|
|
|
616 |
silent=True,
|
617 |
)
|
618 |
|
619 |
+
if "cls" not in self.emb_mode:
|
620 |
+
start = 0
|
621 |
+
else:
|
622 |
+
start = 1
|
623 |
+
|
624 |
+
# remove overexpressed genes and cls
|
625 |
+
original_emb = original_emb[
|
626 |
+
:, start :, :
|
627 |
+
]
|
628 |
if self.perturb_type == "overexpress":
|
629 |
perturbation_emb = full_perturbation_emb[
|
630 |
+
:, start+len(self.tokens_to_perturb) :, :
|
631 |
]
|
|
|
632 |
elif self.perturb_type == "delete":
|
633 |
perturbation_emb = full_perturbation_emb[
|
634 |
+
:, start : max(perturbation_batch["length"]), :
|
635 |
]
|
636 |
|
637 |
n_perturbation_genes = perturbation_emb.size()[1]
|
638 |
|
639 |
+
# if no goal states, the cosine similarities are the mean of gene cosine similarities
|
640 |
if (
|
641 |
self.cell_states_to_model is None
|
642 |
or self.emb_mode == "cell_and_gene"
|
|
|
651 |
|
652 |
# if there are goal states, the cosine similarities are the cell cosine similarities
|
653 |
if self.cell_states_to_model is not None:
|
654 |
+
if "cls" not in self.emb_mode:
|
655 |
+
original_cell_emb = pu.mean_nonpadding_embs(
|
656 |
+
full_original_emb,
|
657 |
+
torch.tensor(minibatch["length"], device="cuda"),
|
658 |
+
dim=1,
|
659 |
+
)
|
660 |
+
perturbation_cell_emb = pu.mean_nonpadding_embs(
|
661 |
+
full_perturbation_emb,
|
662 |
+
torch.tensor(perturbation_batch["length"], device="cuda"),
|
663 |
+
dim=1,
|
664 |
+
)
|
665 |
+
else:
|
666 |
+
# get cls emb
|
667 |
+
original_cell_emb = full_original_emb[:,0,:]
|
668 |
+
perturbation_cell_emb = full_perturbation_emb[:,0,:]
|
669 |
+
|
670 |
cell_cos_sims = pu.quant_cos_sims(
|
671 |
perturbation_cell_emb,
|
672 |
original_cell_emb,
|
|
|
686 |
]
|
687 |
for genes in gene_list
|
688 |
]
|
689 |
+
# remove CLS if present
|
690 |
+
if "cls" in self.emb_mode:
|
691 |
+
gene_list = gene_list[1:]
|
692 |
|
693 |
for cell_i, genes in enumerate(gene_list):
|
694 |
for gene_j, affected_gene in enumerate(genes):
|
|
|
721 |
]
|
722 |
else:
|
723 |
nonpadding_lens = perturbation_batch["length"]
|
724 |
+
if "cls" not in self.emb_mode:
|
725 |
+
cos_sims_data = pu.mean_nonpadding_embs(
|
726 |
+
gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
|
727 |
+
)
|
728 |
+
else:
|
729 |
+
original_cls_emb = full_original_emb[:,0,:]
|
730 |
+
perturbation_cls_emb = full_perturbation_emb[:,0,:]
|
731 |
+
cos_sims_data = pu.quant_cos_sims(
|
732 |
+
perturbation_cls_emb,
|
733 |
+
original_cls_emb,
|
734 |
+
self.cell_states_to_model,
|
735 |
+
self.state_embs_dict,
|
736 |
+
emb_mode="cell",
|
737 |
+
)
|
738 |
+
|
739 |
cos_sims_dict = self.update_perturbation_dictionary(
|
740 |
cos_sims_dict,
|
741 |
cos_sims_data,
|
|
|
755 |
)
|
756 |
del minibatch
|
757 |
del perturbation_batch
|
758 |
+
del full_original_emb
|
759 |
del original_emb
|
760 |
+
del full_perturbation_emb
|
761 |
del perturbation_emb
|
762 |
del cos_sims_data
|
763 |
+
if "cls" in self.emb_mode:
|
764 |
+
del original_cls_emb
|
765 |
+
del perturbation_cls_emb
|
766 |
+
del cls_cos_sims
|
767 |
|
768 |
torch.cuda.empty_cache()
|
769 |
|
|
|
824 |
self.anchor_token,
|
825 |
self.combos,
|
826 |
self.nproc,
|
827 |
+
self.special_token,
|
828 |
)
|
829 |
|
830 |
full_perturbation_emb = get_embs(
|
|
|
840 |
)
|
841 |
|
842 |
num_inds_perturbed = 1 + self.combos
|
843 |
+
# need to remove overexpressed gene and cls to quantify cosine shifts
|
844 |
+
if "cls" not in self.emb_mode:
|
845 |
+
start = 0
|
846 |
+
else:
|
847 |
+
start = 1
|
848 |
if self.perturb_type == "overexpress":
|
849 |
+
perturbation_emb = full_perturbation_emb[:, start+num_inds_perturbed:, :]
|
850 |
gene_list = gene_list[
|
851 |
+
start+num_inds_perturbed:
|
852 |
+
] # cls and index 0 is not overexpressed
|
853 |
|
854 |
elif self.perturb_type == "delete":
|
855 |
+
perturbation_emb = full_perturbation_emb[:, start:, :]
|
856 |
+
gene_list = gene_list[start:]
|
857 |
|
858 |
+
full_original_emb = full_original_emb[:, start:, :]
|
859 |
original_batch = pu.make_comparison_batch(
|
860 |
full_original_emb, indices_to_perturb, perturb_group=False
|
861 |
)
|
|
|
868 |
self.state_embs_dict,
|
869 |
emb_mode="gene",
|
870 |
)
|
871 |
+
|
872 |
if self.cell_states_to_model is not None:
|
873 |
+
if "cls" not in self.emb_mode:
|
874 |
+
original_cell_emb = pu.compute_nonpadded_cell_embedding(
|
875 |
+
full_original_emb, "mean_pool"
|
876 |
+
)
|
877 |
+
perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
|
878 |
+
full_perturbation_emb, "mean_pool"
|
879 |
+
)
|
880 |
+
else:
|
881 |
+
# get cls emb
|
882 |
+
original_cell_emb = full_original_emb[:,0,:]
|
883 |
+
perturbation_cell_emb = full_perturbation_emb[:,0,:]
|
884 |
|
885 |
cell_cos_sims = pu.quant_cos_sims(
|
886 |
perturbation_cell_emb,
|
|
|
911 |
] = gene_cos_sims[perturbation_i, gene_j].item()
|
912 |
|
913 |
if self.cell_states_to_model is None:
|
914 |
+
if "cls" not in self.emb_mode:
|
915 |
+
cos_sims_data = torch.mean(gene_cos_sims, dim=1)
|
916 |
+
else:
|
917 |
+
original_cls_emb = full_original_emb[:,0,:]
|
918 |
+
perturbation_cls_emb = full_perturbation_emb[:,0,:]
|
919 |
+
cos_sims_data = pu.quant_cos_sims(
|
920 |
+
perturbation_cls_emb,
|
921 |
+
original_cls_emb,
|
922 |
+
self.cell_states_to_model,
|
923 |
+
self.state_embs_dict,
|
924 |
+
emb_mode="cell",
|
925 |
+
)
|
926 |
cos_sims_dict = self.update_perturbation_dictionary(
|
927 |
cos_sims_dict,
|
928 |
cos_sims_data,
|
|
|
1015 |
for i, cos in enumerate(cos_sims_data.tolist()):
|
1016 |
cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
|
1017 |
|
1018 |
+
return cos_sims_dict
|