Update geneformer/emb_extractor.py
Browse files- geneformer/emb_extractor.py +54 -12
geneformer/emb_extractor.py
CHANGED
@@ -38,12 +38,14 @@ def get_embs(
|
|
38 |
layer_to_quant,
|
39 |
pad_token_id,
|
40 |
forward_batch_size,
|
|
|
|
|
41 |
summary_stat=None,
|
42 |
silent=False,
|
43 |
):
|
44 |
model_input_size = pu.get_model_input_size(model)
|
45 |
total_batch_length = len(filtered_input_data)
|
46 |
-
|
47 |
if summary_stat is None:
|
48 |
embs_list = []
|
49 |
elif summary_stat is not None:
|
@@ -67,9 +69,25 @@ def get_embs(
|
|
67 |
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
68 |
}
|
69 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
70 |
overall_max_len = 0
|
71 |
-
|
72 |
-
for i in trange(0, total_batch_length, forward_batch_size, leave=(not silent)):
|
73 |
max_range = min(i + forward_batch_size, total_batch_length)
|
74 |
|
75 |
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
@@ -90,9 +108,16 @@ def get_embs(
|
|
90 |
)
|
91 |
|
92 |
embs_i = outputs.hidden_states[layer_to_quant]
|
93 |
-
|
94 |
if emb_mode == "cell":
|
95 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
96 |
if summary_stat is None:
|
97 |
embs_list.append(mean_embs)
|
98 |
elif summary_stat is not None:
|
@@ -121,7 +146,13 @@ def get_embs(
|
|
121 |
accumulate_tdigests(
|
122 |
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
123 |
)
|
124 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
125 |
overall_max_len = max(overall_max_len, max_len)
|
126 |
del outputs
|
127 |
del minibatch
|
@@ -129,7 +160,8 @@ def get_embs(
|
|
129 |
del embs_i
|
130 |
|
131 |
torch.cuda.empty_cache()
|
132 |
-
|
|
|
133 |
if summary_stat is None:
|
134 |
if emb_mode == "cell":
|
135 |
embs_stack = torch.cat(embs_list, dim=0)
|
@@ -142,6 +174,8 @@ def get_embs(
|
|
142 |
1,
|
143 |
pu.pad_3d_tensor,
|
144 |
)
|
|
|
|
|
145 |
|
146 |
# calculate summary stat embs from approximated tdigests
|
147 |
elif summary_stat is not None:
|
@@ -348,7 +382,7 @@ def plot_heatmap(embs_df, emb_dims, label, output_file, kwargs_dict):
|
|
348 |
bbox_to_anchor=(0.5, 1),
|
349 |
facecolor="white",
|
350 |
)
|
351 |
-
|
352 |
plt.savefig(output_file, bbox_inches="tight")
|
353 |
|
354 |
|
@@ -356,7 +390,7 @@ class EmbExtractor:
|
|
356 |
valid_option_dict = {
|
357 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
358 |
"num_classes": {int},
|
359 |
-
"emb_mode": {"cell", "gene"},
|
360 |
"cell_emb_style": {"mean_pool"},
|
361 |
"gene_emb_style": {"mean_pool"},
|
362 |
"filter_data": {None, dict},
|
@@ -365,6 +399,7 @@ class EmbExtractor:
|
|
365 |
"emb_label": {None, list},
|
366 |
"labels_to_plot": {None, list},
|
367 |
"forward_batch_size": {int},
|
|
|
368 |
"nproc": {int},
|
369 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
370 |
}
|
@@ -384,7 +419,7 @@ class EmbExtractor:
|
|
384 |
forward_batch_size=100,
|
385 |
nproc=4,
|
386 |
summary_stat=None,
|
387 |
-
token_dictionary_file=
|
388 |
):
|
389 |
"""
|
390 |
Initialize embedding extractor.
|
@@ -434,6 +469,7 @@ class EmbExtractor:
|
|
434 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
435 |
| Non-exact is slower but more memory-efficient.
|
436 |
token_dictionary_file : Path
|
|
|
437 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
438 |
|
439 |
**Examples:**
|
@@ -463,6 +499,7 @@ class EmbExtractor:
|
|
463 |
self.emb_layer = emb_layer
|
464 |
self.emb_label = emb_label
|
465 |
self.labels_to_plot = labels_to_plot
|
|
|
466 |
self.forward_batch_size = forward_batch_size
|
467 |
self.nproc = nproc
|
468 |
if (summary_stat is not None) and ("exact" in summary_stat):
|
@@ -475,6 +512,8 @@ class EmbExtractor:
|
|
475 |
self.validate_options()
|
476 |
|
477 |
# load token dictionary (Ensembl IDs:token)
|
|
|
|
|
478 |
with open(token_dictionary_file, "rb") as f:
|
479 |
self.gene_token_dict = pickle.load(f)
|
480 |
|
@@ -490,7 +529,7 @@ class EmbExtractor:
|
|
490 |
continue
|
491 |
valid_type = False
|
492 |
for option in valid_options:
|
493 |
-
if (option in [int, list, dict, bool]) and isinstance(
|
494 |
attr_value, option
|
495 |
):
|
496 |
valid_type = True
|
@@ -570,6 +609,7 @@ class EmbExtractor:
|
|
570 |
layer_to_quant,
|
571 |
self.pad_token_id,
|
572 |
self.forward_batch_size,
|
|
|
573 |
self.summary_stat,
|
574 |
)
|
575 |
|
@@ -584,6 +624,8 @@ class EmbExtractor:
|
|
584 |
elif self.summary_stat is not None:
|
585 |
embs_df = pd.DataFrame(embs).T
|
586 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
|
|
|
|
587 |
|
588 |
# save embeddings to output_path
|
589 |
if cell_state is None:
|
@@ -781,7 +823,7 @@ class EmbExtractor:
|
|
781 |
f"not present in provided embeddings dataframe."
|
782 |
)
|
783 |
continue
|
784 |
-
output_prefix_label =
|
785 |
output_file = (
|
786 |
Path(output_directory) / output_prefix_label
|
787 |
).with_suffix(".pdf")
|
|
|
38 |
layer_to_quant,
|
39 |
pad_token_id,
|
40 |
forward_batch_size,
|
41 |
+
token_gene_dict,
|
42 |
+
special_token=False,
|
43 |
summary_stat=None,
|
44 |
silent=False,
|
45 |
):
|
46 |
model_input_size = pu.get_model_input_size(model)
|
47 |
total_batch_length = len(filtered_input_data)
|
48 |
+
|
49 |
if summary_stat is None:
|
50 |
embs_list = []
|
51 |
elif summary_stat is not None:
|
|
|
69 |
k: [TDigest() for _ in range(emb_dims)] for k in gene_set
|
70 |
}
|
71 |
|
72 |
+
# Check if CLS and EOS token is present in the token dictionary
|
73 |
+
cls_present = any("<cls>" in value for value in token_gene_dict.values())
|
74 |
+
eos_present = any("<eos>" in value for value in token_gene_dict.values())
|
75 |
+
if emb_mode == "cls":
|
76 |
+
assert cls_present, "<cls> token missing in token dictionary"
|
77 |
+
# Check to make sure that the first token of the filtered input data is cls token
|
78 |
+
for key, value in token_gene_dict.items():
|
79 |
+
if value == "<cls>":
|
80 |
+
cls_token_id = key
|
81 |
+
assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
|
82 |
+
else:
|
83 |
+
if cls_present:
|
84 |
+
logger.warning("CLS token present in token dictionary, excluding from average")
|
85 |
+
if eos_present:
|
86 |
+
logger.warning("EOS token present in token dictionary, excluding from average")
|
87 |
+
|
88 |
overall_max_len = 0
|
89 |
+
|
90 |
+
for i in trange(0, total_batch_length, forward_batch_size, leave = (not silent)):
|
91 |
max_range = min(i + forward_batch_size, total_batch_length)
|
92 |
|
93 |
minibatch = filtered_input_data.select([i for i in range(i, max_range)])
|
|
|
108 |
)
|
109 |
|
110 |
embs_i = outputs.hidden_states[layer_to_quant]
|
111 |
+
|
112 |
if emb_mode == "cell":
|
113 |
+
if cls_present:
|
114 |
+
non_cls_embs = embs_i[:, 1:, :] # Get all layers except the embs
|
115 |
+
if eos_present:
|
116 |
+
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 2)
|
117 |
+
else:
|
118 |
+
mean_embs = pu.mean_nonpadding_embs(non_cls_embs, original_lens - 1)
|
119 |
+
else:
|
120 |
+
mean_embs = pu.mean_nonpadding_embs(embs_i, original_lens)
|
121 |
if summary_stat is None:
|
122 |
embs_list.append(mean_embs)
|
123 |
elif summary_stat is not None:
|
|
|
146 |
accumulate_tdigests(
|
147 |
embs_tdigests_dict[int(k)], dict_h[k], emb_dims
|
148 |
)
|
149 |
+
del embs_h
|
150 |
+
del dict_h
|
151 |
+
elif emb_mode == "cls":
|
152 |
+
cls_embs = embs_i[:,0,:] # CLS token layer
|
153 |
+
embs_list.append(cls_embs)
|
154 |
+
del cls_embs
|
155 |
+
|
156 |
overall_max_len = max(overall_max_len, max_len)
|
157 |
del outputs
|
158 |
del minibatch
|
|
|
160 |
del embs_i
|
161 |
|
162 |
torch.cuda.empty_cache()
|
163 |
+
|
164 |
+
|
165 |
if summary_stat is None:
|
166 |
if emb_mode == "cell":
|
167 |
embs_stack = torch.cat(embs_list, dim=0)
|
|
|
174 |
1,
|
175 |
pu.pad_3d_tensor,
|
176 |
)
|
177 |
+
elif emb_mode == "cls":
|
178 |
+
embs_stack = torch.cat(embs_list, dim=0)
|
179 |
|
180 |
# calculate summary stat embs from approximated tdigests
|
181 |
elif summary_stat is not None:
|
|
|
382 |
bbox_to_anchor=(0.5, 1),
|
383 |
facecolor="white",
|
384 |
)
|
385 |
+
print(f"Output file: {output_file}")
|
386 |
plt.savefig(output_file, bbox_inches="tight")
|
387 |
|
388 |
|
|
|
390 |
valid_option_dict = {
|
391 |
"model_type": {"Pretrained", "GeneClassifier", "CellClassifier"},
|
392 |
"num_classes": {int},
|
393 |
+
"emb_mode": {"cell", "gene", "cls"},
|
394 |
"cell_emb_style": {"mean_pool"},
|
395 |
"gene_emb_style": {"mean_pool"},
|
396 |
"filter_data": {None, dict},
|
|
|
399 |
"emb_label": {None, list},
|
400 |
"labels_to_plot": {None, list},
|
401 |
"forward_batch_size": {int},
|
402 |
+
"token_dictionary_file" : {None, str},
|
403 |
"nproc": {int},
|
404 |
"summary_stat": {None, "mean", "median", "exact_mean", "exact_median"},
|
405 |
}
|
|
|
419 |
forward_batch_size=100,
|
420 |
nproc=4,
|
421 |
summary_stat=None,
|
422 |
+
token_dictionary_file=None,
|
423 |
):
|
424 |
"""
|
425 |
Initialize embedding extractor.
|
|
|
469 |
| Non-exact recommended if encountering memory constraints while generating goal embedding positions.
|
470 |
| Non-exact is slower but more memory-efficient.
|
471 |
token_dictionary_file : Path
|
472 |
+
| Default is to the geneformer token dictionary
|
473 |
| Path to pickle file containing token dictionary (Ensembl ID:token).
|
474 |
|
475 |
**Examples:**
|
|
|
499 |
self.emb_layer = emb_layer
|
500 |
self.emb_label = emb_label
|
501 |
self.labels_to_plot = labels_to_plot
|
502 |
+
self.token_dictionary_file = token_dictionary_file
|
503 |
self.forward_batch_size = forward_batch_size
|
504 |
self.nproc = nproc
|
505 |
if (summary_stat is not None) and ("exact" in summary_stat):
|
|
|
512 |
self.validate_options()
|
513 |
|
514 |
# load token dictionary (Ensembl IDs:token)
|
515 |
+
if self.token_dictionary_file is None:
|
516 |
+
token_dictionary_file = TOKEN_DICTIONARY_FILE
|
517 |
with open(token_dictionary_file, "rb") as f:
|
518 |
self.gene_token_dict = pickle.load(f)
|
519 |
|
|
|
529 |
continue
|
530 |
valid_type = False
|
531 |
for option in valid_options:
|
532 |
+
if (option in [int, list, dict, bool, str]) and isinstance(
|
533 |
attr_value, option
|
534 |
):
|
535 |
valid_type = True
|
|
|
609 |
layer_to_quant,
|
610 |
self.pad_token_id,
|
611 |
self.forward_batch_size,
|
612 |
+
self.token_gene_dict,
|
613 |
self.summary_stat,
|
614 |
)
|
615 |
|
|
|
624 |
elif self.summary_stat is not None:
|
625 |
embs_df = pd.DataFrame(embs).T
|
626 |
embs_df.index = [self.token_gene_dict[token] for token in embs_df.index]
|
627 |
+
elif self.emb_mode == "cls":
|
628 |
+
embs_df = label_cell_embs(embs, downsampled_data, self.emb_label)
|
629 |
|
630 |
# save embeddings to output_path
|
631 |
if cell_state is None:
|
|
|
823 |
f"not present in provided embeddings dataframe."
|
824 |
)
|
825 |
continue
|
826 |
+
output_prefix_label = output_prefix + f"_umap_{label}"
|
827 |
output_file = (
|
828 |
Path(output_directory) / output_prefix_label
|
829 |
).with_suffix(".pdf")
|