ctheodoris hchen725 commited on
Commit
94095d1
1 Parent(s): 0568479

Update for gene classification (#330)

Browse files

- Update for gene classification (f49922d12ee1fe946a511e6d96b9cda14ce7c22b)


Co-authored-by: Han Chen <[email protected]>

Files changed (1) hide show
  1. geneformer/classifier_utils.py +72 -33
geneformer/classifier_utils.py CHANGED
@@ -1,4 +1,6 @@
 
1
  import logging
 
2
  import random
3
  from collections import Counter, defaultdict
4
 
@@ -6,6 +8,7 @@ import numpy as np
6
  import pandas as pd
7
  from scipy.stats import chisquare, ranksums
8
  from sklearn.metrics import accuracy_score, f1_score
 
9
 
10
  from . import perturber_utils as pu
11
 
@@ -133,61 +136,55 @@ def label_gene_classes(example, class_id_dict, gene_class_dict):
133
  ]
134
 
135
 
136
- def prep_gene_classifier_split(
137
  data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
 
 
 
 
 
 
 
 
 
 
 
 
 
138
  ):
139
  # generate cross-validation splits
140
  targets = np.array(targets)
141
  labels = np.array(labels)
142
- targets_train, targets_eval = targets[train_index], targets[eval_index]
143
- labels_train, labels_eval = labels[train_index], labels[eval_index]
144
- label_dict_train = dict(zip(targets_train, labels_train))
145
- label_dict_eval = dict(zip(targets_eval, labels_eval))
146
 
147
  # function to filter by whether contains train or eval labels
148
- def if_contains_train_label(example):
149
- a = targets_train
150
- b = example["input_ids"]
151
- return not set(a).isdisjoint(b)
152
-
153
- def if_contains_eval_label(example):
154
- a = targets_eval
155
  b = example["input_ids"]
156
  return not set(a).isdisjoint(b)
157
 
158
  # filter dataset for examples containing classes for this split
159
- logger.info(f"Filtering training data for genes in split {iteration_num}")
160
- train_data = data.filter(if_contains_train_label, num_proc=num_proc)
161
  logger.info(
162
- f"Filtered {round((1-len(train_data)/len(data))*100)}%; {len(train_data)} remain\n"
163
- )
164
- logger.info(f"Filtering evalation data for genes in split {iteration_num}")
165
- eval_data = data.filter(if_contains_eval_label, num_proc=num_proc)
166
- logger.info(
167
- f"Filtered {round((1-len(eval_data)/len(data))*100)}%; {len(eval_data)} remain\n"
168
  )
169
 
170
  # subsample to max_ncells
171
- train_data = downsample_and_shuffle(train_data, max_ncells, None, None)
172
- eval_data = downsample_and_shuffle(eval_data, max_ncells, None, None)
173
 
174
  # relabel genes for this split
175
- def train_classes_to_ids(example):
176
  example["labels"] = [
177
- label_dict_train.get(token_id, -100) for token_id in example["input_ids"]
178
  ]
179
  return example
180
 
181
- def eval_classes_to_ids(example):
182
- example["labels"] = [
183
- label_dict_eval.get(token_id, -100) for token_id in example["input_ids"]
184
- ]
185
- return example
186
 
187
- train_data = train_data.map(train_classes_to_ids, num_proc=num_proc)
188
- eval_data = eval_data.map(eval_classes_to_ids, num_proc=num_proc)
189
-
190
- return train_data, eval_data
191
 
192
 
193
  def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
@@ -423,3 +420,45 @@ def get_default_train_args(model, classifier, data, output_dir):
423
  training_args.update(default_training_args)
424
 
425
  return training_args, freeze_layers
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
  import logging
3
+ import os
4
  import random
5
  from collections import Counter, defaultdict
6
 
 
8
  import pandas as pd
9
  from scipy.stats import chisquare, ranksums
10
  from sklearn.metrics import accuracy_score, f1_score
11
+ from sklearn.model_selection import StratifiedKFold, train_test_split
12
 
13
  from . import perturber_utils as pu
14
 
 
136
  ]
137
 
138
 
