Update geneformer/in_silico_perturber.py
#355
by
hchen725
- opened
- geneformer/in_silico_perturber.py +136 -43
- geneformer/perturber_utils.py +37 -10
geneformer/in_silico_perturber.py
CHANGED
@@ -66,7 +66,7 @@ class InSilicoPerturber:
|
|
66 |
"anchor_gene": {None, str},
|
67 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
68 |
"num_classes": {int},
|
69 |
-
"emb_mode": {"cell", "cell_and_gene"},
|
70 |
"cell_emb_style": {"mean_pool"},
|
71 |
"filter_data": {None, dict},
|
72 |
"cell_states_to_model": {None, dict},
|
@@ -74,6 +74,7 @@ class InSilicoPerturber:
|
|
74 |
"max_ncells": {None, int},
|
75 |
"cell_inds_to_perturb": {"all", dict},
|
76 |
"emb_layer": {-1, 0},
|
|
|
77 |
"forward_batch_size": {int},
|
78 |
"nproc": {int},
|
79 |
}
|
@@ -97,7 +98,7 @@ class InSilicoPerturber:
|
|
97 |
emb_layer=-1,
|
98 |
forward_batch_size=100,
|
99 |
nproc=4,
|
100 |
-
token_dictionary_file=
|
101 |
):
|
102 |
"""
|
103 |
Initialize in silico perturber.
|
@@ -137,11 +138,11 @@ class InSilicoPerturber:
|
|
137 |
num_classes : int
|
138 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
139 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
140 |
-
emb_mode : {"cell", "cell_and_gene"}
|
141 |
-
| Whether to output impact of perturbation on cell and/or gene embeddings.
|
142 |
| Gene embedding shifts only available as compared to original cell, not comparing to goal state.
|
143 |
cell_emb_style : "mean_pool"
|
144 |
-
| Method for summarizing cell embeddings.
|
145 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
146 |
filter_data : None, dict
|
147 |
| Default is to use all input data for in silico perturbation study.
|
@@ -222,15 +223,32 @@ class InSilicoPerturber:
|
|
222 |
self.emb_layer = emb_layer
|
223 |
self.forward_batch_size = forward_batch_size
|
224 |
self.nproc = nproc
|
|
|
225 |
|
226 |
self.validate_options()
|
227 |
|
228 |
# load token dictionary (Ensembl IDs:token)
|
|
|
|
|
229 |
with open(token_dictionary_file, "rb") as f:
|
230 |
self.gene_token_dict = pickle.load(f)
|
|
|
231 |
|
232 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
if self.anchor_gene is None:
|
235 |
self.anchor_token = None
|
236 |
else:
|
@@ -287,7 +305,7 @@ class InSilicoPerturber:
|
|
287 |
continue
|
288 |
valid_type = False
|
289 |
for option in valid_options:
|
290 |
-
if (option in [bool, int, list, dict]) and isinstance(
|
291 |
attr_value, option
|
292 |
):
|
293 |
valid_type = True
|
@@ -428,12 +446,21 @@ class InSilicoPerturber:
|
|
428 |
self.max_len = pu.get_model_input_size(model)
|
429 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
430 |
|
431 |
-
|
432 |
### filter input data ###
|
433 |
# general filtering of input data based on filter_data argument
|
434 |
filtered_input_data = pu.load_and_filter(
|
435 |
self.filter_data, self.nproc, input_data_file
|
436 |
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
437 |
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
438 |
|
439 |
if self.perturb_group is True:
|
@@ -506,7 +533,7 @@ class InSilicoPerturber:
|
|
506 |
if self.perturb_type == "delete":
|
507 |
example = pu.delete_indices(example)
|
508 |
elif self.perturb_type == "overexpress":
|
509 |
-
example = pu.overexpress_tokens(example, self.max_len)
|
510 |
example["n_overflow"] = pu.calc_n_overflow(
|
511 |
self.max_len,
|
512 |
example["length"],
|
@@ -527,7 +554,6 @@ class InSilicoPerturber:
|
|
527 |
perturbed_data = filtered_input_data.map(
|
528 |
make_group_perturbation_batch, num_proc=self.nproc
|
529 |
)
|
530 |
-
|
531 |
if self.perturb_type == "overexpress":
|
532 |
filtered_input_data = filtered_input_data.add_column(
|
533 |
"n_overflow", perturbed_data["n_overflow"]
|
@@ -537,9 +563,14 @@ class InSilicoPerturber:
|
|
537 |
# then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
|
538 |
# (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
|
539 |
# rather than only adding 2048)
|
540 |
-
|
541 |
-
|
542 |
-
|
|
|
|
|
|
|
|
|
|
|
543 |
|
544 |
if self.emb_mode == "cell_and_gene":
|
545 |
stored_gene_embs_dict = defaultdict(list)
|
@@ -560,6 +591,7 @@ class InSilicoPerturber:
|
|
560 |
layer_to_quant,
|
561 |
self.pad_token_id,
|
562 |
self.forward_batch_size,
|
|
|
563 |
summary_stat=None,
|
564 |
silent=True,
|
565 |
)
|
@@ -579,24 +611,32 @@ class InSilicoPerturber:
|
|
579 |
layer_to_quant,
|
580 |
self.pad_token_id,
|
581 |
self.forward_batch_size,
|
|
|
582 |
summary_stat=None,
|
583 |
silent=True,
|
584 |
)
|
585 |
|
586 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
587 |
if self.perturb_type == "overexpress":
|
588 |
perturbation_emb = full_perturbation_emb[
|
589 |
-
:, len(self.tokens_to_perturb) :, :
|
590 |
]
|
591 |
-
|
592 |
elif self.perturb_type == "delete":
|
593 |
perturbation_emb = full_perturbation_emb[
|
594 |
-
:, : max(perturbation_batch["length"]), :
|
595 |
]
|
596 |
|
597 |
n_perturbation_genes = perturbation_emb.size()[1]
|
598 |
|
599 |
-
# if no goal states, the cosine
|
600 |
if (
|
601 |
self.cell_states_to_model is None
|
602 |
or self.emb_mode == "cell_and_gene"
|
@@ -611,16 +651,22 @@ class InSilicoPerturber:
|
|
611 |
|
612 |
# if there are goal states, the cosine similarities are the cell cosine similarities
|
613 |
if self.cell_states_to_model is not None:
|
614 |
-
|
615 |
-
|
616 |
-
|
617 |
-
|
618 |
-
|
619 |
-
|
620 |
-
|
621 |
-
|
622 |
-
|
623 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
624 |
cell_cos_sims = pu.quant_cos_sims(
|
625 |
perturbation_cell_emb,
|
626 |
original_cell_emb,
|
@@ -640,6 +686,9 @@ class InSilicoPerturber:
|
|
640 |
]
|
641 |
for genes in gene_list
|
642 |
]
|
|
|
|
|
|
|
643 |
|
644 |
for cell_i, genes in enumerate(gene_list):
|
645 |
for gene_j, affected_gene in enumerate(genes):
|
@@ -672,9 +721,21 @@ class InSilicoPerturber:
|
|
672 |
]
|
673 |
else:
|
674 |
nonpadding_lens = perturbation_batch["length"]
|
675 |
-
|
676 |
-
|
677 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
678 |
cos_sims_dict = self.update_perturbation_dictionary(
|
679 |
cos_sims_dict,
|
680 |
cos_sims_data,
|
@@ -694,9 +755,15 @@ class InSilicoPerturber:
|
|
694 |
)
|
695 |
del minibatch
|
696 |
del perturbation_batch
|
|
|
697 |
del original_emb
|
|
|
698 |
del perturbation_emb
|
699 |
del cos_sims_data
|
|
|
|
|
|
|
|
|
700 |
|
701 |
torch.cuda.empty_cache()
|
702 |
|
@@ -738,6 +805,7 @@ class InSilicoPerturber:
|
|
738 |
layer_to_quant,
|
739 |
self.pad_token_id,
|
740 |
self.forward_batch_size,
|
|
|
741 |
summary_stat=None,
|
742 |
silent=True,
|
743 |
)
|
@@ -756,6 +824,7 @@ class InSilicoPerturber:
|
|
756 |
self.anchor_token,
|
757 |
self.combos,
|
758 |
self.nproc,
|
|
|
759 |
)
|
760 |
|
761 |
full_perturbation_emb = get_embs(
|
@@ -765,21 +834,28 @@ class InSilicoPerturber:
|
|
765 |
layer_to_quant,
|
766 |
self.pad_token_id,
|
767 |
self.forward_batch_size,
|
|
|
768 |
summary_stat=None,
|
769 |
silent=True,
|
770 |
)
|
771 |
|
772 |
num_inds_perturbed = 1 + self.combos
|
773 |
-
# need to remove overexpressed gene to quantify cosine shifts
|
|
|
|
|
|
|
|
|
774 |
if self.perturb_type == "overexpress":
|
775 |
-
perturbation_emb = full_perturbation_emb[:, num_inds_perturbed:, :]
|
776 |
gene_list = gene_list[
|
777 |
-
num_inds_perturbed:
|
778 |
-
] # index 0 is not overexpressed
|
779 |
|
780 |
elif self.perturb_type == "delete":
|
781 |
-
perturbation_emb = full_perturbation_emb
|
|
|
782 |
|
|
|
783 |
original_batch = pu.make_comparison_batch(
|
784 |
full_original_emb, indices_to_perturb, perturb_group=False
|
785 |
)
|
@@ -792,13 +868,19 @@ class InSilicoPerturber:
|
|
792 |
self.state_embs_dict,
|
793 |
emb_mode="gene",
|
794 |
)
|
|
|
795 |
if self.cell_states_to_model is not None:
|
796 |
-
|
797 |
-
|
798 |
-
|
799 |
-
|
800 |
-
|
801 |
-
|
|
|
|
|
|
|
|
|
|
|
802 |
|
803 |
cell_cos_sims = pu.quant_cos_sims(
|
804 |
perturbation_cell_emb,
|
@@ -829,7 +911,18 @@ class InSilicoPerturber:
|
|
829 |
] = gene_cos_sims[perturbation_i, gene_j].item()
|
830 |
|
831 |
if self.cell_states_to_model is None:
|
832 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
833 |
cos_sims_dict = self.update_perturbation_dictionary(
|
834 |
cos_sims_dict,
|
835 |
cos_sims_data,
|
@@ -922,4 +1015,4 @@ class InSilicoPerturber:
|
|
922 |
for i, cos in enumerate(cos_sims_data.tolist()):
|
923 |
cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
|
924 |
|
925 |
-
return cos_sims_dict
|
|
|
66 |
"anchor_gene": {None, str},
|
67 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
68 |
"num_classes": {int},
|
69 |
+
"emb_mode": {"cls", "cell", "cls_and_gene", "cell_and_gene"},
|
70 |
"cell_emb_style": {"mean_pool"},
|
71 |
"filter_data": {None, dict},
|
72 |
"cell_states_to_model": {None, dict},
|
|
|
74 |
"max_ncells": {None, int},
|
75 |
"cell_inds_to_perturb": {"all", dict},
|
76 |
"emb_layer": {-1, 0},
|
77 |
+
"token_dictionary_file" : {None, str},
|
78 |
"forward_batch_size": {int},
|
79 |
"nproc": {int},
|
80 |
}
|
|
|
98 |
emb_layer=-1,
|
99 |
forward_batch_size=100,
|
100 |
nproc=4,
|
101 |
+
token_dictionary_file=None,
|
102 |
):
|
103 |
"""
|
104 |
Initialize in silico perturber.
|
|
|
138 |
num_classes : int
|
139 |
| If model is a gene or cell classifier, specify number of classes it was trained to classify.
|
140 |
| For the pretrained Geneformer model, number of classes is 0 as it is not a classifier.
|
141 |
+
emb_mode : {"cls", "cell", "cls_and_gene","cell_and_gene"}
|
142 |
+
| Whether to output impact of perturbation on CLS token, cell, and/or gene embeddings.
|
143 |
| Gene embedding shifts only available as compared to original cell, not comparing to goal state.
|
144 |
cell_emb_style : "mean_pool"
|
145 |
+
| Method for summarizing cell embeddings if not using CLS token.
|
146 |
| Currently only option is mean pooling of gene embeddings for given cell.
|
147 |
filter_data : None, dict
|
148 |
| Default is to use all input data for in silico perturbation study.
|
|
|
223 |
self.emb_layer = emb_layer
|
224 |
self.forward_batch_size = forward_batch_size
|
225 |
self.nproc = nproc
|
226 |
+
self.token_dictionary_file = token_dictionary_file
|
227 |
|
228 |
self.validate_options()
|
229 |
|
230 |
# load token dictionary (Ensembl IDs:token)
|
231 |
+
if self.token_dictionary_file is None:
|
232 |
+
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
233 |
with open(token_dictionary_file, "rb") as f:
|
234 |
self.gene_token_dict = pickle.load(f)
|
235 |
+
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
236 |
|
237 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
238 |
|
239 |
+
|
240 |
+
# Identify if special token is present in the token dictionary
|
241 |
+
lowercase_token_gene_dict = {k: v.lower() for k, v in self.token_gene_dict.items()}
|
242 |
+
cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
|
243 |
+
eos_present = any("eos" in value for value in lowercase_token_gene_dict.values())
|
244 |
+
if cls_present or eos_present:
|
245 |
+
self.special_token = True
|
246 |
+
else:
|
247 |
+
if "cls" in self.emb_mode:
|
248 |
+
logger.error(f"emb_mode set to {self.emb_mode} but <cls> token not in token dictionary.")
|
249 |
+
raise
|
250 |
+
self.special_token = False
|
251 |
+
|
252 |
if self.anchor_gene is None:
|
253 |
self.anchor_token = None
|
254 |
else:
|
|
|
305 |
continue
|
306 |
valid_type = False
|
307 |
for option in valid_options:
|
308 |
+
if (option in [bool, int, list, dict, str]) and isinstance(
|
309 |
attr_value, option
|
310 |
):
|
311 |
valid_type = True
|
|
|
446 |
self.max_len = pu.get_model_input_size(model)
|
447 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
448 |
|
|
|
449 |
### filter input data ###
|
450 |
# general filtering of input data based on filter_data argument
|
451 |
filtered_input_data = pu.load_and_filter(
|
452 |
self.filter_data, self.nproc, input_data_file
|
453 |
)
|
454 |
+
|
455 |
+
# Ensure emb_mode is cls if first token of the filtered input data is cls token
|
456 |
+
if self.special_token:
|
457 |
+
cls_token_id = self.gene_token_dict["<cls>"]
|
458 |
+
if (filtered_input_data["input_ids"][0][0] == cls_token_id) and ("cls" not in self.emb_mode):
|
459 |
+
logger.error(
|
460 |
+
"Emb mode 'cls' or 'cls_and_gene' required when first token is <cls>."
|
461 |
+
)
|
462 |
+
raise
|
463 |
+
|
464 |
filtered_input_data = self.apply_additional_filters(filtered_input_data)
|
465 |
|
466 |
if self.perturb_group is True:
|
|
|
533 |
if self.perturb_type == "delete":
|
534 |
example = pu.delete_indices(example)
|
535 |
elif self.perturb_type == "overexpress":
|
536 |
+
example = pu.overexpress_tokens(example, self.max_len, self.special_token)
|
537 |
example["n_overflow"] = pu.calc_n_overflow(
|
538 |
self.max_len,
|
539 |
example["length"],
|
|
|
554 |
perturbed_data = filtered_input_data.map(
|
555 |
make_group_perturbation_batch, num_proc=self.nproc
|
556 |
)
|
|
|
557 |
if self.perturb_type == "overexpress":
|
558 |
filtered_input_data = filtered_input_data.add_column(
|
559 |
"n_overflow", perturbed_data["n_overflow"]
|
|
|
563 |
# then the perturbed cell will be 2048+0:2046 so we compare it to an original cell 0:2046.
|
564 |
# (otherwise we will be modeling the effect of both deleting 2047 and adding 2048,
|
565 |
# rather than only adding 2048)
|
566 |
+
if self.special_token:
|
567 |
+
filtered_input_data = filtered_input_data.map(
|
568 |
+
pu.truncate_by_n_overflow_special, num_proc=self.nproc
|
569 |
+
)
|
570 |
+
else:
|
571 |
+
filtered_input_data = filtered_input_data.map(
|
572 |
+
pu.truncate_by_n_overflow, num_proc=self.nproc
|
573 |
+
)
|
574 |
|
575 |
if self.emb_mode == "cell_and_gene":
|
576 |
stored_gene_embs_dict = defaultdict(list)
|
|
|
591 |
layer_to_quant,
|
592 |
self.pad_token_id,
|
593 |
self.forward_batch_size,
|
594 |
+
self.token_gene_dict,
|
595 |
summary_stat=None,
|
596 |
silent=True,
|
597 |
)
|
|
|
611 |
layer_to_quant,
|
612 |
self.pad_token_id,
|
613 |
self.forward_batch_size,
|
614 |
+
self.token_gene_dict,
|
615 |
summary_stat=None,
|
616 |
silent=True,
|
617 |
)
|
618 |
|
619 |
+
if "cls" not in self.emb_mode:
|
620 |
+
start = 0
|
621 |
+
else:
|
622 |
+
start = 1
|
623 |
+
|
624 |
+
# remove overexpressed genes and cls
|
625 |
+
original_emb = original_emb[
|
626 |
+
:, start :, :
|
627 |
+
]
|
628 |
if self.perturb_type == "overexpress":
|
629 |
perturbation_emb = full_perturbation_emb[
|
630 |
+
:, start+len(self.tokens_to_perturb) :, :
|
631 |
]
|
|
|
632 |
elif self.perturb_type == "delete":
|
633 |
perturbation_emb = full_perturbation_emb[
|
634 |
+
:, start : max(perturbation_batch["length"]), :
|
635 |
]
|
636 |
|
637 |
n_perturbation_genes = perturbation_emb.size()[1]
|
638 |
|
639 |
+
# if no goal states, the cosine similarities are the mean of gene cosine similarities
|
640 |
if (
|
641 |
self.cell_states_to_model is None
|
642 |
or self.emb_mode == "cell_and_gene"
|
|
|
651 |
|
652 |
# if there are goal states, the cosine similarities are the cell cosine similarities
|
653 |
if self.cell_states_to_model is not None:
|
654 |
+
if "cls" not in self.emb_mode:
|
655 |
+
original_cell_emb = pu.mean_nonpadding_embs(
|
656 |
+
full_original_emb,
|
657 |
+
torch.tensor(minibatch["length"], device="cuda"),
|
658 |
+
dim=1,
|
659 |
+
)
|
660 |
+
perturbation_cell_emb = pu.mean_nonpadding_embs(
|
661 |
+
full_perturbation_emb,
|
662 |
+
torch.tensor(perturbation_batch["length"], device="cuda"),
|
663 |
+
dim=1,
|
664 |
+
)
|
665 |
+
else:
|
666 |
+
# get cls emb
|
667 |
+
original_cell_emb = full_original_emb[:,0,:]
|
668 |
+
perturbation_cell_emb = full_perturbation_emb[:,0,:]
|
669 |
+
|
670 |
cell_cos_sims = pu.quant_cos_sims(
|
671 |
perturbation_cell_emb,
|
672 |
original_cell_emb,
|
|
|
686 |
]
|
687 |
for genes in gene_list
|
688 |
]
|
689 |
+
# remove CLS if present
|
690 |
+
if "cls" in self.emb_mode:
|
691 |
+
gene_list = gene_list[1:]
|
692 |
|
693 |
for cell_i, genes in enumerate(gene_list):
|
694 |
for gene_j, affected_gene in enumerate(genes):
|
|
|
721 |
]
|
722 |
else:
|
723 |
nonpadding_lens = perturbation_batch["length"]
|
724 |
+
if "cls" not in self.emb_mode:
|
725 |
+
cos_sims_data = pu.mean_nonpadding_embs(
|
726 |
+
gene_cos_sims, torch.tensor(nonpadding_lens, device="cuda")
|
727 |
+
)
|
728 |
+
else:
|
729 |
+
original_cls_emb = full_original_emb[:,0,:]
|
730 |
+
perturbation_cls_emb = full_perturbation_emb[:,0,:]
|
731 |
+
cos_sims_data = pu.quant_cos_sims(
|
732 |
+
perturbation_cls_emb,
|
733 |
+
original_cls_emb,
|
734 |
+
self.cell_states_to_model,
|
735 |
+
self.state_embs_dict,
|
736 |
+
emb_mode="cell",
|
737 |
+
)
|
738 |
+
|
739 |
cos_sims_dict = self.update_perturbation_dictionary(
|
740 |
cos_sims_dict,
|
741 |
cos_sims_data,
|
|
|
755 |
)
|
756 |
del minibatch
|
757 |
del perturbation_batch
|
758 |
+
del full_original_emb
|
759 |
del original_emb
|
760 |
+
del full_perturbation_emb
|
761 |
del perturbation_emb
|
762 |
del cos_sims_data
|
763 |
+
if "cls" in self.emb_mode:
|
764 |
+
del original_cls_emb
|
765 |
+
del perturbation_cls_emb
|
766 |
+
del cls_cos_sims
|
767 |
|
768 |
torch.cuda.empty_cache()
|
769 |
|
|
|
805 |
layer_to_quant,
|
806 |
self.pad_token_id,
|
807 |
self.forward_batch_size,
|
808 |
+
self.token_gene_dict,
|
809 |
summary_stat=None,
|
810 |
silent=True,
|
811 |
)
|
|
|
824 |
self.anchor_token,
|
825 |
self.combos,
|
826 |
self.nproc,
|
827 |
+
self.special_token,
|
828 |
)
|
829 |
|
830 |
full_perturbation_emb = get_embs(
|
|
|
834 |
layer_to_quant,
|
835 |
self.pad_token_id,
|
836 |
self.forward_batch_size,
|
837 |
+
self.token_gene_dict,
|
838 |
summary_stat=None,
|
839 |
silent=True,
|
840 |
)
|
841 |
|
842 |
num_inds_perturbed = 1 + self.combos
|
843 |
+
# need to remove overexpressed gene and cls to quantify cosine shifts
|
844 |
+
if "cls" not in self.emb_mode:
|
845 |
+
start = 0
|
846 |
+
else:
|
847 |
+
start = 1
|
848 |
if self.perturb_type == "overexpress":
|
849 |
+
perturbation_emb = full_perturbation_emb[:, start+num_inds_perturbed:, :]
|
850 |
gene_list = gene_list[
|
851 |
+
start+num_inds_perturbed:
|
852 |
+
] # cls and index 0 is not overexpressed
|
853 |
|
854 |
elif self.perturb_type == "delete":
|
855 |
+
perturbation_emb = full_perturbation_emb[:, start:, :]
|
856 |
+
gene_list = gene_list[start:]
|
857 |
|
858 |
+
full_original_emb = full_original_emb[:, start:, :]
|
859 |
original_batch = pu.make_comparison_batch(
|
860 |
full_original_emb, indices_to_perturb, perturb_group=False
|
861 |
)
|
|
|
868 |
self.state_embs_dict,
|
869 |
emb_mode="gene",
|
870 |
)
|
871 |
+
|
872 |
if self.cell_states_to_model is not None:
|
873 |
+
if "cls" not in self.emb_mode:
|
874 |
+
original_cell_emb = pu.compute_nonpadded_cell_embedding(
|
875 |
+
full_original_emb, "mean_pool"
|
876 |
+
)
|
877 |
+
perturbation_cell_emb = pu.compute_nonpadded_cell_embedding(
|
878 |
+
full_perturbation_emb, "mean_pool"
|
879 |
+
)
|
880 |
+
else:
|
881 |
+
# get cls emb
|
882 |
+
original_cell_emb = full_original_emb[:,0,:]
|
883 |
+
perturbation_cell_emb = full_perturbation_emb[:,0,:]
|
884 |
|
885 |
cell_cos_sims = pu.quant_cos_sims(
|
886 |
perturbation_cell_emb,
|
|
|
911 |
] = gene_cos_sims[perturbation_i, gene_j].item()
|
912 |
|
913 |
if self.cell_states_to_model is None:
|
914 |
+
if "cls" not in self.emb_mode:
|
915 |
+
cos_sims_data = torch.mean(gene_cos_sims, dim=1)
|
916 |
+
else:
|
917 |
+
original_cls_emb = full_original_emb[:,0,:]
|
918 |
+
perturbation_cls_emb = full_perturbation_emb[:,0,:]
|
919 |
+
cos_sims_data = pu.quant_cos_sims(
|
920 |
+
perturbation_cls_emb,
|
921 |
+
original_cls_emb,
|
922 |
+
self.cell_states_to_model,
|
923 |
+
self.state_embs_dict,
|
924 |
+
emb_mode="cell",
|
925 |
+
)
|
926 |
cos_sims_dict = self.update_perturbation_dictionary(
|
927 |
cos_sims_dict,
|
928 |
cos_sims_data,
|
|
|
1015 |
for i, cos in enumerate(cos_sims_data.tolist()):
|
1016 |
cos_sims_dict[(gene_list[i], "cell_emb")].append(cos)
|
1017 |
|
1018 |
+
return cos_sims_dict
|
geneformer/perturber_utils.py
CHANGED
@@ -228,16 +228,32 @@ def overexpress_indices(example):
|
|
228 |
example["length"] = len(example["input_ids"])
|
229 |
return example
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
233 |
-
def overexpress_tokens(example, max_len):
|
234 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
235 |
if example["perturb_index"] != [-100]:
|
236 |
example = delete_indices(example)
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
# truncate to max input size, must also truncate original emb to be comparable
|
243 |
if len(example["input_ids"]) > max_len:
|
@@ -259,6 +275,12 @@ def truncate_by_n_overflow(example):
|
|
259 |
example["length"] = len(example["input_ids"])
|
260 |
return example
|
261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
264 |
# indices_to_remove is list of indices to remove
|
@@ -321,7 +343,7 @@ def remove_perturbed_indices_set(
|
|
321 |
|
322 |
|
323 |
def make_perturbation_batch(
|
324 |
-
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
325 |
) -> tuple[Dataset, List[int]]:
|
326 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
327 |
if perturb_type in ["overexpress", "activate"]:
|
@@ -383,9 +405,14 @@ def make_perturbation_batch(
|
|
383 |
delete_indices, num_proc=num_proc_i
|
384 |
)
|
385 |
elif perturb_type == "overexpress":
|
386 |
-
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
389 |
|
390 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
391 |
|
@@ -758,4 +785,4 @@ class GeneIdHandler:
|
|
758 |
return self.ens_to_symbol(self.token_to_ens(token))
|
759 |
|
760 |
def symbol_to_token(self, symbol):
|
761 |
-
return self.ens_to_token(self.symbol_to_ens(symbol))
|
|
|
228 |
example["length"] = len(example["input_ids"])
|
229 |
return example
|
230 |
|
231 |
+
# if CLS token present, move to 1st rather than 0th position
|
232 |
+
def overexpress_indices_special(example):
|
233 |
+
indices = example["perturb_index"]
|
234 |
+
if any(isinstance(el, list) for el in indices):
|
235 |
+
indices = flatten_list(indices)
|
236 |
+
for index in sorted(indices, reverse=True):
|
237 |
+
example["input_ids"].insert(1, example["input_ids"].pop(index))
|
238 |
+
|
239 |
+
example["length"] = len(example["input_ids"])
|
240 |
+
return example
|
241 |
|
242 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
243 |
+
def overexpress_tokens(example, max_len, special_token):
|
244 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
245 |
if example["perturb_index"] != [-100]:
|
246 |
example = delete_indices(example)
|
247 |
+
if special_token:
|
248 |
+
[
|
249 |
+
example["input_ids"].insert(1, token)
|
250 |
+
for token in example["tokens_to_perturb"][::-1]
|
251 |
+
]
|
252 |
+
else:
|
253 |
+
[
|
254 |
+
example["input_ids"].insert(0, token)
|
255 |
+
for token in example["tokens_to_perturb"][::-1]
|
256 |
+
]
|
257 |
|
258 |
# truncate to max input size, must also truncate original emb to be comparable
|
259 |
if len(example["input_ids"]) > max_len:
|
|
|
275 |
example["length"] = len(example["input_ids"])
|
276 |
return example
|
277 |
|
278 |
+
def truncate_by_n_overflow_special(example):
|
279 |
+
new_max_len = example["length"] - example["n_overflow"]
|
280 |
+
example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]]
|
281 |
+
example["length"] = len(example["input_ids"])
|
282 |
+
return example
|
283 |
+
|
284 |
|
285 |
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
286 |
# indices_to_remove is list of indices to remove
|
|
|
343 |
|
344 |
|
345 |
def make_perturbation_batch(
|
346 |
+
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc, special_token
|
347 |
) -> tuple[Dataset, List[int]]:
|
348 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
349 |
if perturb_type in ["overexpress", "activate"]:
|
|
|
405 |
delete_indices, num_proc=num_proc_i
|
406 |
)
|
407 |
elif perturb_type == "overexpress":
|
408 |
+
if special_token:
|
409 |
+
perturbation_dataset = perturbation_dataset.map(
|
410 |
+
overexpress_indices_special, num_proc=num_proc_i
|
411 |
+
)
|
412 |
+
else:
|
413 |
+
perturbation_dataset = perturbation_dataset.map(
|
414 |
+
overexpress_indices, num_proc=num_proc_i
|
415 |
+
)
|
416 |
|
417 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
418 |
|
|
|
785 |
return self.ens_to_symbol(self.token_to_ens(token))
|
786 |
|
787 |
def symbol_to_token(self, symbol):
|
788 |
+
return self.ens_to_token(self.symbol_to_ens(symbol))
|