Christina Theodoris commited on
Commit
4bddd45
1 Parent(s): 5a43832

add option for hyperparameter tuning to cc.validate

Browse files
examples/cell_classification.ipynb CHANGED
@@ -13,7 +13,7 @@
13
  "id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
14
  "metadata": {},
15
  "source": [
16
- "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but please see the \"hyperparam_optimiz_for_disease_classifier\" script for an example of how to tune hyperparameters for downstream applications."
17
  ]
18
  },
19
  {
@@ -266,7 +266,8 @@
266
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
267
  " output_directory=output_dir,\n",
268
  " output_prefix=output_prefix,\n",
269
- " split_id_dict=train_valid_id_split_dict)"
 
270
  ]
271
  },
272
  {
 
13
  "id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
14
  "metadata": {},
15
  "source": [
16
+ "### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization."
17
  ]
18
  },
19
  {
 
266
  " id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
267
  " output_directory=output_dir,\n",
268
  " output_prefix=output_prefix,\n",
269
+ " split_id_dict=train_valid_id_split_dict)\n",
270
+ " # to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)"
271
  ]
272
  },
273
  {
examples/hyperparam_optimiz_for_disease_classifier.py DELETED
@@ -1,226 +0,0 @@
1
- #!/usr/bin/env python
2
- # coding: utf-8
3
-
4
- # hyperparameter optimization with raytune for disease classification
5
-
6
- # imports
7
- import os
8
- import subprocess
9
- GPU_NUMBER = [0,1,2,3]
10
- os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
11
- os.environ["NCCL_DEBUG"] = "INFO"
12
- os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
13
- os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
14
-
15
- # initiate runtime environment for raytune
16
- import pyarrow # must occur prior to ray import
17
- import ray
18
- from ray import tune
19
- from ray.tune import ExperimentAnalysis
20
- from ray.tune.suggest.hyperopt import HyperOptSearch
21
- ray.shutdown() #engage new ray session
22
- runtime_env = {"conda": "base",
23
- "env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
24
- ray.init(runtime_env=runtime_env)
25
-
26
- def initialize_ray_with_check(ip_address):
27
- """
28
- Initialize Ray with a specified IP address and check its status and accessibility.
29
-
30
- Args:
31
- - ip_address (str): The IP address (with port) to initialize Ray.
32
-
33
- Returns:
34
- - bool: True if initialization was successful and dashboard is accessible, False otherwise.
35
- """
36
- try:
37
- ray.init(address=ip_address)
38
- print(ray.nodes())
39
-
40
- services = ray.get_webui_url()
41
- if not services:
42
- raise RuntimeError("Ray dashboard is not accessible.")
43
- else:
44
- print(f"Ray dashboard is accessible at: {services}")
45
- return True
46
- except Exception as e:
47
- print(f"Error initializing Ray: {e}")
48
- return False
49
-
50
- # Usage:
51
- ip = 'your_ip:xxxx' # Replace with your actual IP address and port
52
- if initialize_ray_with_check(ip):
53
- print("Ray initialized successfully.")
54
- else:
55
- print("Error during Ray initialization.")
56
-
57
- import datetime
58
- import numpy as np
59
- import pandas as pd
60
- import random
61
- import seaborn as sns; sns.set()
62
- from collections import Counter
63
- from datasets import load_from_disk
64
- from scipy.stats import ranksums
65
- from sklearn.metrics import accuracy_score
66
- from transformers import BertForSequenceClassification
67
- from transformers import Trainer
68
- from transformers.training_args import TrainingArguments
69
-
70
- from geneformer import DataCollatorForCellClassification
71
-
72
- # number of CPU cores
73
- num_proc=30
74
-
75
- # load train dataset with columns:
76
- # cell_type (annotation of each cell's type)
77
- # disease (healthy or disease state)
78
- # individual (unique ID for each patient)
79
- # length (length of that cell's rank value encoding)
80
- train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
81
-
82
- # filter dataset for given cell_type
83
- def if_cell_type(example):
84
- return example["cell_type"].startswith("Cardiomyocyte")
85
-
86
- trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
87
-
88
- # create dictionary of disease states : label ids
89
- target_names = ["healthy", "disease1", "disease2"]
90
- target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
91
-
92
- trainset_v3 = trainset_v2.rename_column("disease","label")
93
-
94
- # change labels to numerical ids
95
- def classes_to_ids(example):
96
- example["label"] = target_name_id_dict[example["label"]]
97
- return example
98
-
99
- trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
100
-
101
- # separate into train, validation, test sets
102
- indiv_set = set(trainset_v4["individual"])
103
- random.seed(42)
104
- train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set)))
105
- eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv]
106
- valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
107
- test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
108
-
109
- def if_train(example):
110
- return example["individual"] in train_indiv
111
-
112
- classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
113
-
114
- def if_valid(example):
115
- return example["individual"] in valid_indiv
116
-
117
- classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
118
-
119
- # define output directory path
120
- current_date = datetime.datetime.now()
121
- datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
122
- output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
123
-
124
- # ensure not overwriting previously saved model
125
- saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
126
- if os.path.isfile(saved_model_test) == True:
127
- raise Exception("Model already saved to this directory.")
128
-
129
- # make output directory
130
- subprocess.call(f'mkdir {output_dir}', shell=True)
131
-
132
- # set training parameters
133
- # how many pretrained layers to freeze
134
- freeze_layers = 2
135
- # batch size for training and eval
136
- geneformer_batch_size = 12
137
- # number of epochs
138
- epochs = 1
139
- # logging steps
140
- logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
141
-
142
- # define function to initiate model
143
- def model_init():
144
- model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
145
- num_labels=len(target_names),
146
- output_attentions = False,
147
- output_hidden_states = False)
148
- if freeze_layers is not None:
149
- modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
150
- for module in modules_to_freeze:
151
- for param in module.parameters():
152
- param.requires_grad = False
153
-
154
- model = model.to("cuda:0")
155
- return model
156
-
157
- # define metrics
158
- # note: macro f1 score recommended for imbalanced multiclass classifiers
159
- def compute_metrics(pred):
160
- labels = pred.label_ids
161
- preds = pred.predictions.argmax(-1)
162
- # calculate accuracy using sklearn's function
163
- acc = accuracy_score(labels, preds)
164
- return {
165
- 'accuracy': acc,
166
- }
167
-
168
- # set training arguments
169
- training_args = {
170
- "do_train": True,
171
- "do_eval": True,
172
- "evaluation_strategy": "steps",
173
- "eval_steps": logging_steps,
174
- "logging_steps": logging_steps,
175
- "group_by_length": True,
176
- "length_column_name": "length",
177
- "disable_tqdm": True,
178
- "skip_memory_metrics": True, # memory tracker causes errors in raytune
179
- "per_device_train_batch_size": geneformer_batch_size,
180
- "per_device_eval_batch_size": geneformer_batch_size,
181
- "num_train_epochs": epochs,
182
- "load_best_model_at_end": True,
183
- "output_dir": output_dir,
184
- }
185
-
186
- training_args_init = TrainingArguments(**training_args)
187
-
188
- # create the trainer
189
- trainer = Trainer(
190
- model_init=model_init,
191
- args=training_args_init,
192
- data_collator=DataCollatorForCellClassification(),
193
- train_dataset=classifier_trainset,
194
- eval_dataset=classifier_validset,
195
- compute_metrics=compute_metrics,
196
- )
197
-
198
- # specify raytune hyperparameter search space
199
- ray_config = {
200
- "num_train_epochs": tune.choice([epochs]),
201
- "learning_rate": tune.loguniform(1e-6, 1e-3),
202
- "weight_decay": tune.uniform(0.0, 0.3),
203
- "lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
204
- "warmup_steps": tune.uniform(100, 2000),
205
- "seed": tune.uniform(0,100),
206
- "per_device_train_batch_size": tune.choice([geneformer_batch_size])
207
- }
208
-
209
- hyperopt_search = HyperOptSearch(
210
- metric="eval_accuracy", mode="max")
211
-
212
- # optimize hyperparameters
213
- trainer.hyperparameter_search(
214
- direction="maximize",
215
- backend="ray",
216
- resources_per_trial={"cpu":8,"gpu":1},
217
- hp_space=lambda _: ray_config,
218
- search_alg=hyperopt_search,
219
- n_trials=100, # number of trials
220
- progress_reporter=tune.CLIReporter(max_report_frequency=600,
221
- sort_by_metric=True,
222
- max_progress_rows=100,
223
- mode="max",
224
- metric="eval_accuracy",
225
- metric_columns=["loss", "eval_loss", "eval_accuracy"])
226
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
geneformer/classifier.py CHANGED
@@ -82,11 +82,12 @@ class Classifier:
82
  "training_args": {None, dict},
83
  "freeze_layers": {int},
84
  "num_crossval_splits": {0, 1, 5},
85
- "eval_size": {int, float},
86
  "no_eval": {bool},
87
  "stratify_splits_col": {None, str},
88
  "forward_batch_size": {int},
89
  "nproc": {int},
 
90
  }