139
+ def prep_gene_classifier_train_eval_split(
140
  data, targets, labels, train_index, eval_index, max_ncells, iteration_num, num_proc
141
+ ):
142
+ # generate cross-validation splits
143
+ train_data = prep_gene_classifier_split(
144
+ data, targets, labels, train_index, "train", max_ncells, iteration_num, num_proc
145
+ )
146
+ eval_data = prep_gene_classifier_split(
147
+ data, targets, labels, eval_index, "eval", max_ncells, iteration_num, num_proc
148
+ )
149
+ return train_data, eval_data
150
+
151
+
152
+ def prep_gene_classifier_split(
153
+ data, targets, labels, index, subset_name, max_ncells, iteration_num, num_proc
154
  ):
155
  # generate cross-validation splits
156
  targets = np.array(targets)
157
  labels = np.array(labels)
158
+ targets_subset = targets[index]
159
+ labels_subset = labels[index]
160
+ label_dict_subset = dict(zip(targets_subset, labels_subset))
 
161
 
162
  # function to filter by whether contains train or eval labels
163
+ def if_contains_subset_label(example):
164
+ a = targets_subset
 
 
 
 
 
165
  b = example["input_ids"]
166
  return not set(a).isdisjoint(b)
167
 
168
  # filter dataset for examples containing classes for this split
169
+ logger.info(f"Filtering data for {subset_name} genes in split {iteration_num}")
170
+ subset_data = data.filter(if_contains_subset_label, num_proc=num_proc)
171
  logger.info(
172
+ f"Filtered {round((1-len(subset_data)/len(data))*100)}%; {len(subset_data)} remain\n"
 
 
 
 
 
173
  )
174
 
175
  # subsample to max_ncells
176
+ subset_data = downsample_and_shuffle(subset_data, max_ncells, None, None)
 
177
 
178
  # relabel genes for this split
179
+ def subset_classes_to_ids(example):
180
  example["labels"] = [
181
+ label_dict_subset.get(token_id, -100) for token_id in example["input_ids"]
182
  ]
183
  return example
184
 
185
+ subset_data = subset_data.map(subset_classes_to_ids, num_proc=num_proc)
 
 
 
 
186
 
187
+ return subset_data
 
 
 
188
 
189
 
190
  def prep_gene_classifier_all_data(data, targets, labels, max_ncells, num_proc):
 
420
  training_args.update(default_training_args)
421
 
422
  return training_args, freeze_layers
423
+
424
+
425
+ def load_best_model(directory, model_type, num_classes, mode="eval"):
426
+ file_dict = dict()
427
+ for subdir, dirs, files in os.walk(directory):
428
+ for file in files:
429
+ if file.endswith("result.json"):
430
+ with open(f"{subdir}/{file}", "rb") as fp:
431
+ result_json = json.load(fp)
432
+ file_dict[f"{subdir}"] = result_json["eval_macro_f1"]
433
+ file_df = pd.DataFrame(
434
+ {"dir": file_dict.keys(), "eval_macro_f1": file_dict.values()}
435
+ )
436
+ model_superdir = (
437
+ "run-"
438
+ + file_df.iloc[file_df["eval_macro_f1"].idxmax()]["dir"]
439
+ .split("_objective_")[2]
440
+ .split("_")[0]
441
+ )
442
+
443
+ for subdir, dirs, files in os.walk(f"{directory}/{model_superdir}"):
444
+ for file in files:
445
+ if file.endswith("model.safetensors"):
446
+ model = pu.load_model(model_type, num_classes, f"{subdir}", mode)
447
+ return model
448
+
449
+
450
+ class StratifiedKFold3(StratifiedKFold):
451
+ def split(self, targets, labels, test_ratio=0.5, groups=None):
452
+ s = super().split(targets, labels, groups)
453
+ for train_indxs, test_indxs in s:
454
+ if test_ratio == 0:
455
+ yield train_indxs, test_indxs, None
456
+ else:
457
+ labels_test = np.array(labels)[test_indxs]
458
+ valid_indxs, test_indxs = train_test_split(
459
+ test_indxs,
460
+ stratify=labels_test,
461
+ test_size=test_ratio,
462
+ random_state=0,
463
+ )
464
+ yield train_indxs, valid_indxs, test_indxs