ctheodoris commited on
Commit
65d6e69
1 Parent(s): 0680d55

update perturber_utils to account for cls

Browse files
Files changed (1) hide show
  1. geneformer/perturber_utils.py +37 -10
geneformer/perturber_utils.py CHANGED
@@ -228,16 +228,32 @@ def overexpress_indices(example):
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
 
 
 
 
 
 
 
 
 
 
231
 
232
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
233
- def overexpress_tokens(example, max_len):
234
  # -100 indicates tokens to overexpress are not present in rank value encoding
235
  if example["perturb_index"] != [-100]:
236
  example = delete_indices(example)
237
- [
238
- example["input_ids"].insert(0, token)
239
- for token in example["tokens_to_perturb"][::-1]
240
- ]
 
 
 
 
 
 
241
 
242
  # truncate to max input size, must also truncate original emb to be comparable
243
  if len(example["input_ids"]) > max_len:
@@ -259,6 +275,12 @@ def truncate_by_n_overflow(example):
259
  example["length"] = len(example["input_ids"])
260
  return example
261
 
 
 
 
 
 
 
262
 
263
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
264
  # indices_to_remove is list of indices to remove
@@ -321,7 +343,7 @@ def remove_perturbed_indices_set(
321
 
322
 
323
  def make_perturbation_batch(
324
- example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc
325
  ) -> tuple[Dataset, List[int]]:
326
  if combo_lvl == 0 and tokens_to_perturb == "all":
327
  if perturb_type in ["overexpress", "activate"]:
@@ -383,9 +405,14 @@ def make_perturbation_batch(
383
  delete_indices, num_proc=num_proc_i
384
  )
385
  elif perturb_type == "overexpress":
386
- perturbation_dataset = perturbation_dataset.map(
387
- overexpress_indices, num_proc=num_proc_i
388
- )
 
 
 
 
 
389
 
390
  perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
391
 
@@ -758,4 +785,4 @@ class GeneIdHandler:
758
  return self.ens_to_symbol(self.token_to_ens(token))
759
 
760
  def symbol_to_token(self, symbol):
761
- return self.ens_to_token(self.symbol_to_ens(symbol))
 
228
  example["length"] = len(example["input_ids"])
229
  return example
230
 
231
+ # if CLS token present, move to 1st rather than 0th position
232
+ def overexpress_indices_special(example):
233
+ indices = example["perturb_index"]
234
+ if any(isinstance(el, list) for el in indices):
235
+ indices = flatten_list(indices)
236
+ for index in sorted(indices, reverse=True):
237
+ example["input_ids"].insert(1, example["input_ids"].pop(index))
238
+
239
+ example["length"] = len(example["input_ids"])
240
+ return example
241
 
242
  # for genes_to_perturb = list of genes to overexpress that are not necessarily expressed in cell
243
+ def overexpress_tokens(example, max_len, special_token):
244
  # -100 indicates tokens to overexpress are not present in rank value encoding
245
  if example["perturb_index"] != [-100]:
246
  example = delete_indices(example)
247
+ if special_token:
248
+ [
249
+ example["input_ids"].insert(1, token)
250
+ for token in example["tokens_to_perturb"][::-1]
251
+ ]
252
+ else:
253
+ [
254
+ example["input_ids"].insert(0, token)
255
+ for token in example["tokens_to_perturb"][::-1]
256
+ ]
257
 
258
  # truncate to max input size, must also truncate original emb to be comparable
259
  if len(example["input_ids"]) > max_len:
 
275
  example["length"] = len(example["input_ids"])
276
  return example
277
 
278
+ def truncate_by_n_overflow_special(example):
279
+ new_max_len = example["length"] - example["n_overflow"]
280
+ example["input_ids"] = example["input_ids"][0:new_max_len-1]+[example["input_ids"][-1]]
281
+ example["length"] = len(example["input_ids"])
282
+ return example
283
+
284
 
285
  def remove_indices_from_emb(emb, indices_to_remove, gene_dim):
286
  # indices_to_remove is list of indices to remove
 
343
 
344
 
345
  def make_perturbation_batch(
346
+ example_cell, perturb_type, tokens_to_perturb, anchor_token, combo_lvl, num_proc, special_token
347
  ) -> tuple[Dataset, List[int]]:
348
  if combo_lvl == 0 and tokens_to_perturb == "all":
349
  if perturb_type in ["overexpress", "activate"]:
 
405
  delete_indices, num_proc=num_proc_i
406
  )
407
  elif perturb_type == "overexpress":
408
+ if special_token:
409
+ perturbation_dataset = perturbation_dataset.map(
410
+ overexpress_indices_special, num_proc=num_proc_i
411
+ )
412
+ else:
413
+ perturbation_dataset = perturbation_dataset.map(
414
+ overexpress_indices, num_proc=num_proc_i
415
+ )
416
 
417
  perturbation_dataset = perturbation_dataset.map(measure_length, num_proc=num_proc_i)
418
 
 
785
  return self.ens_to_symbol(self.token_to_ens(token))
786
 
787
  def symbol_to_token(self, symbol):
788
+ return self.ens_to_token(self.symbol_to_ens(symbol))