91
 
92
  def __init__(
@@ -99,13 +100,15 @@ class Classifier:
99
  max_ncells=None,
100
  max_ncells_per_class=None,
101
  training_args=None,
 
102
  freeze_layers=0,
103
  num_crossval_splits=1,
104
- eval_size=0.2,
105
  stratify_splits_col=None,
106
  no_eval=False,
107
  forward_batch_size=100,
108
  nproc=4,
 
109
  ):
110
  """
111
  Initialize Geneformer classifier.
@@ -152,15 +155,18 @@ class Classifier:
152
  | Otherwise, will use the Hugging Face defaults:
153
  | https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
154
  | Note: Hyperparameter tuning is highly recommended, rather than using defaults.
 
 
155
  freeze_layers : int
156
  | Number of layers to freeze from fine-tuning.
157
  | 0: no layers will be frozen; 2: first two layers will be frozen; etc.
158
  num_crossval_splits : {0, 1, 5}
159
  | 0: train on all data without splitting
160
- | 1: split data into train and eval sets by designated eval_size
161
- | 5: split data into 5 folds of train and eval sets by designated eval_size
162
- eval_size : None, float
163
- | Proportion of data to hold out for evaluation (e.g. 0.2 if intending 80:20 train/eval split)
 
164
  stratify_splits_col : None, str
165
  | Name of column in .dataset to be used for stratified splitting.
166
  | Proportion of each class in this column will be the same in the splits as in the original dataset.
@@ -171,6 +177,8 @@ class Classifier:
171
  | Batch size for forward pass (for evaluation, not training).
172
  nproc : int
173
  | Number of CPU processes to use.
 
 
174
 
175
  """
