CELL-E_2-Image_Prediction / celle_taming_main.py
EmaadKhwaja
file upload
5d2263b
import argparse, os, sys, datetime, glob, importlib
from omegaconf import OmegaConf
import numpy as np
from PIL import Image
import torch
import torchvision
from torch.utils.data import DataLoader, Dataset
from dataloader import CellLoader
from pytorch_lightning.callbacks import ModelCheckpoint, Callback, LearningRateMonitor
import pytorch_lightning as pl
from pytorch_lightning import seed_everything
from pytorch_lightning.trainer import Trainer
from pytorch_lightning.callbacks import Callback
from pytorch_lightning.utilities import rank_zero_only
def get_obj_from_str(string, reload=False):
module, cls = string.rsplit(".", 1)
if reload:
module_imp = importlib.import_module(module)
importlib.reload(module_imp)
return getattr(importlib.import_module(module, package=None), cls)
def get_parser(**parser_kwargs):
def str2bool(v):
if isinstance(v, bool):
return v
if v.lower() in ("yes", "true", "t", "y", "1"):
return True
elif v.lower() in ("no", "false", "f", "n", "0"):
return False
else:
raise argparse.ArgumentTypeError("Boolean value expected.")
parser = argparse.ArgumentParser(**parser_kwargs)
parser.add_argument(
"-n",
"--name",
type=str,
const=True,
default="",
nargs="?",
help="postfix for logdir",
)
parser.add_argument(
"-r",
"--resume",
type=str,
const=True,
default="",
nargs="?",
help="resume from logdir or checkpoint in logdir",
)
parser.add_argument(
"-b",
"--base",
nargs="*",
metavar="base_config.yaml",
help="paths to base configs. Loaded from left-to-right. "
"Parameters can be overwritten or added with command-line options of the form `--key value`.",
default=list(),
)
parser.add_argument(
"-t",
"--train",
type=str2bool,
const=True,
default=False,
nargs="?",
help="train",
)
parser.add_argument(
"--no-test",
type=str2bool,
const=True,
default=False,
nargs="?",
help="disable test",
)
parser.add_argument(
"-p", "--project", help="name of new or path to existing project"
)
parser.add_argument(
"-d",
"--debug",
type=str2bool,
nargs="?",
const=True,
default=False,
help="enable post-mortem debugging",
)
parser.add_argument(
"-s",
"--seed",
type=int,
default=42,
help="seed for seed_everything",
)
parser.add_argument(
"-f",
"--postfix",
type=str,
default="",
help="post-postfix for default name",
)
return parser
def nondefault_trainer_args(opt):
parser = argparse.ArgumentParser()
parser = Trainer.add_argparse_args(parser)
args = parser.parse_args([])
return sorted(k for k in vars(args) if getattr(opt, k) != getattr(args, k))
def instantiate_from_config(config):
if not "target" in config:
raise KeyError("Expected key `target` to instantiate.")
return get_obj_from_str(config["target"])(**config.get("params", dict()))
class WrappedDataset(Dataset):
"""Wraps an arbitrary object with __len__ and __getitem__ into a pytorch dataset"""
def __init__(self, dataset):
self.data = dataset
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
return self.data[idx]
class DataModuleFromConfig(pl.LightningDataModule):
def __init__(
self,
data_csv,
dataset,
crop_size=256,
resize=600,
batch_size=1,
sequence_mode="latent",
vocab="bert",
text_seq_len=0,
num_workers=1,
threshold=False,
train=True,
validation=True,
test=None,
wrap=False,
**kwargs,
):
super().__init__()
self.data_csv = data_csv
self.dataset = dataset
self.image_folders = []
self.crop_size = crop_size
self.resize = resize
self.batch_size = batch_size
self.sequence_mode = sequence_mode
self.threshold = threshold
self.text_seq_len = int(text_seq_len)
self.vocab = vocab
self.dataset_configs = dict()
self.num_workers = num_workers if num_workers is not None else batch_size * 2
if train is not None:
self.dataset_configs["train"] = train
self.train_dataloader = self._train_dataloader
if validation is not None:
self.dataset_configs["validation"] = validation
self.val_dataloader = self._val_dataloader
if test is not None:
self.dataset_configs["test"] = test
self.test_dataloader = self._test_dataloader
self.wrap = wrap
def prepare_data(self):
pass
def setup(self, stage=None):
# called on every GPU
self.cell_dataset_train = CellLoader(
data_csv=self.data_csv,
dataset=self.dataset,
crop_size=self.crop_size,
split_key="train",
crop_method="random",
sequence_mode=None,
vocab=self.vocab,
text_seq_len=self.text_seq_len,
threshold=self.threshold,
)
self.cell_dataset_val = CellLoader(
data_csv=self.data_csv,
dataset=self.dataset,
crop_size=self.crop_size,
split_key="val",
crop_method="center",
sequence_mode=None,
vocab=self.vocab,
text_seq_len=self.text_seq_len,
threshold=self.threshold,
)
def _train_dataloader(self):
return DataLoader(
self.cell_dataset_train,
num_workers=self.num_workers,
pin_memory=True,
shuffle=True,
batch_size=self.batch_size,
)
def _val_dataloader(self):
return DataLoader(
self.cell_dataset_val,
num_workers=self.num_workers,
pin_memory=True,
batch_size=self.batch_size,
)
# def _test_dataloader(self):
# return DataLoader(self.datasets["test"], batch_size=self.batch_size,
# num_workers=self.num_workers)
class SetupCallback(Callback):
def __init__(self, resume, now, logdir, ckptdir, cfgdir, config, lightning_config):
super().__init__()
self.resume = resume
self.now = now
self.logdir = logdir
self.ckptdir = ckptdir
self.cfgdir = cfgdir
self.config = config
self.lightning_config = lightning_config
def on_fit_start(self, trainer, pl_module):
if trainer.global_rank == 0:
# Create logdirs and save configs
os.makedirs(self.logdir, exist_ok=True)
os.makedirs(self.ckptdir, exist_ok=True)
os.makedirs(self.cfgdir, exist_ok=True)
print("Project config")
print(OmegaConf.to_yaml(self.config))
OmegaConf.save(
self.config,
os.path.join(self.cfgdir, "{}-project.yaml".format(self.now)),
)
print("Lightning config")
print(OmegaConf.to_yaml(self.lightning_config))
OmegaConf.save(
OmegaConf.create({"lightning": self.lightning_config}),
os.path.join(self.cfgdir, "{}-lightning.yaml".format(self.now)),
)
else:
# ModelCheckpoint callback created log directory --- remove it
if not self.resume and os.path.exists(self.logdir):
dst, name = os.path.split(self.logdir)
dst = os.path.join(dst, "child_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
try:
os.rename(self.logdir, dst)
except FileNotFoundError:
pass
class ImageLogger(Callback):
def __init__(
self, batch_frequency, max_images, clamp=True, increase_log_steps=True
):
super().__init__()
self.batch_freq = batch_frequency
self.max_images = max_images
self.logger_log_images = {
pl.loggers.WandbLogger: self._wandb,
# pl.loggers.TestTubeLogger: self._testtube,
pl.loggers.TensorBoardLogger: self._testtube,
}
self.log_steps = [2**n for n in range(int(np.log2(self.batch_freq)) + 1)]
if not increase_log_steps:
self.log_steps = [self.batch_freq]
self.clamp = clamp
@rank_zero_only
def _wandb(self, pl_module, images, batch_idx, split):
raise ValueError("No way wandb")
grids = dict()
for k in images:
grid = torchvision.utils.make_grid(images[k])
grids[f"{split}/{k}"] = wandb.Image(grid)
pl_module.logger.experiment.log(grids)
@rank_zero_only
def _testtube(self, pl_module, images, batch_idx, split):
for k in images:
images[k] -= torch.min(images[k])
images[k] /= torch.max(images[k])
grid = torchvision.utils.make_grid(images[k])
# grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
tag = f"{split}/{k}"
pl_module.logger.experiment.add_image(
tag, grid, global_step=pl_module.global_step
)
@rank_zero_only
def log_local(self, save_dir, split, images, global_step, current_epoch, batch_idx):
root = os.path.join(save_dir, "images", split)
for k in images:
images[k] -= torch.min(images[k])
images[k] /= torch.max(images[k])
grid = torchvision.utils.make_grid(images[k], nrow=4)
# grid = (grid + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w
grid = grid.transpose(0, 1).transpose(1, 2).squeeze(-1)
grid = grid.numpy()
grid = (grid * 255).astype(np.uint8)
filename = "{}_gs-{:06}_e-{:06}_b-{:06}.png".format(
k, global_step, current_epoch, batch_idx
)
path = os.path.join(root, filename)
os.makedirs(os.path.split(path)[0], exist_ok=True)
Image.fromarray(grid).save(path)
def log_img(self, pl_module, batch, batch_idx, split="train"):
if (
self.check_frequency(batch_idx)
and hasattr(pl_module, "log_images") # batch_idx % self.batch_freq == 0
and callable(pl_module.log_images)
and self.max_images > 0
):
logger = type(pl_module.logger)
is_train = pl_module.training
if is_train:
pl_module.eval()
with torch.no_grad():
images = pl_module.log_images(batch, split=split)
for k in images:
N = min(images[k].shape[0], self.max_images)
images[k] = images[k][:N]
if isinstance(images[k], torch.Tensor):
images[k] = images[k].detach().cpu()
if self.clamp:
images[k] = torch.clamp(images[k], -1.0, 1.0)
self.log_local(
pl_module.logger.save_dir,
split,
images,
pl_module.global_step,
pl_module.current_epoch,
batch_idx,
)
logger_log_images = self.logger_log_images.get(
logger, lambda *args, **kwargs: None
)
logger_log_images(pl_module, images, pl_module.global_step, split)
if is_train:
pl_module.train()
def check_frequency(self, batch_idx):
if (batch_idx % self.batch_freq) == 0 or (batch_idx in self.log_steps):
try:
self.log_steps.pop(0)
except IndexError:
pass
return True
return False
# def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
# def on_train_batch_end(self, *args, **kwargs):
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
self.log_img(pl_module, batch, batch_idx, split="train")
def on_validation_batch_end(
self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx
):
self.log_img(pl_module, batch, batch_idx, split="val")
if __name__ == "__main__":
# custom parser to specify config files, train, test and debug mode,
# postfix, resume.
# `--key value` arguments are interpreted as arguments to the trainer.
# `nested.key=value` arguments are interpreted as config parameters.
# configs are merged from left-to-right followed by command line parameters.
# model:
# base_learning_rate: float
# target: path to lightning module
# params:
# key: value
# data:
# target: main.DataModuleFromConfig
# params:
# batch_size: int
# wrap: bool
# train:
# target: path to train dataset
# params:
# key: value
# validation:
# target: path to validation dataset
# params:
# key: value
# test:
# target: path to test dataset
# params:
# key: value
# lightning: (optional, has sane defaults and can be specified on cmdline)
# trainer:
# additional arguments to trainer
# logger:
# logger to instantiate
# modelcheckpoint:
# modelcheckpoint to instantiate
# callbacks:
# callback1:
# target: importpath
# params:
# key: value
now = datetime.datetime.now().strftime("%Y-%m-%dT%H-%M-%S")
# add cwd for convenience and to make classes in this file available when
# running as `python main.py`
# (in particular `main.DataModuleFromConfig`)
sys.path.append(os.getcwd())
parser = get_parser()
parser = Trainer.add_argparse_args(parser)
opt, unknown = parser.parse_known_args()
if opt.name and opt.resume:
raise ValueError(
"-n/--name and -r/--resume cannot be specified both."
"If you want to resume training in a new log folder, "
"use -n/--name in combination with --resume_from_checkpoint"
)
if opt.resume:
if not os.path.exists(opt.resume):
raise ValueError("Cannot find {}".format(opt.resume))
if os.path.isfile(opt.resume):
paths = opt.resume.split("/")
idx = len(paths) - paths[::-1].index("logs") + 1
logdir = "/".join(paths[:idx])
ckpt = opt.resume
else:
assert os.path.isdir(opt.resume), opt.resume
logdir = opt.resume.rstrip("/")
ckpt = os.path.join(logdir, "checkpoints", "last.ckpt")
opt.resume_from_checkpoint = ckpt
base_configs = sorted(glob.glob(os.path.join(logdir, "configs/*.yaml")))
opt.base = base_configs + opt.base
_tmp = logdir.split("/")
nowname = _tmp[_tmp.index("logs") + 1]
else:
if opt.name:
name = "_" + opt.name
elif opt.base:
cfg_fname = os.path.split(opt.base[0])[-1]
cfg_name = os.path.splitext(cfg_fname)[0]
name = "_" + cfg_name
else:
name = ""
nowname = now + name + opt.postfix
logdir = os.path.join("logs", nowname)
ckptdir = os.path.join(logdir, "checkpoints")
cfgdir = os.path.join(logdir, "configs")
seed_everything(opt.seed)
try:
# init and save configs
configs = [OmegaConf.load(cfg) for cfg in opt.base]
cli = OmegaConf.from_dotlist(unknown)
config = OmegaConf.merge(*configs, cli)
lightning_config = config.pop("lightning", OmegaConf.create())
# merge trainer cli with config
trainer_config = lightning_config.get("trainer", OmegaConf.create())
# default to ddp
trainer_config["distributed_backend"] = "ddp"
trainer_config["replace_sampler_ddp"] = False
trainer_config["strategy"] = "ddp"
trainer_config["persistent_workers"] = True
for k in nondefault_trainer_args(opt):
trainer_config[k] = getattr(opt, k)
if not "gpus" in trainer_config:
del trainer_config["distributed_backend"]
cpu = True
else:
gpuinfo = trainer_config["gpus"]
print(f"Running on GPUs {gpuinfo}")
cpu = False
trainer_opt = argparse.Namespace(**trainer_config)
lightning_config.trainer = trainer_config
# model
model = instantiate_from_config(config.model)
# trainer and callbacks
trainer_kwargs = dict()
# default logger configs
# NOTE wandb < 0.10.0 interferes with shutdown
# wandb >= 0.10.0 seems to fix it but still interferes with pudb
# debugging (wrongly sized pudb ui)
# thus prefer testtube for now
default_logger_cfgs = {
"wandb": {
"target": "pytorch_lightning.loggers.WandbLogger",
"params": {
"name": nowname,
"save_dir": logdir,
"offline": opt.debug,
"id": nowname,
},
},
"testtube": {
# "target": "pytorch_lightning.loggers.TestTubeLogger",
"target": "pytorch_lightning.loggers.TensorBoardLogger",
"params": {
"name": "testtube",
"save_dir": logdir,
},
},
}
default_logger_cfg = default_logger_cfgs["testtube"]
try:
logger_cfg = lightning_config.logger
except:
logger_cfg = OmegaConf.create()
logger_cfg = OmegaConf.merge(default_logger_cfg, logger_cfg)
trainer_kwargs["logger"] = instantiate_from_config(logger_cfg)
# modelcheckpoint - use TrainResult/EvalResult(checkpoint_on=metric) to
# specify which metric is used to determine best models
default_modelckpt_cfg = {
"checkpoint_callback": {
"target": "pytorch_lightning.callbacks.ModelCheckpoint",
"params": {
"dirpath": ckptdir,
"filename": "{epoch:06}",
"verbose": True,
"save_last": True,
},
}
}
if hasattr(model, "monitor"):
print(f"Monitoring {model.monitor} as checkpoint metric.")
default_modelckpt_cfg["checkpoint_callback"]["params"][
"monitor"
] = model.monitor
default_modelckpt_cfg["checkpoint_callback"]["params"]["save_top_k"] = 3
try:
modelckpt_cfg = lightning_config.modelcheckpoint
except:
modelckpt_cfg = OmegaConf.create()
modelckpt_cfg = OmegaConf.merge(default_modelckpt_cfg, modelckpt_cfg)
# trainer_kwargs["checkpoint_callback"] = instantiate_from_config(modelckpt_cfg)
# loaded_model_callbacks = instantiate_from_config(modelckpt_cfg)
# add callback which sets up log directory
default_callbacks_cfg = {
"setup_callback": {
"target": "celle_taming_main.SetupCallback",
"params": {
"resume": opt.resume,
"now": now,
"logdir": logdir,
"ckptdir": ckptdir,
"cfgdir": cfgdir,
"config": config,
"lightning_config": lightning_config,
},
},
"image_logger": {
"target": "celle_taming_main.ImageLogger",
"params": {
"batch_frequency": 2000,
"max_images": 10,
"clamp": True,
"increase_log_steps": False,
},
},
"learning_rate_logger": {
"target": "celle_taming_main.LearningRateMonitor",
"params": {
"logging_interval": "step",
# "log_momentum": True
},
},
}
try:
callbacks_cfg = lightning_config.callbacks
except:
callbacks_cfg = OmegaConf.create()
callbacks_cfg = OmegaConf.merge(default_callbacks_cfg, callbacks_cfg)
callbacks_cfg = OmegaConf.merge(modelckpt_cfg, callbacks_cfg)
trainer_kwargs["callbacks"] = [
instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
]
# loaded_callbacks = [
# instantiate_from_config(callbacks_cfg[k]) for k in callbacks_cfg
# ]
# trainer_kwargs["callbacks"] = loaded_callbacks.append(loaded_model_callbacks)
trainer = Trainer.from_argparse_args(trainer_opt, **trainer_kwargs)
# data
data = instantiate_from_config(config.data)
# NOTE according to https://pytorch-lightning.readthedocs.io/en/latest/datamodules.html
# calling these ourselves should not be necessary but it is.
# lightning still takes care of proper multiprocessing though
data.prepare_data()
data.setup()
# configure learning rate
bs, base_lr = config.data.params.batch_size, config.model.base_learning_rate
if not cpu:
ngpu = len(lightning_config.trainer.gpus.strip(",").split(","))
else:
ngpu = 1
try:
accumulate_grad_batches = lightning_config.trainer.accumulate_grad_batches
except:
accumulate_grad_batches = 1
print(f"accumulate_grad_batches = {accumulate_grad_batches}")
lightning_config.trainer.accumulate_grad_batches = accumulate_grad_batches
model.learning_rate = accumulate_grad_batches * ngpu * bs * base_lr
print(
"Setting learning rate to {:.2e} = {} (accumulate_grad_batches) * {} (num_gpus) * {} (batchsize) * {:.2e} (base_lr)".format(
model.learning_rate, accumulate_grad_batches, ngpu, bs, base_lr
)
)
# allow checkpointing via USR1
def melk(*args, **kwargs):
# run all checkpoint hooks
if trainer.global_rank == 0:
print("Summoning checkpoint.")
ckpt_path = os.path.join(ckptdir, "last.ckpt")
trainer.save_checkpoint(ckpt_path)
def divein(*args, **kwargs):
if trainer.global_rank == 0:
import pudb
pudb.set_trace()
import signal
signal.signal(signal.SIGUSR1, melk)
signal.signal(signal.SIGUSR2, divein)
# model = torch.compile(model)
# run
if opt.train:
try:
torch.compile(trainer.fit(model, data))
except Exception:
melk()
raise
if not opt.no_test and not trainer.interrupted:
trainer.test(model, data)
except Exception:
if opt.debug and trainer.global_rank == 0:
try:
import pudb as debugger
except ImportError:
import pdb as debugger
debugger.post_mortem()
raise
finally:
# move newly created debug project to debug_runs
if opt.debug and not opt.resume and trainer.global_rank == 0:
dst, name = os.path.split(logdir)
dst = os.path.join(dst, "debug_runs", name)
os.makedirs(os.path.split(dst)[0], exist_ok=True)
os.rename(logdir, dst)