Christina Theodoris
commited on
Commit
•
4bddd45
1
Parent(s):
5a43832
add option for hyperparameter tuning to cc.validate
Browse files- examples/cell_classification.ipynb +3 -2
- examples/hyperparam_optimiz_for_disease_classifier.py +0 -226
- geneformer/classifier.py +235 -23
- geneformer/classifier_utils.py +21 -2
- requirements.txt +1 -0
examples/cell_classification.ipynb
CHANGED
@@ -13,7 +13,7 @@
|
|
13 |
"id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
|
14 |
"metadata": {},
|
15 |
"source": [
|
16 |
-
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but
|
17 |
]
|
18 |
},
|
19 |
{
|
@@ -266,7 +266,8 @@
|
|
266 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
267 |
" output_directory=output_dir,\n",
|
268 |
" output_prefix=output_prefix,\n",
|
269 |
-
" split_id_dict=train_valid_id_split_dict)"
|
|
|
270 |
]
|
271 |
},
|
272 |
{
|
|
|
13 |
"id": "1792e51c-86c3-406f-be5a-273c4e4aec20",
|
14 |
"metadata": {},
|
15 |
"source": [
|
16 |
+
"### Please note that, as usual with deep learning models, we **highly** recommend tuning learning hyperparameters for all fine-tuning applications as this can significantly improve model performance. Example below uses previously optimized hyperparameters, but one can optimize hyperparameters with the argument n_hyperopt_trials=n in cc.validate() where n>0 and represents the number of trials for hyperparameter optimization."
|
17 |
]
|
18 |
},
|
19 |
{
|
|
|
266 |
" id_class_dict_file=f\"{output_dir}/{output_prefix}_id_class_dict.pkl\",\n",
|
267 |
" output_directory=output_dir,\n",
|
268 |
" output_prefix=output_prefix,\n",
|
269 |
+
" split_id_dict=train_valid_id_split_dict)\n",
|
270 |
+
" # to optimize hyperparameters, set n_hyperopt_trials=100 (or alternative desired # of trials)"
|
271 |
]
|
272 |
},
|
273 |
{
|
examples/hyperparam_optimiz_for_disease_classifier.py
DELETED
@@ -1,226 +0,0 @@
|
|
1 |
-
#!/usr/bin/env python
|
2 |
-
# coding: utf-8
|
3 |
-
|
4 |
-
# hyperparameter optimization with raytune for disease classification
|
5 |
-
|
6 |
-
# imports
|
7 |
-
import os
|
8 |
-
import subprocess
|
9 |
-
GPU_NUMBER = [0,1,2,3]
|
10 |
-
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
|
11 |
-
os.environ["NCCL_DEBUG"] = "INFO"
|
12 |
-
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
|
13 |
-
os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
|
14 |
-
|
15 |
-
# initiate runtime environment for raytune
|
16 |
-
import pyarrow # must occur prior to ray import
|
17 |
-
import ray
|
18 |
-
from ray import tune
|
19 |
-
from ray.tune import ExperimentAnalysis
|
20 |
-
from ray.tune.suggest.hyperopt import HyperOptSearch
|
21 |
-
ray.shutdown() #engage new ray session
|
22 |
-
runtime_env = {"conda": "base",
|
23 |
-
"env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
|
24 |
-
ray.init(runtime_env=runtime_env)
|
25 |
-
|
26 |
-
def initialize_ray_with_check(ip_address):
|
27 |
-
"""
|
28 |
-
Initialize Ray with a specified IP address and check its status and accessibility.
|
29 |
-
|
30 |
-
Args:
|
31 |
-
- ip_address (str): The IP address (with port) to initialize Ray.
|
32 |
-
|
33 |
-
Returns:
|
34 |
-
- bool: True if initialization was successful and dashboard is accessible, False otherwise.
|
35 |
-
"""
|
36 |
-
try:
|
37 |
-
ray.init(address=ip_address)
|
38 |
-
print(ray.nodes())
|
39 |
-
|
40 |
-
services = ray.get_webui_url()
|
41 |
-
if not services:
|
42 |
-
raise RuntimeError("Ray dashboard is not accessible.")
|
43 |
-
else:
|
44 |
-
print(f"Ray dashboard is accessible at: {services}")
|
45 |
-
return True
|
46 |
-
except Exception as e:
|
47 |
-
print(f"Error initializing Ray: {e}")
|
48 |
-
return False
|
49 |
-
|
50 |
-
# Usage:
|
51 |
-
ip = 'your_ip:xxxx' # Replace with your actual IP address and port
|
52 |
-
if initialize_ray_with_check(ip):
|
53 |
-
print("Ray initialized successfully.")
|
54 |
-
else:
|
55 |
-
print("Error during Ray initialization.")
|
56 |
-
|
57 |
-
import datetime
|
58 |
-
import numpy as np
|
59 |
-
import pandas as pd
|
60 |
-
import random
|
61 |
-
import seaborn as sns; sns.set()
|
62 |
-
from collections import Counter
|
63 |
-
from datasets import load_from_disk
|
64 |
-
from scipy.stats import ranksums
|
65 |
-
from sklearn.metrics import accuracy_score
|
66 |
-
from transformers import BertForSequenceClassification
|
67 |
-
from transformers import Trainer
|
68 |
-
from transformers.training_args import TrainingArguments
|
69 |
-
|
70 |
-
from geneformer import DataCollatorForCellClassification
|
71 |
-
|
72 |
-
# number of CPU cores
|
73 |
-
num_proc=30
|
74 |
-
|
75 |
-
# load train dataset with columns:
|
76 |
-
# cell_type (annotation of each cell's type)
|
77 |
-
# disease (healthy or disease state)
|
78 |
-
# individual (unique ID for each patient)
|
79 |
-
# length (length of that cell's rank value encoding)
|
80 |
-
train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
|
81 |
-
|
82 |
-
# filter dataset for given cell_type
|
83 |
-
def if_cell_type(example):
|
84 |
-
return example["cell_type"].startswith("Cardiomyocyte")
|
85 |
-
|
86 |
-
trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
|
87 |
-
|
88 |
-
# create dictionary of disease states : label ids
|
89 |
-
target_names = ["healthy", "disease1", "disease2"]
|
90 |
-
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
|
91 |
-
|
92 |
-
trainset_v3 = trainset_v2.rename_column("disease","label")
|
93 |
-
|
94 |
-
# change labels to numerical ids
|
95 |
-
def classes_to_ids(example):
|
96 |
-
example["label"] = target_name_id_dict[example["label"]]
|
97 |
-
return example
|
98 |
-
|
99 |
-
trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
|
100 |
-
|
101 |
-
# separate into train, validation, test sets
|
102 |
-
indiv_set = set(trainset_v4["individual"])
|
103 |
-
random.seed(42)
|
104 |
-
train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set)))
|
105 |
-
eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv]
|
106 |
-
valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
|
107 |
-
test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
|
108 |
-
|
109 |
-
def if_train(example):
|
110 |
-
return example["individual"] in train_indiv
|
111 |
-
|
112 |
-
classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
|
113 |
-
|
114 |
-
def if_valid(example):
|
115 |
-
return example["individual"] in valid_indiv
|
116 |
-
|
117 |
-
classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
|
118 |
-
|
119 |
-
# define output directory path
|
120 |
-
current_date = datetime.datetime.now()
|
121 |
-
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
|
122 |
-
output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
|
123 |
-
|
124 |
-
# ensure not overwriting previously saved model
|
125 |
-
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
|
126 |
-
if os.path.isfile(saved_model_test) == True:
|
127 |
-
raise Exception("Model already saved to this directory.")
|
128 |
-
|
129 |
-
# make output directory
|
130 |
-
subprocess.call(f'mkdir {output_dir}', shell=True)
|
131 |
-
|
132 |
-
# set training parameters
|
133 |
-
# how many pretrained layers to freeze
|
134 |
-
freeze_layers = 2
|
135 |
-
# batch size for training and eval
|
136 |
-
geneformer_batch_size = 12
|
137 |
-
# number of epochs
|
138 |
-
epochs = 1
|
139 |
-
# logging steps
|
140 |
-
logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
|
141 |
-
|
142 |
-
# define function to initiate model
|
143 |
-
def model_init():
|
144 |
-
model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
|
145 |
-
num_labels=len(target_names),
|
146 |
-
output_attentions = False,
|
147 |
-
output_hidden_states = False)
|
148 |
-
if freeze_layers is not None:
|
149 |
-
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
|
150 |
-
for module in modules_to_freeze:
|
151 |
-
for param in module.parameters():
|
152 |
-
param.requires_grad = False
|
153 |
-
|
154 |
-
model = model.to("cuda:0")
|
155 |
-
return model
|
156 |
-
|
157 |
-
# define metrics
|
158 |
-
# note: macro f1 score recommended for imbalanced multiclass classifiers
|
159 |
-
def compute_metrics(pred):
|
160 |
-
labels = pred.label_ids
|
161 |
-
preds = pred.predictions.argmax(-1)
|
162 |
-
# calculate accuracy using sklearn's function
|
163 |
-
acc = accuracy_score(labels, preds)
|
164 |
-
return {
|
165 |
-
'accuracy': acc,
|
166 |
-
}
|
167 |
-
|
168 |
-
# set training arguments
|
169 |
-
training_args = {
|
170 |
-
"do_train": True,
|
171 |
-
"do_eval": True,
|
172 |
-
"evaluation_strategy": "steps",
|
173 |
-
"eval_steps": logging_steps,
|
174 |
-
"logging_steps": logging_steps,
|
175 |
-
"group_by_length": True,
|
176 |
-
"length_column_name": "length",
|
177 |
-
"disable_tqdm": True,
|
178 |
-
"skip_memory_metrics": True, # memory tracker causes errors in raytune
|
179 |
-
"per_device_train_batch_size": geneformer_batch_size,
|
180 |
-
"per_device_eval_batch_size": geneformer_batch_size,
|
181 |
-
"num_train_epochs": epochs,
|
182 |
-
"load_best_model_at_end": True,
|
183 |
-
"output_dir": output_dir,
|
184 |
-
}
|
185 |
-
|
186 |
-
training_args_init = TrainingArguments(**training_args)
|
187 |
-
|
188 |
-
# create the trainer
|
189 |
-
trainer = Trainer(
|
190 |
-
model_init=model_init,
|
191 |
-
args=training_args_init,
|
192 |
-
data_collator=DataCollatorForCellClassification(),
|
193 |
-
train_dataset=classifier_trainset,
|
194 |
-
eval_dataset=classifier_validset,
|
195 |
-
compute_metrics=compute_metrics,
|
196 |
-
)
|
197 |
-
|
198 |
-
# specify raytune hyperparameter search space
|
199 |
-
ray_config = {
|
200 |
-
"num_train_epochs": tune.choice([epochs]),
|
201 |
-
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
202 |
-
"weight_decay": tune.uniform(0.0, 0.3),
|
203 |
-
"lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
|
204 |
-
"warmup_steps": tune.uniform(100, 2000),
|
205 |
-
"seed": tune.uniform(0,100),
|
206 |
-
"per_device_train_batch_size": tune.choice([geneformer_batch_size])
|
207 |
-
}
|
208 |
-
|
209 |
-
hyperopt_search = HyperOptSearch(
|
210 |
-
metric="eval_accuracy", mode="max")
|
211 |
-
|
212 |
-
# optimize hyperparameters
|
213 |
-
trainer.hyperparameter_search(
|
214 |
-
direction="maximize",
|
215 |
-
backend="ray",
|
216 |
-
resources_per_trial={"cpu":8,"gpu":1},
|
217 |
-
hp_space=lambda _: ray_config,
|
218 |
-
search_alg=hyperopt_search,
|
219 |
-
n_trials=100, # number of trials
|
220 |
-
progress_reporter=tune.CLIReporter(max_report_frequency=600,
|
221 |
-
sort_by_metric=True,
|
222 |
-
max_progress_rows=100,
|
223 |
-
mode="max",
|
224 |
-
metric="eval_accuracy",
|
225 |
-
metric_columns=["loss", "eval_loss", "eval_accuracy"])
|
226 |
-
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
geneformer/classifier.py
CHANGED
@@ -82,11 +82,12 @@ class Classifier:
|
|
82 |
"training_args": {None, dict},
|
83 |
"freeze_layers": {int},
|
84 |
"num_crossval_splits": {0, 1, 5},
|
85 |
-
"
|
86 |
"no_eval": {bool},
|
87 |
"stratify_splits_col": {None, str},
|
88 |
"forward_batch_size": {int},
|
89 |
"nproc": {int},
|
|
|
90 |
}
|
91 |
|
92 |
def __init__(
|
@@ -99,13 +100,15 @@ class Classifier:
|
|
99 |
max_ncells=None,
|
100 |
max_ncells_per_class=None,
|
101 |
training_args=None,
|
|
|
102 |
freeze_layers=0,
|
103 |
num_crossval_splits=1,
|
104 |
-
|
105 |
stratify_splits_col=None,
|
106 |
no_eval=False,
|
107 |
forward_batch_size=100,
|
108 |
nproc=4,
|
|
|
109 |
):
|
110 |
"""
|
111 |
Initialize Geneformer classifier.
|
@@ -152,15 +155,18 @@ class Classifier:
|
|
152 |
| Otherwise, will use the Hugging Face defaults:
|
153 |
| https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
|
154 |
| Note: Hyperparameter tuning is highly recommended, rather than using defaults.
|
|
|
|
|
155 |
freeze_layers : int
|
156 |
| Number of layers to freeze from fine-tuning.
|
157 |
| 0: no layers will be frozen; 2: first two layers will be frozen; etc.
|
158 |
num_crossval_splits : {0, 1, 5}
|
159 |
| 0: train on all data without splitting
|
160 |
-
| 1: split data into train and eval sets by designated
|
161 |
-
| 5: split data into 5 folds of train and eval sets by designated
|
162 |
-
|
163 |
-
|
|
|
|
164 |
stratify_splits_col : None, str
|
165 |
| Name of column in .dataset to be used for stratified splitting.
|
166 |
| Proportion of each class in this column will be the same in the splits as in the original dataset.
|
@@ -171,6 +177,8 @@ class Classifier:
|
|
171 |
| Batch size for forward pass (for evaluation, not training).
|
172 |
nproc : int
|
173 |
| Number of CPU processes to use.
|
|
|
|
|
174 |
|
175 |
"""
|
176 |
|
@@ -182,13 +190,19 @@ class Classifier:
|
|
182 |
self.max_ncells = max_ncells
|
183 |
self.max_ncells_per_class = max_ncells_per_class
|
184 |
self.training_args = training_args
|
|
|
185 |
self.freeze_layers = freeze_layers
|
186 |
self.num_crossval_splits = num_crossval_splits
|
187 |
-
self.
|
|
|
|
|
|
|
|
|
188 |
self.stratify_splits_col = stratify_splits_col
|
189 |
self.no_eval = no_eval
|
190 |
self.forward_batch_size = forward_batch_size
|
191 |
self.nproc = nproc
|
|
|
192 |
|
193 |
if self.training_args is None:
|
194 |
logger.warning(
|
@@ -301,6 +315,9 @@ class Classifier:
|
|
301 |
"Gene_class_dict should contain at least 2 gene classes to classify."
|
302 |
)
|
303 |
raise
|
|
|
|
|
|
|
304 |
|
305 |
def prepare_data(
|
306 |
self,
|
@@ -337,6 +354,7 @@ class Classifier:
|
|
337 |
test_size : None, float
|
338 |
| Proportion of data to be saved separately and held out for test set
|
339 |
| (e.g. 0.2 if intending hold out 20%)
|
|
|
340 |
| The training set will be further split to train / validation in self.validate
|
341 |
| Note: only available for CellClassifiers
|
342 |
attr_to_split : None, str
|
@@ -356,6 +374,9 @@ class Classifier:
|
|
356 |
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
357 |
"""
|
358 |
|
|
|
|
|
|
|
359 |
# prepare data and labels for classification
|
360 |
data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
|
361 |
|
@@ -555,6 +576,7 @@ class Classifier:
|
|
555 |
save_eval_output=True,
|
556 |
predict_eval=True,
|
557 |
predict_trainer=False,
|
|
|
558 |
):
|
559 |
"""
|
560 |
(Cross-)validate cell state or gene classifier.
|
@@ -604,6 +626,9 @@ class Classifier:
|
|
604 |
predict_trainer : bool
|
605 |
| Whether or not to save eval predictions from trainer
|
606 |
| Saves as a pickle file of trainer predictions
|
|
|
|
|
|
|
607 |
"""
|
608 |
|
609 |
if self.num_crossval_splits == 0:
|
@@ -700,14 +725,30 @@ class Classifier:
|
|
700 |
]
|
701 |
eval_data = data.select(eval_indices)
|
702 |
train_data = data.select(train_indices)
|
703 |
-
|
704 |
-
|
705 |
-
|
706 |
-
|
707 |
-
|
708 |
-
|
709 |
-
|
710 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
711 |
result = self.evaluate_model(
|
712 |
trainer.model,
|
713 |
num_classes,
|
@@ -752,14 +793,29 @@ class Classifier:
|
|
752 |
self.nproc,
|
753 |
)
|
754 |
|
755 |
-
|
756 |
-
|
757 |
-
|
758 |
-
|
759 |
-
|
760 |
-
|
761 |
-
|
762 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
763 |
result = self.evaluate_model(
|
764 |
trainer.model,
|
765 |
num_classes,
|
@@ -810,6 +866,162 @@ class Classifier:
|
|
810 |
|
811 |
return all_metrics
|
812 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
813 |
def train_classifier(
|
814 |
self,
|
815 |
model_directory,
|
|
|
82 |
"training_args": {None, dict},
|
83 |
"freeze_layers": {int},
|
84 |
"num_crossval_splits": {0, 1, 5},
|
85 |
+
"split_sizes": {None, dict},
|
86 |
"no_eval": {bool},
|
87 |
"stratify_splits_col": {None, str},
|
88 |
"forward_batch_size": {int},
|
89 |
"nproc": {int},
|
90 |
+
"ngpu": {int},
|
91 |
}
|
92 |
|
93 |
def __init__(
|
|
|
100 |
max_ncells=None,
|
101 |
max_ncells_per_class=None,
|
102 |
training_args=None,
|
103 |
+
ray_config=None,
|
104 |
freeze_layers=0,
|
105 |
num_crossval_splits=1,
|
106 |
+
split_sizes={"train": 0.8, "valid": 0.1, "test": 0.1},
|
107 |
stratify_splits_col=None,
|
108 |
no_eval=False,
|
109 |
forward_batch_size=100,
|
110 |
nproc=4,
|
111 |
+
ngpu=1,
|
112 |
):
|
113 |
"""
|
114 |
Initialize Geneformer classifier.
|
|
|
155 |
| Otherwise, will use the Hugging Face defaults:
|
156 |
| https://huggingface.co/docs/transformers/main_classes/trainer#transformers.TrainingArguments
|
157 |
| Note: Hyperparameter tuning is highly recommended, rather than using defaults.
|
158 |
+
ray_config : None, dict
|
159 |
+
| Training argument ranges for tuning hyperparameters with Ray.
|
160 |
freeze_layers : int
|
161 |
| Number of layers to freeze from fine-tuning.
|
162 |
| 0: no layers will be frozen; 2: first two layers will be frozen; etc.
|
163 |
num_crossval_splits : {0, 1, 5}
|
164 |
| 0: train on all data without splitting
|
165 |
+
| 1: split data into train and eval sets by designated split_sizes["valid"]
|
166 |
+
| 5: split data into 5 folds of train and eval sets by designated split_sizes["valid"]
|
167 |
+
split_sizes : None, dict
|
168 |
+
| Dictionary of proportion of data to hold out for train, validation, and test sets
|
169 |
+
| {"train": 0.8, "valid": 0.1, "test": 0.1} if intending 80/10/10 train/valid/test split
|
170 |
stratify_splits_col : None, str
|
171 |
| Name of column in .dataset to be used for stratified splitting.
|
172 |
| Proportion of each class in this column will be the same in the splits as in the original dataset.
|
|
|
177 |
| Batch size for forward pass (for evaluation, not training).
|
178 |
nproc : int
|
179 |
| Number of CPU processes to use.
|
180 |
+
ngpu : int
|
181 |
+
| Number of GPUs available.
|
182 |
|
183 |
"""
|
184 |
|
|
|
190 |
self.max_ncells = max_ncells
|
191 |
self.max_ncells_per_class = max_ncells_per_class
|
192 |
self.training_args = training_args
|
193 |
+
self.ray_config = ray_config
|
194 |
self.freeze_layers = freeze_layers
|
195 |
self.num_crossval_splits = num_crossval_splits
|
196 |
+
self.split_sizes = split_sizes
|
197 |
+
self.train_size = self.split_sizes["train"]
|
198 |
+
self.valid_size = self.split_sizes["valid"]
|
199 |
+
self.oos_test_size = self.split_sizes["test"]
|
200 |
+
self.eval_size = self.valid_size / (self.train_size + self.valid_size)
|
201 |
self.stratify_splits_col = stratify_splits_col
|
202 |
self.no_eval = no_eval
|
203 |
self.forward_batch_size = forward_batch_size
|
204 |
self.nproc = nproc
|
205 |
+
self.ngpu = ngpu
|
206 |
|
207 |
if self.training_args is None:
|
208 |
logger.warning(
|
|
|
315 |
"Gene_class_dict should contain at least 2 gene classes to classify."
|
316 |
)
|
317 |
raise
|
318 |
+
if sum(self.split_sizes.values()) != 1:
|
319 |
+
logger.error("Train, validation, and test proportions should sum to 1.")
|
320 |
+
raise
|
321 |
|
322 |
def prepare_data(
|
323 |
self,
|
|
|
354 |
test_size : None, float
|
355 |
| Proportion of data to be saved separately and held out for test set
|
356 |
| (e.g. 0.2 if intending hold out 20%)
|
357 |
+
| If None, will inherit from split_sizes["test"] from Classifier
|
358 |
| The training set will be further split to train / validation in self.validate
|
359 |
| Note: only available for CellClassifiers
|
360 |
attr_to_split : None, str
|
|
|
374 |
| E.g. if set to 0.1, will accept trial if p >= 0.1 for all attributes in attr_to_balance
|
375 |
"""
|
376 |
|
377 |
+
if test_size is None:
|
378 |
+
test_size = self.oos_test_size
|
379 |
+
|
380 |
# prepare data and labels for classification
|
381 |
data = pu.load_and_filter(self.filter_data, self.nproc, input_data_file)
|
382 |
|
|
|
576 |
save_eval_output=True,
|
577 |
predict_eval=True,
|
578 |
predict_trainer=False,
|
579 |
+
n_hyperopt_trials=0,
|
580 |
):
|
581 |
"""
|
582 |
(Cross-)validate cell state or gene classifier.
|
|
|
626 |
predict_trainer : bool
|
627 |
| Whether or not to save eval predictions from trainer
|
628 |
| Saves as a pickle file of trainer predictions
|
629 |
+
n_hyperopt_trials : int
|
630 |
+
| Number of trials to run for hyperparameter optimization
|
631 |
+
| If 0, will not optimize hyperparameters
|
632 |
"""
|
633 |
|
634 |
if self.num_crossval_splits == 0:
|
|
|
725 |
]
|
726 |
eval_data = data.select(eval_indices)
|
727 |
train_data = data.select(train_indices)
|
728 |
+
if n_hyperopt_trials == 0:
|
729 |
+
trainer = self.train_classifier(
|
730 |
+
model_directory,
|
731 |
+
num_classes,
|
732 |
+
train_data,
|
733 |
+
eval_data,
|
734 |
+
ksplit_output_dir,
|
735 |
+
predict_trainer,
|
736 |
+
)
|
737 |
+
else:
|
738 |
+
trainer = self.hyperopt_classifier(
|
739 |
+
model_directory,
|
740 |
+
num_classes,
|
741 |
+
train_data,
|
742 |
+
eval_data,
|
743 |
+
ksplit_output_dir,
|
744 |
+
n_trials=n_hyperopt_trials,
|
745 |
+
)
|
746 |
+
if iteration_num == self.num_crossval_splits:
|
747 |
+
return
|
748 |
+
else:
|
749 |
+
iteration_num = iteration_num + 1
|
750 |
+
continue
|
751 |
+
|
752 |
result = self.evaluate_model(
|
753 |
trainer.model,
|
754 |
num_classes,
|
|
|
793 |
self.nproc,
|
794 |
)
|
795 |
|
796 |
+
if n_hyperopt_trials == 0:
|
797 |
+
trainer = self.train_classifier(
|
798 |
+
model_directory,
|
799 |
+
num_classes,
|
800 |
+
train_data,
|
801 |
+
eval_data,
|
802 |
+
ksplit_output_dir,
|
803 |
+
predict_trainer,
|
804 |
+
)
|
805 |
+
else:
|
806 |
+
trainer = self.hyperopt_classifier(
|
807 |
+
model_directory,
|
808 |
+
num_classes,
|
809 |
+
train_data,
|
810 |
+
eval_data,
|
811 |
+
ksplit_output_dir,
|
812 |
+
n_trials=n_hyperopt_trials,
|
813 |
+
)
|
814 |
+
if iteration_num == self.num_crossval_splits:
|
815 |
+
return
|
816 |
+
else:
|
817 |
+
iteration_num = iteration_num + 1
|
818 |
+
continue
|
819 |
result = self.evaluate_model(
|
820 |
trainer.model,
|
821 |
num_classes,
|
|
|
866 |
|
867 |
return all_metrics
|
868 |
|
869 |
+
def hyperopt_classifier(
|
870 |
+
self,
|
871 |
+
model_directory,
|
872 |
+
num_classes,
|
873 |
+
train_data,
|
874 |
+
eval_data,
|
875 |
+
output_directory,
|
876 |
+
n_trials=100,
|
877 |
+
):
|
878 |
+
"""
|
879 |
+
Fine-tune model for cell state or gene classification.
|
880 |
+
|
881 |
+
**Parameters**
|
882 |
+
|
883 |
+
model_directory : Path
|
884 |
+
| Path to directory containing model
|
885 |
+
num_classes : int
|
886 |
+
| Number of classes for classifier
|
887 |
+
train_data : Dataset
|
888 |
+
| Loaded training .dataset input
|
889 |
+
| For cell classifier, labels in column "label".
|
890 |
+
| For gene classifier, labels in column "labels".
|
891 |
+
eval_data : None, Dataset
|
892 |
+
| (Optional) Loaded evaluation .dataset input
|
893 |
+
| For cell classifier, labels in column "label".
|
894 |
+
| For gene classifier, labels in column "labels".
|
895 |
+
output_directory : Path
|
896 |
+
| Path to directory where fine-tuned model will be saved
|
897 |
+
n_trials : int
|
898 |
+
| Number of trials to run for hyperparameter optimization
|
899 |
+
"""
|
900 |
+
|
901 |
+
# initiate runtime environment for raytune
|
902 |
+
import ray
|
903 |
+
from ray import tune
|
904 |
+
from ray.tune.search.hyperopt import HyperOptSearch
|
905 |
+
|
906 |
+
ray.shutdown() # engage new ray session
|
907 |
+
ray.init()
|
908 |
+
|
909 |
+
##### Validate and prepare data #####
|
910 |
+
train_data, eval_data = cu.validate_and_clean_cols(
|
911 |
+
train_data, eval_data, self.classifier
|
912 |
+
)
|
913 |
+
|
914 |
+
if (self.no_eval is True) and (eval_data is not None):
|
915 |
+
logger.warning(
|
916 |
+
"no_eval set to True; hyperparameter optimization requires eval, proceeding with eval"
|
917 |
+
)
|
918 |
+
|
919 |
+
# ensure not overwriting previously saved model
|
920 |
+
saved_model_test = os.path.join(output_directory, "pytorch_model.bin")
|
921 |
+
if os.path.isfile(saved_model_test) is True:
|
922 |
+
logger.error("Model already saved to this designated output directory.")
|
923 |
+
raise
|
924 |
+
# make output directory
|
925 |
+
subprocess.call(f"mkdir {output_directory}", shell=True)
|
926 |
+
|
927 |
+
##### Load model and training args #####
|
928 |
+
if self.classifier == "cell":
|
929 |
+
model_type = "CellClassifier"
|
930 |
+
elif self.classifier == "gene":
|
931 |
+
model_type = "GeneClassifier"
|
932 |
+
|
933 |
+
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
934 |
+
def_training_args, def_freeze_layers = cu.get_default_train_args(
|
935 |
+
model, self.classifier, train_data, output_directory
|
936 |
+
)
|
937 |
+
del model
|
938 |
+
|
939 |
+
if self.training_args is not None:
|
940 |
+
def_training_args.update(self.training_args)
|
941 |
+
logging_steps = round(
|
942 |
+
len(train_data) / def_training_args["per_device_train_batch_size"] / 10
|
943 |
+
)
|
944 |
+
def_training_args["logging_steps"] = logging_steps
|
945 |
+
def_training_args["output_dir"] = output_directory
|
946 |
+
if eval_data is None:
|
947 |
+
def_training_args["evaluation_strategy"] = "no"
|
948 |
+
def_training_args["load_best_model_at_end"] = False
|
949 |
+
training_args_init = TrainingArguments(**def_training_args)
|
950 |
+
|
951 |
+
##### Fine-tune the model #####
|
952 |
+
# define the data collator
|
953 |
+
if self.classifier == "cell":
|
954 |
+
data_collator = DataCollatorForCellClassification()
|
955 |
+
elif self.classifier == "gene":
|
956 |
+
data_collator = DataCollatorForGeneClassification()
|
957 |
+
|
958 |
+
# define function to initiate model
|
959 |
+
def model_init():
|
960 |
+
model = pu.load_model(model_type, num_classes, model_directory, "train")
|
961 |
+
|
962 |
+
if self.freeze_layers is not None:
|
963 |
+
def_freeze_layers = self.freeze_layers
|
964 |
+
|
965 |
+
if def_freeze_layers > 0:
|
966 |
+
modules_to_freeze = model.bert.encoder.layer[:def_freeze_layers]
|
967 |
+
for module in modules_to_freeze:
|
968 |
+
for param in module.parameters():
|
969 |
+
param.requires_grad = False
|
970 |
+
|
971 |
+
model = model.to("cuda:0")
|
972 |
+
return model
|
973 |
+
|
974 |
+
# create the trainer
|
975 |
+
trainer = Trainer(
|
976 |
+
model_init=model_init,
|
977 |
+
args=training_args_init,
|
978 |
+
data_collator=data_collator,
|
979 |
+
train_dataset=train_data,
|
980 |
+
eval_dataset=eval_data,
|
981 |
+
compute_metrics=cu.compute_metrics,
|
982 |
+
)
|
983 |
+
|
984 |
+
# specify raytune hyperparameter search space
|
985 |
+
if self.ray_config is None:
|
986 |
+
logger.warning(
|
987 |
+
"No ray_config provided. Proceeding with default, but ranges may need adjustment depending on model."
|
988 |
+
)
|
989 |
+
def_ray_config = {
|
990 |
+
"num_train_epochs": tune.choice([1]),
|
991 |
+
"learning_rate": tune.loguniform(1e-6, 1e-3),
|
992 |
+
"weight_decay": tune.uniform(0.0, 0.3),
|
993 |
+
"lr_scheduler_type": tune.choice(["linear", "cosine", "polynomial"]),
|
994 |
+
"warmup_steps": tune.uniform(100, 2000),
|
995 |
+
"seed": tune.uniform(0, 100),
|
996 |
+
"per_device_train_batch_size": tune.choice(
|
997 |
+
[def_training_args["per_device_train_batch_size"]]
|
998 |
+
),
|
999 |
+
}
|
1000 |
+
|
1001 |
+
hyperopt_search = HyperOptSearch(metric="eval_macro_f1", mode="max")
|
1002 |
+
|
1003 |
+
# optimize hyperparameters
|
1004 |
+
trainer.hyperparameter_search(
|
1005 |
+
direction="maximize",
|
1006 |
+
backend="ray",
|
1007 |
+
resources_per_trial={"cpu": int(self.nproc / self.ngpu), "gpu": 1},
|
1008 |
+
hp_space=lambda _: def_ray_config
|
1009 |
+
if self.ray_config is None
|
1010 |
+
else self.ray_config,
|
1011 |
+
search_alg=hyperopt_search,
|
1012 |
+
n_trials=n_trials, # number of trials
|
1013 |
+
progress_reporter=tune.CLIReporter(
|
1014 |
+
max_report_frequency=600,
|
1015 |
+
sort_by_metric=True,
|
1016 |
+
max_progress_rows=n_trials,
|
1017 |
+
mode="max",
|
1018 |
+
metric="eval_macro_f1",
|
1019 |
+
metric_columns=["loss", "eval_loss", "eval_accuracy", "eval_macro_f1"],
|
1020 |
+
),
|
1021 |
+
)
|
1022 |
+
|
1023 |
+
return trainer
|
1024 |
+
|
1025 |
def train_classifier(
|
1026 |
self,
|
1027 |
model_directory,
|
geneformer/classifier_utils.py
CHANGED
@@ -360,9 +360,23 @@ def get_num_classes(id_class_dict):
|
|
360 |
def compute_metrics(pred):
|
361 |
labels = pred.label_ids
|
362 |
preds = pred.predictions.argmax(-1)
|
|
|
363 |
# calculate accuracy and macro f1 using sklearn's function
|
364 |
-
|
365 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
366 |
return {"accuracy": acc, "macro_f1": macro_f1}
|
367 |
|
368 |
|
@@ -387,6 +401,11 @@ def get_default_train_args(model, classifier, data, output_dir):
|
|
387 |
"per_device_train_batch_size": batch_size,
|
388 |
"per_device_eval_batch_size": batch_size,
|
389 |
}
|
|
|
|
|
|
|
|
|
|
|
390 |
|
391 |
training_args = {
|
392 |
"num_train_epochs": epochs,
|
|
|
360 |
def compute_metrics(pred):
|
361 |
labels = pred.label_ids
|
362 |
preds = pred.predictions.argmax(-1)
|
363 |
+
|
364 |
# calculate accuracy and macro f1 using sklearn's function
|
365 |
+
if len(labels.shape) == 1:
|
366 |
+
acc = accuracy_score(labels, preds)
|
367 |
+
macro_f1 = f1_score(labels, preds, average="macro")
|
368 |
+
else:
|
369 |
+
flat_labels = labels.flatten().tolist()
|
370 |
+
flat_preds = preds.flatten().tolist()
|
371 |
+
logit_label_paired = [
|
372 |
+
item for item in list(zip(flat_preds, flat_labels)) if item[1] != -100
|
373 |
+
]
|
374 |
+
y_pred = [item[0] for item in logit_label_paired]
|
375 |
+
y_true = [item[1] for item in logit_label_paired]
|
376 |
+
|
377 |
+
acc = accuracy_score(y_true, y_pred)
|
378 |
+
macro_f1 = f1_score(y_true, y_pred, average="macro")
|
379 |
+
|
380 |
return {"accuracy": acc, "macro_f1": macro_f1}
|
381 |
|
382 |
|
|
|
401 |
"per_device_train_batch_size": batch_size,
|
402 |
"per_device_eval_batch_size": batch_size,
|
403 |
}
|
404 |
+
else:
|
405 |
+
default_training_args = {
|
406 |
+
"per_device_train_batch_size": batch_size,
|
407 |
+
"per_device_eval_batch_size": batch_size,
|
408 |
+
}
|
409 |
|
410 |
training_args = {
|
411 |
"num_train_epochs": epochs,
|
requirements.txt
CHANGED
@@ -1,5 +1,6 @@
|
|
1 |
anndata>=0.9
|
2 |
datasets>=2.12
|
|
|
3 |
loompy>=3.0
|
4 |
matplotlib>=3.7
|
5 |
numpy>=1.23
|
|
|
1 |
anndata>=0.9
|
2 |
datasets>=2.12
|
3 |
+
hyperopt>=0.2
|
4 |
loompy>=3.0
|
5 |
matplotlib>=3.7
|
6 |
numpy>=1.23
|