176
 
@@ -182,13 +190,19 @@ class Classifier:
182
  self.max_ncells = max_ncells
183
  self.max_ncells_per_class = max_ncells_per_class
184
  self.training_args = training_args
 
185
  self.freeze_layers = freeze_layers
186
  self.num_crossval_splits = num_crossval_splits
187
- self.eval_size = eval_size
 
 
 
 
188
  self.stratify_splits_col = stratify_splits_col
189
  self.no_eval = no_eval
190
  self.forward_batch_size = forward_batch_size
191
  self.nproc = nproc
 
192
 
193
  if self.training_args is None:
194
  logger.warning(
@@ -301,6 +315,9 @@ class Classifier:
301
  "Gene_class_dict should contain at least 2 gene classes to classify."
302
  )
303
  raise
 
 
 
304
 
305
  def prepare_data(
306
  self,
@@ -337,6 +354,7 @@ class Classifier:
337
  test_size : None, float
338
  | Proportion of data to be saved separately and held out for test set
339
  | (e.g. 0.2 if intending hold out 20%)
 
340
  | The training set will be further split to train / validation in self.validate
341
  | Note: only available for CellClassifiers
342
  attr_to_split : None, str
@@ -356,6 +374,9 @@ class Classifier:
356
  | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
357
  """
358
 
 
 
 
359
  # prepare data and labels for classification
360
  data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
361
 
@@ -555,6 +576,7 @@ class Classifier:
555
  save_eval_output=True,
556
  predict_eval=True,
557
  predict_trainer=False,
 
558
  ):
559
  """
560
  (Cross-)validate cell state or gene classifier.
@@ -604,6 +626,9 @@ class Classifier:
604
  predict_trainer : bool
605
  | Whether or not to save eval predictions from trainer
606
  | Saves as a pickle file of trainer predictions
 
 
 
607
  """
608
 
609
  if self.num_crossval_splits == 0:
@@ -700,14 +725,30 @@ class Classifier:
700
  ]
701
  eval_data = data.select(eval_indices)
702
  train_data = data.select(train_indices)
703
- trainer = self.train_classifier(
704
- model_directory,
705
- num_classes,
706
- train_data,
707
- eval_data,
708
- ksplit_output_dir,
709
- predict_trainer,
710
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
711
  result = self.evaluate_model(
712
  trainer.model,
713
  num_classes,
@@ -752,14 +793,29 @@ class Classifier:
752
  self.nproc,
753
  )
754
 
755
- trainer = self.train_classifier(
756
- model_directory,
757
- num_classes,
758
- train_data,
759
- eval_data,
760
- ksplit_output_dir,
761
- predict_trainer,
762
- )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
763
  result = self.evaluate_model(
764
  trainer.model,
765
  num_classes,
@@ -810,6 +866,162 @@ class Classifier:
810
 
811
  return all_metrics
812
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
813
  def train_classifier(
814
  self,
815
  model_directory,
 
82
  "training_args": {None, dict},
83
  "freeze_layers": {int},
84
  "num_crossval_splits": {0, 1, 5},
85
+ "split_sizes": {None, dict},
86
  "no_eval": {bool},
87
  "stratify_splits_col": {None, str},
88
  "forward_batch_size": {int},
89
  "nproc": {int},
90
+ "ngpu": {int},
91
  }
92
 
93
  def __init__(
 
100
  max_ncells=None,
101
  max_ncells_per_class=None,
102
  training_args=None,
103
+ ray_config=None,
104
  freeze_layers=0,
105
  num_crossval_splits=1,
106
+ split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1},
107
  stratify_splits_col=None,
108
  no_eval=False,
109
  forward_batch_size=100,
110
  nproc=4,
111
+ ngpu=1,
112
  ):
113
  """
114
  Initialize Geneformer classifier.
 
155
  | Otherwise, will use the Hugging Face defaults:
156
  | https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
157
  | Note: Hyperparameter tuning is highly recommended, rather than using defaults.
158
+ ray_config : None, dict
159
+ | Training argument ranges for tuning hyperparameters with Ray.
160
  freeze_layers : int
161
  | Number of layers to freeze from fine-tuning.
162
  | 0: no layers will be frozen; 2: first two layers will be frozen; etc.
163
  num_crossval_splits : {0, 1, 5}
164
  | 0: train on all data without splitting
165
+ | 1: split data into train and eval sets by designated split_sizes["valid"]
166
+ | 5: split data into 5 folds of train and eval sets by designated split_sizes["valid"]
167
+ split_sizes : None, dict
168
+ | Dictionary of proportion of data to hold out for train, validation, and test sets
169
+ | {"train": 0.8, "valid": 0.1, "test": 0.1} if intending 80/10/10 train/valid/test split
170
  stratify_splits_col : None, str
171
  | Name of column in .dataset to be used for stratified splitting.
172
  | Proportion of each class in this column will be the same in the splits as in the original dataset.
 
177
  | Batch size for forward pass (for evaluation, not training).
178
  nproc : int
179
  | Number of CPU processes to use.
180
+ ngpu : int
181
+ | Number of GPUs available.
182
 
183
  """
184
 
 
190
  self.max_ncells = max_ncells
191
  self.max_ncells_per_class = max_ncells_per_class
192
  self.training_args = training_args
193
+ self.ray_config = ray_config
194
  self.freeze_layers = freeze_layers
195
  self.num_crossval_splits = num_crossval_splits
196
+ self.split_sizes = split_sizes
197
+ self.train_size = self.split_sizes["train"]
198
+ self.valid_size = self.split_sizes["valid"]
199
+ self.oos_test_size = self.split_sizes["test"]
200
+ self.eval_size = self.valid_size / (self.train_size + self.valid_size)
201
  self.stratify_splits_col = stratify_splits_col
202
  self.no_eval = no_eval
203
  self.forward_batch_size = forward_batch_size
204
  self.nproc = nproc
205
+ self.ngpu = ngpu
206
 
207
  if self.training_args is None:
208
  logger.warning(
 
315
  "Gene_class_dict should contain at least 2 gene classes to classify."
316
  )
317
  raise
318
+ if sum(self.split_sizes.values()) != 1:
319
+ logger.error("Train, validation, and test proportions should sum to 1.")
320
+ raise
321
 
322
  def prepare_data(
323
  self,
 
354
  test_size : None, float
355
  | Proportion of data to be saved separately and held out for test set
356
  | (e.g. 0.2 if intending hold out 20%)
357
+ | If None, will inherit from split_sizes["test"] from Classifier
358
  | The training set will be further split to train / validation in self.validate
359
  | Note: only available for CellClassifiers
360
  attr_to_split : None, str
 
374
  | E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
375
  """
376
 
377
+ if test_size is None:
378
+ test_size = self.oos_test_size
379
+
380
  # prepare data and labels for classification
381
  data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
382
 
 
576
  save_eval_output=True,
577
  predict_eval=True,
578
  predict_trainer=False,
579
+ n_hyperopt_trials=0,
580
  ):
581
  """
582
  (Cross-)validate cell state or gene classifier.
 
626
  predict_trainer : bool
627
  | Whether or not to save eval predictions from trainer
628
  | Saves as a pickle file of trainer predictions
629
+ n_hyperopt_trials : int
630
+ | Number of trials to run for hyperparameter optimization
631
+ | If 0, will not optimize hyperparameters
632
  """
633
 
634
  if self.num_crossval_splits == 0:
 
725
  ]
726
  eval_data = data.select(eval_indices)
727
  train_data = data.select(train_indices)
728
+ if n_hyperopt_trials == 0:
729
+ trainer = self.train_classifier(
730
+ model_directory,
731
+ num_classes,
732
+ train_data,
733
+ eval_data,
734
+ ksplit_output_dir,
735
+ predict_trainer,
736
+ )
737
+ else:
738
+ trainer = self.hyperopt_classifier(
739
+ model_directory,
740
+ num_classes,
741
+ train_data,
742
+ eval_data,
743
+ ksplit_output_dir,
744
+ n_trials=n_hyperopt_trials,
745
+ )
746
+ if iteration_num == self.num_crossval_splits:
747
+ return
748
+ else:
749
+ iteration_num = iteration_num + 1
750
+ continue
751
+
752
  result = self.evaluate_model(
753
  trainer.model,
754
  num_classes,
 
793
  self.nproc,
794
  )
