Christina Theodoris commited on
Commit
dc1481d
1 Parent(s): 67f674c

Add mixture model option for gene-gene interaction stats

Browse files
geneformer/in_silico_perturber_stats.py CHANGED
@@ -24,7 +24,7 @@ import statsmodels.stats.multitest as smt
24
  from pathlib import Path
25
  from scipy.stats import ranksums
26
  from sklearn.mixture import GaussianMixture
27
- from tqdm.notebook import trange
28
 
29
  from .tokenizer import TOKEN_DICTIONARY_FILE
30
 
@@ -37,19 +37,26 @@ def invert_dict(dictionary):
37
  return {v: k for k, v in dictionary.items()}
38
 
39
  # read raw dictionary files
40
- def read_dictionaries(dir, cell_or_gene_emb):
41
  file_found = 0
 
42
  dict_list = []
43
  for file in os.listdir(dir):
44
  # process only _raw.pickle files
45
  if file.endswith("_raw.pickle"):
46
  file_found = 1
47
- with open(f"{dir}/{file}", "rb") as fp:
48
- cos_sims_dict = pickle.load(fp)
49
- if cell_or_gene_emb == "cell":
50
- cell_emb_dict = {k: v for k,
51
- v in cos_sims_dict.items() if v and "cell_emb" in k}
 
 
52
  dict_list += [cell_emb_dict]
 
 
 
 
53
  if file_found == 0:
54
  logger.error(
55
  "No raw data for processing found within provided directory. " \
@@ -58,18 +65,27 @@ def read_dictionaries(dir, cell_or_gene_emb):
58
  return dict_list
59
 
60
  # get complete gene list
61
- def get_gene_list(dict_list):
 
 
 
 
62
  gene_set = set()
63
  for dict_i in dict_list:
64
- gene_set.update([k[0] for k, v in dict_i.items() if v])
65
  gene_list = list(gene_set)
 
 
66
  gene_list.sort()
67
  return gene_list
68
 
69
- def n_detections(token, dict_list):
70
  cos_sim_megalist = []
71
  for dict_i in dict_list:
72
- cos_sim_megalist += dict_i.get((token, "cell_emb"),[])
 
 
 
73
  return len(cos_sim_megalist)
74
 
75
  def get_fdr(pvalues):
@@ -154,7 +170,7 @@ def isp_stats_to_goal_state(cos_sims_df, dict_list):
154
  cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
155
 
156
  # quantify number of detections of each gene
157
- cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list) for i in cos_sims_full_df["Gene"]]
158
 
159
  # sort by shift to desired state
160
  cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_from_goal_end",
@@ -205,7 +221,7 @@ def isp_stats_vs_null(cos_sims_df, dict_list, null_dict_list):
205
  # reports the most likely component for each test perturbation
206
  # Note: because assumes given perturbation has a consistent effect in the cells tested,
207
  # we recommend only using the mixture model strategy with uniform cell populations
208
- def isp_stats_mixture_model(cos_sims_df, dict_list, combos):
209
 
210
  names=["Gene",
211
  "Gene_name",
@@ -232,9 +248,12 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos):
232
  name = cos_sims_df["Gene_name"][i]
233
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
234
  cos_shift_data = []
235
-
236
  for dict_i in dict_list:
237
- cos_shift_data += dict_i.get((token, "cell_emb"),[])
 
 
 
238
 
239
  # Extract values for current gene
240
  if combos == 0:
@@ -248,7 +267,7 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos):
248
  avg_value = np.mean(test_values)
249
  avg_values.append(avg_value)
250
  gene_names.append(name)
251
-
252
  # fit Gaussian mixture model to dataset of mean for each gene
253
  avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
254
  gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
@@ -260,7 +279,10 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos):
260
  cos_shift_data = []
261
 
262
  for dict_i in dict_list:
263
- cos_shift_data += dict_i.get((token, "cell_emb"),[])
 
 
 
264
 
265
  if combos == 0:
266
  mean_test = np.mean(cos_shift_data)
@@ -301,7 +323,10 @@ def isp_stats_mixture_model(cos_sims_df, dict_list, combos):
301
  cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
302
 
303
  # quantify number of detections of each gene
304
- cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list) for i in cos_sims_full_df["Gene"]]
 
 
 
305
 
306
  if combos == 0:
307
  cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
@@ -342,9 +367,11 @@ class InSilicoPerturberStats:
342
  combos : {0,1,2}
343
  Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
