File size: 7,842 Bytes
79a0c41 748f48a 79a0c41 748f48a c48e37c 748f48a 79a0c41 c48e37c 79a0c41 45b9d69 79a0c41 45b9d69 79a0c41 c48e37c 79a0c41 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 |
#!/usr/bin/env python
# coding: utf-8
# hyperparameter optimization with raytune for disease classification
# imports
import os
import subprocess
GPU_NUMBER = [0,1,2,3]
os.environ["CUDA_VISIBLE_DEVICES"] = ",".join([str(s) for s in GPU_NUMBER])
os.environ["NCCL_DEBUG"] = "INFO"
os.environ["CONDA_OVERRIDE_GLIBC"] = "2.56"
os.environ["LD_LIBRARY_PATH"] = "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"
# initiate runtime environment for raytune
import pyarrow # must occur prior to ray import
import ray
from ray import tune
from ray.tune import ExperimentAnalysis
from ray.tune.suggest.hyperopt import HyperOptSearch
ray.shutdown() #engage new ray session
runtime_env = {"conda": "base",
"env_vars": {"LD_LIBRARY_PATH": "/path/to/miniconda3/lib:/path/to/sw/lib:/path/to/sw/lib"}}
ray.init(runtime_env=runtime_env)
def initialize_ray_with_check(ip_address):
"""
Initialize Ray with a specified IP address and check its status and accessibility.
Args:
- ip_address (str): The IP address (with port) to initialize Ray.
Returns:
- bool: True if initialization was successful and dashboard is accessible, False otherwise.
"""
try:
ray.init(address=ip_address)
print(ray.nodes())
services = ray.get_webui_url()
if not services:
raise RuntimeError("Ray dashboard is not accessible.")
else:
print(f"Ray dashboard is accessible at: {services}")
return True
except Exception as e:
print(f"Error initializing Ray: {e}")
return False
# Usage:
ip = 'your_ip:xxxx' # Replace with your actual IP address and port
if initialize_ray_with_check(ip):
print("Ray initialized successfully.")
else:
print("Error during Ray initialization.")
import datetime
import numpy as np
import pandas as pd
import random
import seaborn as sns; sns.set()
from collections import Counter
from datasets import load_from_disk
from scipy.stats import ranksums
from sklearn.metrics import accuracy_score
from transformers import BertForSequenceClassification
from transformers import Trainer
from transformers.training_args import TrainingArguments
from geneformer import DataCollatorForCellClassification
# number of CPU cores
num_proc=30
# load train dataset with columns:
# cell_type (annotation of each cell's type)
# disease (healthy or disease state)
# individual (unique ID for each patient)
# length (length of that cell's rank value encoding)
train_dataset=load_from_disk("/path/to/disease_train_data.dataset")
# filter dataset for given cell_type
def if_cell_type(example):
return example["cell_type"].startswith("Cardiomyocyte")
trainset_v2 = train_dataset.filter(if_cell_type, num_proc=num_proc)
# create dictionary of disease states : label ids
target_names = ["healthy", "disease1", "disease2"]
target_name_id_dict = dict(zip(target_names,[i for i in range(len(target_names))]))
trainset_v3 = trainset_v2.rename_column("disease","label")
# change labels to numerical ids
def classes_to_ids(example):
example["label"] = target_name_id_dict[example["label"]]
return example
trainset_v4 = trainset_v3.map(classes_to_ids, num_proc=num_proc)
# separate into train, validation, test sets
indiv_set = set(trainset_v4["individual"])
random.seed(42)
train_indiv = random.sample(indiv_set,round(0.7*len(indiv_set)))
eval_indiv = [indiv for indiv in indiv_set if indiv not in train_indiv]
valid_indiv = random.sample(eval_indiv,round(0.5*len(eval_indiv)))
test_indiv = [indiv for indiv in eval_indiv if indiv not in valid_indiv]
def if_train(example):
return example["individual"] in train_indiv
classifier_trainset = trainset_v4.filter(if_train,num_proc=num_proc).shuffle(seed=42)
def if_valid(example):
return example["individual"] in valid_indiv
classifier_validset = trainset_v4.filter(if_valid,num_proc=num_proc).shuffle(seed=42)
# define output directory path
current_date = datetime.datetime.now()
datestamp = f"{str(current_date.year)[-2:]}{current_date.month:02d}{current_date.day:02d}"
output_dir = f"/path/to/models/{datestamp}_geneformer_DiseaseClassifier/"
# ensure not overwriting previously saved model
saved_model_test = os.path.join(output_dir, f"pytorch_model.bin")
if os.path.isfile(saved_model_test) == True:
raise Exception("Model already saved to this directory.")
# make output directory
subprocess.call(f'mkdir {output_dir}', shell=True)
# set training parameters
# how many pretrained layers to freeze
freeze_layers = 2
# batch size for training and eval
geneformer_batch_size = 12
# number of epochs
epochs = 1
# logging steps
logging_steps = round(len(classifier_trainset)/geneformer_batch_size/10)
# define function to initiate model
def model_init():
model = BertForSequenceClassification.from_pretrained("/path/to/pretrained_model/",
num_labels=len(target_names),
output_attentions = False,
output_hidden_states = False)
if freeze_layers is not None:
modules_to_freeze = model.bert.encoder.layer[:freeze_layers]
for module in modules_to_freeze:
for param in module.parameters():
param.requires_grad = False
model = model.to("cuda:0")
return model
# define metrics
# note: macro f1 score recommended for imbalanced multiclass classifiers
def compute_metrics(pred):
labels = pred.label_ids
preds = pred.predictions.argmax(-1)
# calculate accuracy using sklearn's function
acc = accuracy_score(labels, preds)
return {
'accuracy': acc,
}
# set training arguments
training_args = {
"do_train": True,
"do_eval": True,
"evaluation_strategy": "steps",
"eval_steps": logging_steps,
"logging_steps": logging_steps,
"group_by_length": True,
"length_column_name": "length",
"disable_tqdm": True,
"skip_memory_metrics": True, # memory tracker causes errors in raytune
"per_device_train_batch_size": geneformer_batch_size,
"per_device_eval_batch_size": geneformer_batch_size,
"num_train_epochs": epochs,
"load_best_model_at_end": True,
"output_dir": output_dir,
}
training_args_init = TrainingArguments(**training_args)
# create the trainer
trainer = Trainer(
model_init=model_init,
args=training_args_init,
data_collator=DataCollatorForCellClassification(),
train_dataset=classifier_trainset,
eval_dataset=classifier_validset,
compute_metrics=compute_metrics,
)
# specify raytune hyperparameter search space
ray_config = {
"num_train_epochs": tune.choice([epochs]),
"learning_rate": tune.loguniform(1e-6, 1e-3),
"weight_decay": tune.uniform(0.0, 0.3),
"lr_scheduler_type": tune.choice(["linear","cosine","polynomial"]),
"warmup_steps": tune.uniform(100, 2000),
"seed": tune.uniform(0,100),
"per_device_train_batch_size": tune.choice([geneformer_batch_size])
}
hyperopt_search = HyperOptSearch(
metric="eval_accuracy", mode="max")
# optimize hyperparameters
trainer.hyperparameter_search(
direction="maximize",
backend="ray",
resources_per_trial={"cpu":8,"gpu":1},
hp_space=lambda _: ray_config,
search_alg=hyperopt_search,
n_trials=100, # number of trials
progress_reporter=tune.CLIReporter(max_report_frequency=600,
sort_by_metric=True,
max_progress_rows=100,
mode="max",
metric="eval_accuracy",
metric_columns=["loss", "eval_loss", "eval_accuracy"])
) |