795
 
796
+ if n_hyperopt_trials == 0:
797
+ trainer = self.train_classifier(
798
+ model_directory,
799
+ num_classes,
800
+ train_data,
801
+ eval_data,
802
+ ksplit_output_dir,
803
+ predict_trainer,
804
+ )
805
+ else:
806
+ trainer = self.hyperopt_classifier(
807
+ model_directory,
808
+ num_classes,
809
+ train_data,
810
+ eval_data,
811
+ ksplit_output_dir,
812
+ n_trials=n_hyperopt_trials,
813
+ )
814
+ if iteration_num == self.num_crossval_splits:
815
+ return
816
+ else:
817
+ iteration_num = iteration_num + 1
818
+ continue
819
  result = self.evaluate_model(
820
  trainer.model,
821
  num_classes,
 
866
 
867
  return all_metrics
868
 
869
+ def hyperopt_classifier(
870
+ self,
871
+ model_directory,
872
+ num_classes,
873
+ train_data,
874
+ eval_data,
875
+ output_directory,
876
+ n_trials=100,
877
+ ):
878
+ """
879
+ Fine-tune model for cell state or gene classification.
880
+
881
+ **Parameters**
882
+
883
+ model_directory : Path
884
+ | Path to directory containing model
885
+ num_classes : int
886
+ | Number of classes for classifier
887
+ train_data : Dataset
888
+ | Loaded training .dataset input
889
+ | For cell classifier, labels in column "label".
890
+ | For gene classifier, labels in column "labels".
891
+ eval_data : None, Dataset
892
+ | (Optional) Loaded evaluation .dataset input
893
+ | For cell classifier, labels in column "label".
894
+ | For gene classifier, labels in column "labels".
895
+ output_directory : Path
896
+ | Path to directory where fine-tuned model will be saved
897
+ n_trials : int
898
+ | Number of trials to run for hyperparameter optimization
899
+ """
900
+
901
+ # initiate runtime environment for raytune
902
+ import ray
903
+ from ray import tune
904
+ from ray.tune.search.hyperopt import HyperOptSearch
905
+
906
+ ray.shutdown() # engage new ray session
907
+ ray.init()
908
+
909
+ ##### Validate and prepare data #####
910
+ train_data, eval_data = cu.validate_and_clean_cols(
911
+ train_data, eval_data, self.classifier
912
+ )
913
+
914
+ if (self.no_eval is True) and (eval_data is not None):
915
+ logger.warning(
916
+ "no_eval set to True; hyperparameter optimization requires eval, proceeding with eval"
917
+ )
918
+
919
+ # ensure not overwriting previously saved model
920
+ saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
921
+ if os.path.isfile(saved_model_test) is True:
922
+ logger.error("Model already saved to this designated output directory.")
923
+ raise
924
+ # make output directory
925
+ subprocess.call(f"mkdir {output_directory}", shell=True)
926
+
927
+ ##### Load model and training args #####
928
+ if self.classifier == "cell":
929
+ model_type = "CellClassifier"
930
+ elif self.classifier == "gene":
931
+ model_type = "GeneClassifier"
932
+
933
+ model = pu.load_model(model_type, num_classes, model_directory, "train")
934
+ def_training_args, def_freeze_layers = cu.get_default_train_args(
935
+ model, self.classifier, train_data, output_directory
936
+ )
937
+ del model
938
+
939
+ if self.training_args is not None:
940
+ def_training_args.update(self.training_args)
941
+ logging_steps = round(
942
+ len(train_data) / def_training_args["per_device_train_batch_size"] / 10
943
+ )
944
+ def_training_args["logging_steps"] = logging_steps
945
+ def_training_args["output_dir"] = output_directory
946
+ if eval_data is None:
947
+ def_training_args["evaluation_strategy"] = "no"
948
+ def_training_args["load_best_model_at_end"] = False
949
+ training_args_init = TrainingArguments(**def_training_args)
950
+
951
+ ##### Fine-tune the model #####
952
+ # define the data collator
953
+ if self.classifier == "cell":
954
+ data_collator = DataCollatorForCellClassification()
955
+ elif self.classifier == "gene":
956
+ data_collator = DataCollatorForGeneClassification()
957
+
958
+ # define function to initiate model
959
+ def model_init():
960
+ model = pu.load_model(model_type, num_classes, model_directory, "train")
961
+
962
+ if self.freeze_layers is not None:
963
+ def_freeze_layers = self.freeze_layers
964
+
965
+ if def_freeze_layers > 0:
966
+ modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
967
+ for module in modules_to_freeze:
968
+ for param in module.parameters():
969
+ param.requires_grad = False
970
+
971
+ model = model.to("cuda:0")
972
+ return model
973
+
974
+ # create the trainer
975
+ trainer = Trainer(
976
+ model_init=model_init,
977
+ args=training_args_init,
978
+ data_collator=data_collator,
979
+ train_dataset=train_data,
980
+ eval_dataset=eval_data,
981
+ compute_metrics=cu.compute_metrics,
982
+ )
983
+
984
+ # specify raytune hyperparameter search space
985
+ if self.ray_config is None:
986
+ logger.warning(
987
+ "No ray_config provided. Proceeding with default, but ranges may need adjustment depending on model."
988
+ )
989
+ def_ray_config = {
990
+ "num_train_epochs": tune.choice([1]),
991
+ "learning_rate": tune.loguniform(1e-6, 1e-3),
992
+ "weight_decay": tune.uniform(0.0, 0.3),
993
+ "lr_scheduler_type": tune.choice(["linear", "cosine", "polynomial"]),
994
+ "warmup_steps": tune.uniform(100, 2000),
995
+ "seed": tune.uniform(0, 100),
996
+ "per_device_train_batch_size": tune.choice(
997
+ [def_training_args["per_device_train_batch_size"]]
998
+ ),
999
+ }
1000
+
1001
+ hyperopt_search = HyperOptSearch(metric="eval_macro_f1", mode="max")
1002
+
1003
+ # optimize hyperparameters
1004
+ trainer.hyperparameter_search(
1005
+ direction="maximize",
1006
+ backend="ray",
1007
+ resources_per_trial={"cpu": int(self.nproc / self.ngpu), "gpu": 1},
1008
+ hp_space=lambda _: def_ray_config
1009
+ if self.ray_config is None
1010
+ else self.ray_config,
1011
+ search_alg=hyperopt_search,
1012
+ n_trials=n_trials, # number of trials
1013
+ progress_reporter=tune.CLIReporter(
1014
+ max_report_frequency=600,
1015
+ sort_by_metric=True,
1016
+ max_progress_rows=n_trials,
1017
+ mode="max",
1018
+ metric="eval_macro_f1",
1019
+ metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
1020
+ ),
1021
+ )
1022
+
1023
+ return trainer
1024
+
1025
  def train_classifier(
1026
  self,
1027
  model_directory,
geneformer/classifier_utils.py CHANGED
@@ -360,9 +360,23 @@ def get_num_classes(id_class_dict):
360
  def compute_metrics(pred):
361
  labels = pred.label_ids
362
  preds = pred.predictions.argmax(-1)
 
363
  # calculate accuracy and macro f1 using sklearn's function
364
- acc = accuracy_score(labels, preds)
365
- macro_f1 = f1_score(labels, preds, average="macro")
 
 
 
 
 
 
 
 
 
 
 
 
 
366
  return {"accuracy": acc, "macro_f1": macro_f1}
367
 
368
 
@@ -387,6 +401,11 @@ def get_default_train_args(model, classifier, data, output_dir):
387
  "per_device_train_batch_size": batch_size,
388
  "per_device_eval_batch_size": batch_size,
389
  }
 
 
 
 
 
390
 
391
  training_args = {
392
  "num_train_epochs": epochs,
 
360
  def compute_metrics(pred):
361
  labels = pred.label_ids
362
  preds = pred.predictions.argmax(-1)
363
+
364
  # calculate accuracy and macro f1 using sklearn's function
365
+ if len(labels.shape) == 1:
366
+ acc = accuracy_score(labels, preds)
367
+ macro_f1 = f1_score(labels, preds, average="macro")
368
+ else:
369
+ flat_labels = labels.flatten().tolist()
370
+ flat_preds = preds.flatten().tolist()
371
+ logit_label_paired = [
372
+ item for item in list(zip(flat_preds, flat_labels)) if item[1] != -100
373
+ ]
374
+ y_pred = [item[0] for item in logit_label_paired]
375
+ y_true = [item[1] for item in logit_label_paired]
376
+
377
+ acc = accuracy_score(y_true, y_pred)
378
+ macro_f1 = f1_score(y_true, y_pred, average="macro")
379
+
380
  return {"accuracy": acc, "macro_f1": macro_f1}
381
 
382
 
 
401
  "per_device_train_batch_size": batch_size,
402
  "per_device_eval_batch_size": batch_size,
403
  }
404
+ else:
405
+ default_training_args = {
406
+ "per_device_train_batch_size": batch_size,
407
+ "per_device_eval_batch_size": batch_size,
408
+ }
409
 
410
  training_args = {
411
  "num_train_epochs": epochs,
requirements.txt CHANGED
@@ -1,5 +1,6 @@
1
  anndata>=0.9
2
  datasets>=2.12
 
3
  loompy>=3.0
4
  matplotlib>=3.7
5
  numpy>=1.23
 
1
  anndata>=0.9
2
  datasets>=2.12
3
+ hyperopt>=0.2
4
  loompy>=3.0
5
  matplotlib>=3.7
6
  numpy>=1.23