344
  anchor_gene : None, str
345
- ENSEMBL ID of gene to use as anchor in combination perturbations.
346
- For example, if combos=1 and anchor_gene="ENSG00000148400":
347
- anchor gene will be perturbed in combination with each other gene.
 
 
348
  cell_states_to_model: None, dict
349
  Cell states to model if testing perturbations that achieve goal state change.
350
  Single-item dictionary with key being cell attribute (e.g. "disease").
@@ -459,8 +486,14 @@ class InSilicoPerturberStats:
459
  self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
460
 
461
  # obtain total gene list
462
- dict_list = read_dictionaries(input_data_directory, "cell")
463
- gene_list = get_gene_list(dict_list)
 
 
 
 
 
 
464
 
465
  # initiate results dataframe
466
  cos_sims_df_initial = pd.DataFrame({"Gene": gene_list,
@@ -476,11 +509,11 @@ class InSilicoPerturberStats:
476
  cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list)
477
 
478
  elif self.mode == "vs_null":
479
- null_dict_list = read_dictionaries(null_dist_data_directory, "cell")
480
  cos_sims_df = isp_stats_vs_null(cos_sims_df_initial, dict_list, null_dict_list)
481
 
482
  elif self.mode == "mixture_model":
483
- cos_sims_df = isp_stats_mixture_model(cos_sims_df_initial, dict_list, self.combos)
484
 
485
  # save perturbation stats to output_path
486
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")
 
24
  from pathlib import Path
25
  from scipy.stats import ranksums
26
  from sklearn.mixture import GaussianMixture
27
+ from tqdm.notebook import trange, tqdm
28
 
29
  from .tokenizer import TOKEN_DICTIONARY_FILE
30
 
 
37
  return {v: k for k, v in dictionary.items()}
38
 
39
  # read raw dictionary files
40
+ def read_dictionaries(dir, cell_or_gene_emb, anchor_token):
41
  file_found = 0
42
+ file_path_list = []
43
  dict_list = []
44
  for file in os.listdir(dir):
45
  # process only _raw.pickle files
46
  if file.endswith("_raw.pickle"):
47
  file_found = 1
48
+ file_path_list += [f"{dir}/{file}"]
49
+ for file_path in tqdm(file_path_list):
50
+ with open(file_path, "rb") as fp:
51
+ cos_sims_dict = pickle.load(fp)
52
+ if cell_or_gene_emb == "cell":
53
+ cell_emb_dict = {k: v for k,
54
+ v in cos_sims_dict.items() if v and "cell_emb" in k}
55
  dict_list += [cell_emb_dict]
56
+ elif cell_or_gene_emb == "gene":
57
+ gene_emb_dict = {k: v for k,
58
+ v in cos_sims_dict.items() if v and anchor_token == k[0]}
59
+ dict_list += [gene_emb_dict]
60
  if file_found == 0:
61
  logger.error(
62
  "No raw data for processing found within provided directory. " \
 
65
  return dict_list
66
 
67
  # get complete gene list
68
+ def get_gene_list(dict_list,mode):
69
+ if mode == "cell":
70
+ position = 0
71
+ elif mode == "gene":
72
+ position = 1
73
  gene_set = set()
74
  for dict_i in dict_list:
75
+ gene_set.update([k[position] for k, v in dict_i.items() if v])
76
  gene_list = list(gene_set)
77
+ if mode == "gene":
78
+ gene_list.remove("cell_emb")
79
  gene_list.sort()
80
  return gene_list
81
 
82
+ def n_detections(token, dict_list, mode, anchor_token):
83
  cos_sim_megalist = []
84
  for dict_i in dict_list:
85
+ if mode == "cell":
86
+ cos_sim_megalist += dict_i.get((token, "cell_emb"),[])
87
+ elif mode == "gene":
88
+ cos_sim_megalist += dict_i.get((anchor_token, token),[])
89
  return len(cos_sim_megalist)
90
 
91
  def get_fdr(pvalues):
 
170
  cos_sims_full_df["Alt_end_FDR"] = get_fdr(list(cos_sims_full_df["Alt_end_vs_random_pval"]))
171
 
172
  # quantify number of detections of each gene
173
+ cos_sims_full_df["N_Detections"] = [n_detections(i, dict_list, "cell", None) for i in cos_sims_full_df["Gene"]]
174
 
175
  # sort by shift to desired state
176
  cos_sims_full_df = cos_sims_full_df.sort_values(by=["Shift_from_goal_end",
 
221
  # reports the most likely component for each test perturbation
222
  # Note: because assumes given perturbation has a consistent effect in the cells tested,
223
  # we recommend only using the mixture model strategy with uniform cell populations
224
+ def isp_stats_mixture_model(cos_sims_df, dict_list, combos, anchor_token):
225
 
226
  names=["Gene",
227
  "Gene_name",
 
248
  name = cos_sims_df["Gene_name"][i]
249
  ensembl_id = cos_sims_df["Ensembl_ID"][i]
250
  cos_shift_data = []
251
+
252
  for dict_i in dict_list:
253
+ if (combos == 0) and (anchor_token is not None):
254
+ cos_shift_data += dict_i.get((anchor_token, token),[])
255
+ else:
256
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
257
 
258
  # Extract values for current gene
259
  if combos == 0:
 
267
  avg_value = np.mean(test_values)
268
  avg_values.append(avg_value)
269
  gene_names.append(name)
270
+
271
  # fit Gaussian mixture model to dataset of mean for each gene
272
  avg_values_to_fit = np.array(avg_values).reshape(-1, 1)
273
  gm = GaussianMixture(n_components=2, random_state=0).fit(avg_values_to_fit)
 
279
  cos_shift_data = []
280
 
281
  for dict_i in dict_list:
282
+ if (combos == 0) and (anchor_token is not None):
283
+ cos_shift_data += dict_i.get((anchor_token, token),[])
284
+ else:
285
+ cos_shift_data += dict_i.get((token, "cell_emb"),[])
286
 
287
  if combos == 0:
288
  mean_test = np.mean(cos_shift_data)
 
323
  cos_sims_full_df = pd.concat([cos_sims_full_df,cos_sims_df_i])
324
 
325
  # quantify number of detections of each gene
326
+ cos_sims_full_df["N_Detections"] = [n_detections(i,
327
+ dict_list,
328
+ "gene",
329
+ anchor_token) for i in cos_sims_full_df["Gene"]]
330
 
331
  if combos == 0:
332
  cos_sims_full_df = cos_sims_full_df.sort_values(by=["Impact_component",
 
367
  combos : {0,1,2}
368
  Whether to perturb genes individually (0), in pairs (1), or in triplets (2).
369
  anchor_gene : None, str
370
+ ENSEMBL ID of gene to use as anchor in combination perturbations or in testing effect on downstream genes.
371
+ For example, if combos=1 and anchor_gene="ENSG00000136574":
372
+ analyzes data for anchor gene perturbed in combination with each other gene.
373
+ However, if combos=0 and anchor_gene="ENSG00000136574":
374
+ analyzes data for the effect of anchor gene's perturbation on the embedding of each other gene.
375
  cell_states_to_model: None, dict
376
  Cell states to model if testing perturbations that achieve goal state change.
377
  Single-item dictionary with key being cell attribute (e.g. "disease").
 
486
  self.gene_id_name_dict = invert_dict(self.gene_name_id_dict)
487
 
488
  # obtain total gene list
489
+ if (self.combos == 0) and (self.anchor_token is not None):
490
+ # cos sim data for effect of gene perturbation on the embedding of each other gene
491
+ dict_list = read_dictionaries(input_data_directory, "gene", self.anchor_token)
492
+ gene_list = get_gene_list(dict_list, "gene")
493
+ else:
494
+ # cos sim data for effect of gene perturbation on the embedding of each cell
495
+ dict_list = read_dictionaries(input_data_directory, "cell", self.anchor_token)
496
+ gene_list = get_gene_list(dict_list, "cell")
497
 
498
  # initiate results dataframe
499
  cos_sims_df_initial = pd.DataFrame({"Gene": gene_list,
 
509
  cos_sims_df = isp_stats_to_goal_state(cos_sims_df_initial, dict_list)
510
 
511
  elif self.mode == "vs_null":
512
+ null_dict_list = read_dictionaries(null_dist_data_directory, "cell", self.anchor_token)
513
  cos_sims_df = isp_stats_vs_null(cos_sims_df_initial, dict_list, null_dict_list)
514
 
515
  elif self.mode == "mixture_model":
516
+ cos_sims_df = isp_stats_mixture_model(cos_sims_df_initial, dict_list, self.combos, self.anchor_token)
517
 
518
  # save perturbation stats to output_path
519
  output_path = (Path(output_directory) / output_prefix).with_suffix(".csv")