ctheodoris commited on
Commit
25dd1da
1 Parent(s): eb038a6

update perturber stats to reflect cos sim and emb_extractor to suppress warnings for non-cls

Browse files
geneformer/emb_extractor.py CHANGED
@@ -78,7 +78,7 @@ def get_embs(
78
  gene_token_dict = {v:k for k,v in token_gene_dict.items()}
79
  cls_token_id = gene_token_dict["<cls>"]
80
  assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
81
- else:
82
  if cls_present:
83
  logger.warning("CLS token present in token dictionary, excluding from average.")
84
  if eos_present:
 
78
  gene_token_dict = {v:k for k,v in token_gene_dict.items()}
79
  cls_token_id = gene_token_dict["<cls>"]
80
  assert filtered_input_data["input_ids"][0][0] == cls_token_id, "First token is not <cls> token value"
81
+ elif emb_mode == "cell":
82
  if cls_present:
83
  logger.warning("CLS token present in token dictionary, excluding from average.")
84
  if eos_present:
geneformer/in_silico_perturber_stats.py CHANGED
@@ -193,9 +193,8 @@ def get_impact_component(test_value, gaussian_mixture_model):
193
 
194
  # aggregate data for single perturbation in multiple cells
195
  def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
196
- names = ["Cosine_shift", "Gene"]
197
  cos_sims_full_dfs = []
198
-
199
  if isinstance(genes_perturbed,list):
200
  if len(genes_perturbed)>1:
201
  gene_ids_df = cos_sims_df.loc[np.isin([set(idx) for idx in cos_sims_df["Ensembl_ID"]], set(genes_perturbed)), :]
@@ -222,7 +221,7 @@ def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
222
  cos_shift_data += dict_i.get((token, "cell_emb"), [])
223
 
224
  df = pd.DataFrame(columns=names)
225
- df["Cosine_shift"] = cos_shift_data
226
  df["Gene"] = symbol
227
  cos_sims_full_dfs.append(df)
228
 
@@ -233,6 +232,8 @@ def find(variable, x):
233
  try:
234
  if x in variable: # Test if variable is iterable and contains x
235
  return True
 
 
236
  except (ValueError, TypeError):
237
  return x == variable # Test if variable is x if non-iterable
238
 
@@ -273,15 +274,15 @@ def isp_aggregate_gene_shifts(
273
  cos_sims_full_df["Affected_Ensembl_ID"] = [
274
  gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
275
  ]
276
- cos_sims_full_df["Cosine_shift_mean"] = [v[0] for k, v in cos_data_mean.items()]
277
- cos_sims_full_df["Cosine_shift_stdev"] = [v[1] for k, v in cos_data_mean.items()]
278
  cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
279
 
280
  specific_val = "cell_emb"
281
  cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
282
- # reorder so cell embs are at the top and all are subordered by magnitude of cosine shift
283
  cos_sims_full_df = cos_sims_full_df.sort_values(
284
- by=(["temp", "Cosine_shift_mean"]), ascending=[False, False]
285
  ).drop("temp", axis=1)
286
 
287
  return cos_sims_full_df
@@ -939,11 +940,11 @@ class InSilicoPerturberStats:
939
  | 1: within impact component; 0: not within impact component
940
  | "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
941
 
942
- | In case of aggregating gene shifts:
943
  | "Perturbed": ID(s) of gene(s) being perturbed
944
  | "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
945
- | "Cosine_shift_mean": mean of cosine shift of modeled perturbation on affected gene or cell
946
- | "Cosine_shift_stdev": standard deviation of cosine shift of modeled perturbation on affected gene or cell
947
  """
948
 
949
  if self.mode not in [
 
193
 
194
  # aggregate data for single perturbation in multiple cells
195
  def isp_aggregate_grouped_perturb(cos_sims_df, dict_list, genes_perturbed):
196
+ names = ["Cosine_sim", "Gene"]
197
  cos_sims_full_dfs = []
 
198
  if isinstance(genes_perturbed,list):
199
  if len(genes_perturbed)>1:
200
  gene_ids_df = cos_sims_df.loc[np.isin([set(idx) for idx in cos_sims_df["Ensembl_ID"]], set(genes_perturbed)), :]
 
221
  cos_shift_data += dict_i.get((token, "cell_emb"), [])
222
 
223
  df = pd.DataFrame(columns=names)
224
+ df["Cosine_sim"] = cos_shift_data
225
  df["Gene"] = symbol
226
  cos_sims_full_dfs.append(df)
227
 
 
232
  try:
233
  if x in variable: # Test if variable is iterable and contains x
234
  return True
235
+ elif x == variable:
236
+ return True
237
  except (ValueError, TypeError):
238
  return x == variable # Test if variable is x if non-iterable
239
 
 
274
  cos_sims_full_df["Affected_Ensembl_ID"] = [
275
  gene_token_id_dict.get(token, np.nan) for token in cos_sims_full_df["Affected"]
276
  ]
277
+ cos_sims_full_df["Cosine_sim_mean"] = [v[0] for k, v in cos_data_mean.items()]
278
+ cos_sims_full_df["Cosine_sim_stdev"] = [v[1] for k, v in cos_data_mean.items()]
279
  cos_sims_full_df["N_Detections"] = [v[2] for k, v in cos_data_mean.items()]
280
 
281
  specific_val = "cell_emb"
282
  cos_sims_full_df["temp"] = list(cos_sims_full_df["Affected"] == specific_val)
283
+ # reorder so cell embs are at the top and all are subordered by magnitude of cosine sim
284
  cos_sims_full_df = cos_sims_full_df.sort_values(
285
+ by=(["temp", "Cosine_sim_mean"]), ascending=[False, True]
286
  ).drop("temp", axis=1)
287
 
288
  return cos_sims_full_df
 
940
  | 1: within impact component; 0: not within impact component
941
  | "Impact_component_percent": percent of cells in which given perturbation was modeled to be within impact component
942
 
943
+ | In case of aggregating data / gene shifts:
944
  | "Perturbed": ID(s) of gene(s) being perturbed
945
  | "Affected": ID of affected gene or "cell_emb" indicating the impact on the cell embedding as a whole
946
+ | "Cosine_sim_mean": mean of cosine similarity of cell or affected gene in original vs. perturbed
947
+ | "Cosine_sim_stdev": standard deviation of cosine similarity of cell or affected gene in original vs. perturbed
948
  """
949
 
950
  if self.mode not in [