Fixed bug with the double removing of indices when cell_states_to_model is false

#188
Files changed (1) hide show
  1. geneformer/in_silico_perturber.py +13 -15
geneformer/in_silico_perturber.py CHANGED
@@ -123,17 +123,17 @@ def forward_pass_single_cell(model, example_cell, layer_to_quant):
123
  del outputs
124
  return emb
125
 
126
- def perturb_emb_by_index(emb, indices):
127
- mask = torch.ones(emb.numel(), dtype=torch.bool)
128
- mask[indices] = False
129
  return emb[mask]
130
 
131
- def delete_indices(example):
132
  indices = example["perturb_index"]
133
  if any(isinstance(el, list) for el in indices):
134
- indices = flatten_list(indices)
135
- for index in sorted(indices, reverse=True):
136
- del example["input_ids"][index]
137
  return example
138
 
139
  # for genes_to_perturb = "all" where only genes within cell are overexpressed
@@ -180,10 +180,10 @@ def make_perturbation_batch(example_cell,
180
  elif perturb_type in ["delete","inhibit"]:
181
  range_start = 0
182
  indices_to_perturb = [[i] for i in range(range_start,example_cell["length"][0])]
183
- elif combo_lvl>0 and (anchor_token is not None):
184
- example_input_ids = example_cell["input_ids "][0]
185
- anchor_index = example_input_ids.index(anchor_token[0])
186
- indices_to_perturb = [sorted([anchor_index,i]) if i!=anchor_index else None for i in range(example_cell["length"][0])]
187
  indices_to_perturb = [item for item in indices_to_perturb if item is not None]
188
  else:
189
  example_input_ids = example_cell["input_ids"][0]
@@ -398,7 +398,7 @@ def quant_cos_sims(model,
398
  original_minibatch_length_set = set(original_minibatch["length"])
399
 
400
  indices_to_perturb_minibatch = indices_to_perturb[i:i+forward_batch_size]
401
-
402
  if perturb_type == "overexpress":
403
  new_max_len = model_input_size - len(tokens_to_perturb)
404
  else:
@@ -440,9 +440,7 @@ def quant_cos_sims(model,
440
  if perturb_group == False:
441
  minibatch_comparison = comparison_batch[i:max_range]
442
  elif perturb_group == True:
443
- minibatch_comparison = make_comparison_batch(original_minibatch_emb,
444
- indices_to_perturb_minibatch,
445
- perturb_group)
446
 
447
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
448
  elif cell_states_to_model is not None:
 
123
  del outputs
124
  return emb
125
 
126
+ def perturb_emb_by_index(emb, indices):
127
+ mask = torch.ones(emb.numel(), dtype=torch.bool)
128
+ mask[indices] = False
129
  return emb[mask]
130
 
131
+ def delete_indices(example):
132
  indices = example["perturb_index"]
133
  if any(isinstance(el, list) for el in indices):
134
+ indices = flatten_list(indices)
135
+ for index in sorted(indices, reverse=True):
136
+ del example["input_ids"][index]
137
  return example
138
 
139
  # for genes_to_perturb = "all" where only genes within cell are overexpressed
 
180
  elif perturb_type in ["delete","inhibit"]:
181
  range_start = 0
182
  indices_to_perturb = [[i] for i in range(range_start,example_cell["length"][0])]
183
+ elif combo_lvl>0 and (anchor_token is not None):
184
+ example_input_ids = example_cell["input_ids "][0]
185
+ anchor_index = example_input_ids.index(anchor_token[0])
186
+ indices_to_perturb = [sorted([anchor_index,i]) if i!=anchor_index else None for i in range(example_cell["length"][0])]
187
  indices_to_perturb = [item for item in indices_to_perturb if item is not None]
188
  else:
189
  example_input_ids = example_cell["input_ids"][0]
 
398
  original_minibatch_length_set = set(original_minibatch["length"])
399
 
400
  indices_to_perturb_minibatch = indices_to_perturb[i:i+forward_batch_size]
401
+
402
  if perturb_type == "overexpress":
403
  new_max_len = model_input_size - len(tokens_to_perturb)
404
  else:
 
440
  if perturb_group == False:
441
  minibatch_comparison = comparison_batch[i:max_range]
442
  elif perturb_group == True:
443
+ minibatch_comparison = original_minibatch_emb
 
 
444
 
445
  cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
446
  elif cell_states_to_model is not None: