Update geneformer/in_silico_perturber.py
Browse filesAdd custom token gene dictionary to in silico pertruber
geneformer/in_silico_perturber.py
CHANGED
@@ -50,7 +50,6 @@ from . import perturber_utils as pu
|
|
50 |
from .emb_extractor import get_embs
|
51 |
from .perturber_utils import TOKEN_DICTIONARY_FILE
|
52 |
|
53 |
-
|
54 |
sns.set()
|
55 |
|
56 |
|
@@ -74,6 +73,7 @@ class InSilicoPerturber:
|
|
74 |
"max_ncells": {None, int},
|
75 |
"cell_inds_to_perturb": {"all", dict},
|
76 |
"emb_layer": {-1, 0},
|
|
|
77 |
"forward_batch_size": {int},
|
78 |
"nproc": {int},
|
79 |
}
|
@@ -95,9 +95,10 @@ class InSilicoPerturber:
|
|
95 |
max_ncells=None,
|
96 |
cell_inds_to_perturb="all",
|
97 |
emb_layer=-1,
|
|
|
98 |
forward_batch_size=100,
|
99 |
nproc=4,
|
100 |
-
|
101 |
):
|
102 |
"""
|
103 |
Initialize in silico perturber.
|
@@ -187,10 +188,6 @@ class InSilicoPerturber:
|
|
187 |
token_dictionary_file : Path
|
188 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
189 |
"""
|
190 |
-
try:
|
191 |
-
set_start_method("spawn")
|
192 |
-
except RuntimeError:
|
193 |
-
pass
|
194 |
|
195 |
self.perturb_type = perturb_type
|
196 |
self.perturb_rank_shift = perturb_rank_shift
|
@@ -220,17 +217,29 @@ class InSilicoPerturber:
|
|
220 |
self.max_ncells = max_ncells
|
221 |
self.cell_inds_to_perturb = cell_inds_to_perturb
|
222 |
self.emb_layer = emb_layer
|
|
|
223 |
self.forward_batch_size = forward_batch_size
|
224 |
self.nproc = nproc
|
225 |
|
226 |
self.validate_options()
|
227 |
|
228 |
# load token dictionary (Ensembl IDs:token)
|
|
|
|
|
229 |
with open(token_dictionary_file, "rb") as f:
|
230 |
self.gene_token_dict = pickle.load(f)
|
|
|
231 |
|
232 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
233 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
234 |
if self.anchor_gene is None:
|
235 |
self.anchor_token = None
|
236 |
else:
|
@@ -287,7 +296,7 @@ class InSilicoPerturber:
|
|
287 |
continue
|
288 |
valid_type = False
|
289 |
for option in valid_options:
|
290 |
-
if (option in [bool, int, list, dict]) and isinstance(
|
291 |
attr_value, option
|
292 |
):
|
293 |
valid_type = True
|
@@ -428,7 +437,6 @@ class InSilicoPerturber:
|
|
428 |
self.max_len = pu.get_model_input_size(model)
|
429 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
430 |
|
431 |
-
|
432 |
### filter input data ###
|
433 |
# general filtering of input data based on filter_data argument
|
434 |
filtered_input_data = pu.load_and_filter(
|
@@ -506,7 +514,7 @@ class InSilicoPerturber:
|
|
506 |
if self.perturb_type == "delete":
|
507 |
example = pu.delete_indices(example)
|
508 |
elif self.perturb_type == "overexpress":
|
509 |
-
example = pu.overexpress_tokens(example, self.max_len)
|
510 |
example["n_overflow"] = pu.calc_n_overflow(
|
511 |
self.max_len,
|
512 |
example["length"],
|
@@ -527,7 +535,6 @@ class InSilicoPerturber:
|
|
527 |
perturbed_data = filtered_input_data.map(
|
528 |
make_group_perturbation_batch, num_proc=self.nproc
|
529 |
)
|
530 |
-
|
531 |
if self.perturb_type == "overexpress":
|
532 |
filtered_input_data = filtered_input_data.add_column(
|
533 |
"n_overflow", perturbed_data["n_overflow"]
|
@@ -560,6 +567,7 @@ class InSilicoPerturber:
|
|
560 |
layer_to_quant,
|
561 |
self.pad_token_id,
|
562 |
self.forward_batch_size,
|
|
|
563 |
summary_stat=None,
|
564 |
silent=True,
|
565 |
)
|
@@ -579,6 +587,7 @@ class InSilicoPerturber:
|
|
579 |
layer_to_quant,
|
580 |
self.pad_token_id,
|
581 |
self.forward_batch_size,
|
|
|
582 |
summary_stat=None,
|
583 |
silent=True,
|
584 |
)
|
@@ -738,6 +747,7 @@ class InSilicoPerturber:
|
|
738 |
layer_to_quant,
|
739 |
self.pad_token_id,
|
740 |
self.forward_batch_size,
|
|
|
741 |
summary_stat=None,
|
742 |
silent=True,
|
743 |
)
|
@@ -765,6 +775,7 @@ class InSilicoPerturber:
|
|
765 |
layer_to_quant,
|
766 |
self.pad_token_id,
|
767 |
self.forward_batch_size,
|
|
|
768 |
summary_stat=None,
|
769 |
silent=True,
|
770 |
)
|
|
|
50 |
from .emb_extractor import get_embs
|
51 |
from .perturber_utils import TOKEN_DICTIONARY_FILE
|
52 |
|
|
|
53 |
sns.set()
|
54 |
|
55 |
|
|
|
73 |
"max_ncells": {None, int},
|
74 |
"cell_inds_to_perturb": {"all", dict},
|
75 |
"emb_layer": {-1, 0},
|
76 |
+
"token_dictionary_file" : {None, str},
|
77 |
"forward_batch_size": {int},
|
78 |
"nproc": {int},
|
79 |
}
|
|
|
95 |
max_ncells=None,
|
96 |
cell_inds_to_perturb="all",
|
97 |
emb_layer=-1,
|
98 |
+
token_dictionary_file=None,
|
99 |
forward_batch_size=100,
|
100 |
nproc=4,
|
101 |
+
|
102 |
):
|
103 |
"""
|
104 |
Initialize in silico perturber.
|
|
|
188 |
token_dictionary_file : Path
|
189 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
190 |
"""
|
|
|
|
|
|
|
|
|
191 |
|
192 |
self.perturb_type = perturb_type
|
193 |
self.perturb_rank_shift = perturb_rank_shift
|
|
|
217 |
self.max_ncells = max_ncells
|
218 |
self.cell_inds_to_perturb = cell_inds_to_perturb
|
219 |
self.emb_layer = emb_layer
|
220 |
+
self.token_dictionary_file = token_dictionary_file
|
221 |
self.forward_batch_size = forward_batch_size
|
222 |
self.nproc = nproc
|
223 |
|
224 |
self.validate_options()
|
225 |
|
226 |
# load token dictionary (Ensembl IDs:token)
|
227 |
+
if self.token_dictionary_file is None:
|
228 |
+
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
229 |
with open(token_dictionary_file, "rb") as f:
|
230 |
self.gene_token_dict = pickle.load(f)
|
231 |
+
self.token_gene_dict = {v: k for k, v in self.gene_token_dict.items()}
|
232 |
|
233 |
self.pad_token_id = self.gene_token_dict.get("<pad>")
|
234 |
|
235 |
+
|
236 |
+
# Identify if special token is present in the token dictionary
|
237 |
+
lowercase_token_gene_dict = {k: v.lower() for k, v in self.token_gene_dict.items()}
|
238 |
+
cls_present = any("cls" in value for value in lowercase_token_gene_dict.values())
|
239 |
+
sep_present = any("sep" in value for value in lowercase_token_gene_dict.values())
|
240 |
+
if cls_present or sep_present:
|
241 |
+
self.special_token = True
|
242 |
+
|
243 |
if self.anchor_gene is None:
|
244 |
self.anchor_token = None
|
245 |
else:
|
|
|
296 |
continue
|
297 |
valid_type = False
|
298 |
for option in valid_options:
|
299 |
+
if (option in [bool, int, list, dict, str]) and isinstance(
|
300 |
attr_value, option
|
301 |
):
|
302 |
valid_type = True
|
|
|
437 |
self.max_len = pu.get_model_input_size(model)
|
438 |
layer_to_quant = pu.quant_layers(model) + self.emb_layer
|
439 |
|
|
|
440 |
### filter input data ###
|
441 |
# general filtering of input data based on filter_data argument
|
442 |
filtered_input_data = pu.load_and_filter(
|
|
|
514 |
if self.perturb_type == "delete":
|
515 |
example = pu.delete_indices(example)
|
516 |
elif self.perturb_type == "overexpress":
|
517 |
+
example = pu.overexpress_tokens(example, self.max_len, self.special_token)
|
518 |
example["n_overflow"] = pu.calc_n_overflow(
|
519 |
self.max_len,
|
520 |
example["length"],
|
|
|
535 |
perturbed_data = filtered_input_data.map(
|
536 |
make_group_perturbation_batch, num_proc=self.nproc
|
537 |
)
|
|
|
538 |
if self.perturb_type == "overexpress":
|
539 |
filtered_input_data = filtered_input_data.add_column(
|
540 |
"n_overflow", perturbed_data["n_overflow"]
|
|
|
567 |
layer_to_quant,
|
568 |
self.pad_token_id,
|
569 |
self.forward_batch_size,
|
570 |
+
self.token_gene_dict,
|
571 |
summary_stat=None,
|
572 |
silent=True,
|
573 |
)
|
|
|
587 |
layer_to_quant,
|
588 |
self.pad_token_id,
|
589 |
self.forward_batch_size,
|
590 |
+
self.token_gene_dict,
|
591 |
summary_stat=None,
|
592 |
silent=True,
|
593 |
)
|
|
|
747 |
layer_to_quant,
|
748 |
self.pad_token_id,
|
749 |
self.forward_batch_size,
|
750 |
+
self.token_gene_dict,
|
751 |
summary_stat=None,
|
752 |
silent=True,
|
753 |
)
|
|
|
775 |
layer_to_quant,
|
776 |
self.pad_token_id,
|
777 |
self.forward_batch_size,
|
778 |
+
self.token_gene_dict,
|
779 |
summary_stat=None,
|
780 |
silent=True,
|
781 |
)
|