Fixed bug with the double removing of indices when cell_states_to_model is false
#188
by
davidjwen
- opened
geneformer/in_silico_perturber.py
CHANGED
@@ -123,17 +123,17 @@ def forward_pass_single_cell(model, example_cell, layer_to_quant):
|
|
123 |
del outputs
|
124 |
return emb
|
125 |
|
126 |
-
def perturb_emb_by_index(emb, indices):
|
127 |
-
mask = torch.ones(emb.numel(), dtype=torch.bool)
|
128 |
-
mask[indices] = False
|
129 |
return emb[mask]
|
130 |
|
131 |
-
def delete_indices(example):
|
132 |
indices = example["perturb_index"]
|
133 |
if any(isinstance(el, list) for el in indices):
|
134 |
-
indices = flatten_list(indices)
|
135 |
-
for index in sorted(indices, reverse=True):
|
136 |
-
del example["input_ids"][index]
|
137 |
return example
|
138 |
|
139 |
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
@@ -180,10 +180,10 @@ def make_perturbation_batch(example_cell,
|
|
180 |
elif perturb_type in ["delete","inhibit"]:
|
181 |
range_start = 0
|
182 |
indices_to_perturb = [[i] for i in range(range_start,example_cell["length"][0])]
|
183 |
-
elif combo_lvl>0 and (anchor_token is not None):
|
184 |
-
example_input_ids = example_cell["input_ids "][0]
|
185 |
-
anchor_index = example_input_ids.index(anchor_token[0])
|
186 |
-
indices_to_perturb = [sorted([anchor_index,i]) if i!=anchor_index else None for i in range(example_cell["length"][0])]
|
187 |
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
188 |
else:
|
189 |
example_input_ids = example_cell["input_ids"][0]
|
@@ -398,7 +398,7 @@ def quant_cos_sims(model,
|
|
398 |
original_minibatch_length_set = set(original_minibatch["length"])
|
399 |
|
400 |
indices_to_perturb_minibatch = indices_to_perturb[i:i+forward_batch_size]
|
401 |
-
|
402 |
if perturb_type == "overexpress":
|
403 |
new_max_len = model_input_size - len(tokens_to_perturb)
|
404 |
else:
|
@@ -440,9 +440,7 @@ def quant_cos_sims(model,
|
|
440 |
if perturb_group == False:
|
441 |
minibatch_comparison = comparison_batch[i:max_range]
|
442 |
elif perturb_group == True:
|
443 |
-
minibatch_comparison =
|
444 |
-
indices_to_perturb_minibatch,
|
445 |
-
perturb_group)
|
446 |
|
447 |
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|
448 |
elif cell_states_to_model is not None:
|
|
|
123 |
del outputs
|
124 |
return emb
|
125 |
|
126 |
+
def perturb_emb_by_index(emb, indices):
|
127 |
+
mask = torch.ones(emb.numel(), dtype=torch.bool)
|
128 |
+
mask[indices] = False
|
129 |
return emb[mask]
|
130 |
|
131 |
+
def delete_indices(example):
|
132 |
indices = example["perturb_index"]
|
133 |
if any(isinstance(el, list) for el in indices):
|
134 |
+
indices = flatten_list(indices)
|
135 |
+
for index in sorted(indices, reverse=True):
|
136 |
+
del example["input_ids"][index]
|
137 |
return example
|
138 |
|
139 |
# for genes_to_perturb = "all" where only genes within cell are overexpressed
|
|
|
180 |
elif perturb_type in ["delete","inhibit"]:
|
181 |
range_start = 0
|
182 |
indices_to_perturb = [[i] for i in range(range_start,example_cell["length"][0])]
|
183 |
+
elif combo_lvl>0 and (anchor_token is not None):
|
184 |
+
example_input_ids = example_cell["input_ids "][0]
|
185 |
+
anchor_index = example_input_ids.index(anchor_token[0])
|
186 |
+
indices_to_perturb = [sorted([anchor_index,i]) if i!=anchor_index else None for i in range(example_cell["length"][0])]
|
187 |
indices_to_perturb = [item for item in indices_to_perturb if item is not None]
|
188 |
else:
|
189 |
example_input_ids = example_cell["input_ids"][0]
|
|
|
398 |
original_minibatch_length_set = set(original_minibatch["length"])
|
399 |
|
400 |
indices_to_perturb_minibatch = indices_to_perturb[i:i+forward_batch_size]
|
401 |
+
|
402 |
if perturb_type == "overexpress":
|
403 |
new_max_len = model_input_size - len(tokens_to_perturb)
|
404 |
else:
|
|
|
440 |
if perturb_group == False:
|
441 |
minibatch_comparison = comparison_batch[i:max_range]
|
442 |
elif perturb_group == True:
|
443 |
+
minibatch_comparison = original_minibatch_emb
|
|
|
|
|
444 |
|
445 |
cos_sims += [cos(minibatch_emb, minibatch_comparison).to("cpu")]
|
446 |
elif cell_states_to_model is not None:
|