ctheodoris
commited on
Commit
•
65d6e69
1
Parent(s):
0680d55
update perturber_utils to account for cls
Browse files- geneformer/perturber_utils.py +37 -10
geneformer/perturber_utils.py
CHANGED
@@ -228,16 +228,32 @@ def overexpress_indices(example):
|
|
228 |
example["length"] = len(example["input_ids"])
|
229 |
return example
|
230 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
231 |
|
232 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
233 |
-
def overexpress_tokens(example, max_len):
|
234 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
235 |
if example["perturb_index"] != [-100]:
|
236 |
example = delete_indices(example)
|
237 |
-
|
238 |
-
|
239 |
-
|
240 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
241 |
|
242 |
# truncate to max input size, must also truncate original emb to be comparable
|
243 |
if len(example["input_ids"]) > max_len:
|
@@ -259,6 +275,12 @@ def truncate_by_n_overflow(example):
|
|
259 |
example["length"] = len(example["input_ids"])
|
260 |
return example
|
261 |
|
|
|
|
|
|
|
|
|
|
|
|
|
262 |
|
263 |
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
264 |
# indices_to_remove is list of indices to remove
|
@@ -321,7 +343,7 @@ def remove_perturbed_indices_set(
|
|
321 |
|
322 |
|
323 |
def make_perturbation_batch(
|
324 |
-
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
|
325 |
) -> tuple[Dataset, List[int]]:
|
326 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
327 |
if perturb_type in ["overexpress", "activate"]:
|
@@ -383,9 +405,14 @@ def make_perturbation_batch(
|
|
383 |
delete_indices, num_proc=num_proc_i
|
384 |
)
|
385 |
elif perturb_type == "overexpress":
|
386 |
-
|
387 |
-
|
388 |
-
|
|
|
|
|
|
|
|
|
|
|
389 |
|
390 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
391 |
|
@@ -758,4 +785,4 @@ class GeneIdHandler:
|
|
758 |
return self.ens_to_symbol(self.token_to_ens(token))
|
759 |
|
760 |
def symbol_to_token(self, symbol):
|
761 |
-
return self.ens_to_token(self.symbol_to_ens(symbol))
|
|
|
228 |
example["length"] = len(example["input_ids"])
|
229 |
return example
|
230 |
|
231 |
+
# if CLS token present, move to 1st rather than 0th position
|
232 |
+
def overexpress_indices_special(example):
|
233 |
+
indices = example["perturb_index"]
|
234 |
+
if any(isinstance(el, list) for el in indices):
|
235 |
+
indices = flatten_list(indices)
|
236 |
+
for index in sorted(indices, reverse=True):
|
237 |
+
example["input_ids"].insert(1, example["input_ids"].pop(index))
|
238 |
+
|
239 |
+
example["length"] = len(example["input_ids"])
|
240 |
+
return example
|
241 |
|
242 |
# for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
|
243 |
+
def overexpress_tokens(example, max_len, special_token):
|
244 |
# -100 indicates tokens to overexpress are not present in rank value encoding
|
245 |
if example["perturb_index"] != [-100]:
|
246 |
example = delete_indices(example)
|
247 |
+
if special_token:
|
248 |
+
[
|
249 |
+
example["input_ids"].insert(1, token)
|
250 |
+
for token in example["tokens_to_perturb"][::-1]
|
251 |
+
]
|
252 |
+
else:
|
253 |
+
[
|
254 |
+
example["input_ids"].insert(0, token)
|
255 |
+
for token in example["tokens_to_perturb"][::-1]
|
256 |
+
]
|
257 |
|
258 |
# truncate to max input size, must also truncate original emb to be comparable
|
259 |
if len(example["input_ids"]) > max_len:
|
|
|
275 |
example["length"] = len(example["input_ids"])
|
276 |
return example
|
277 |
|
278 |
+
def truncate_by_n_overflow_special(example):
|
279 |
+
new_max_len = example["length"] - example["n_overflow"]
|
280 |
+
example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]]
|
281 |
+
example["length"] = len(example["input_ids"])
|
282 |
+
return example
|
283 |
+
|
284 |
|
285 |
def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
|
286 |
# indices_to_remove is list of indices to remove
|
|
|
343 |
|
344 |
|
345 |
def make_perturbation_batch(
|
346 |
+
example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc, special_token
|
347 |
) -> tuple[Dataset, List[int]]:
|
348 |
if combo_lvl == 0 and tokens_to_perturb == "all":
|
349 |
if perturb_type in ["overexpress", "activate"]:
|
|
|
405 |
delete_indices, num_proc=num_proc_i
|
406 |
)
|
407 |
elif perturb_type == "overexpress":
|
408 |
+
if special_token:
|
409 |
+
perturbation_dataset = perturbation_dataset.map(
|
410 |
+
overexpress_indices_special, num_proc=num_proc_i
|
411 |
+
)
|
412 |
+
else:
|
413 |
+
perturbation_dataset = perturbation_dataset.map(
|
414 |
+
overexpress_indices, num_proc=num_proc_i
|
415 |
+
)
|
416 |
|
417 |
perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
|
418 |
|
|
|
785 |
return self.ens_to_symbol(self.token_to_ens(token))
|
786 |
|
787 |
def symbol_to_token(self, symbol):
|
788 |
+
return self.ens_to_token(self.symbol_to_ens(symbol))
|