"""Main training script.""" import os from pathlib import Path import torch from cliport import agents from cliport.dataset import RavensDataset, RavensMultiTaskDataset, RavenMultiTaskDatasetBalance import hydra from pytorch_lightning import Trainer from pytorch_lightning.callbacks import ModelCheckpoint from pytorch_lightning.loggers import WandbLogger import numpy as np from torch.utils.data import DataLoader from torch.utils.data.dataloader import default_collate import IPython import pytorch_lightning as pl from pytorch_lightning.utilities import rank_zero_only import datetime import time @hydra.main(config_path="./cfg", config_name='train', version_base="1.2") def main(cfg): # Logger wandb_logger = None if cfg['train']['log']: try: wandb_logger = WandbLogger(name=cfg['tag']) except: pass # Checkpoint saver hydra_dir = Path(os.getcwd()) checkpoint_path = os.path.join(cfg['train']['train_dir'], 'checkpoints') last_checkpoint_path = os.path.join(checkpoint_path, 'last.ckpt') last_checkpoint = last_checkpoint_path if os.path.exists(last_checkpoint_path) and cfg['train']['load_from_last_ckpt'] else None checkpoint_callback = [ModelCheckpoint( # monitor=cfg['wandb']['saver']['monitor'], dirpath=os.path.join(checkpoint_path, 'best'), save_top_k=1, every_n_epochs=3, save_last=True, # every_n_train_steps=100 )] # Trainer max_epochs = cfg['train']['n_steps'] * cfg['train']['batch_size'] // cfg['train']['n_demos'] if cfg['train']['training_step_scale'] > 0: # scale training time depending on the tasks to ensure coverage. max_epochs = cfg['train']['training_step_scale'] # // cfg['train']['batch_size'] trainer = Trainer( accelerator='gpu', devices=cfg['train']['gpu'], fast_dev_run=cfg['debug'], logger=wandb_logger, callbacks=checkpoint_callback, max_epochs=max_epochs, # check_val_every_n_epoch=max_epochs // 50, # resume_from_checkpoint=last_checkpoint, sync_batchnorm=True, log_every_n_steps=30, ) print(f"max epochs: {max_epochs}!") # Resume epoch and global_steps if last_checkpoint: print(f"Resuming: {last_checkpoint}") # Config data_dir = cfg['train']['data_dir'] task = cfg['train']['task'] agent_type = cfg['train']['agent'] n_demos = cfg['train']['n_demos'] # n_demos = cfg['train']['n_demos'] # n_demos = cfg['train']['n_demos'] n_val = cfg['train']['n_val'] name = '{}-{}-{}'.format(task, agent_type, n_demos) # Datasets dataset_type = cfg['dataset']['type'] if 'multi' in dataset_type: train_ds = RavensMultiTaskDataset(data_dir, cfg, group=task, mode='train', n_demos=n_demos, augment=True) val_ds = RavensMultiTaskDataset(data_dir, cfg, group=task, mode='val', n_demos=n_val, augment=False) elif 'weighted' in dataset_type: train_ds = RavenMultiTaskDatasetBalance(data_dir, cfg, group=task, mode='train', n_demos=n_demos, augment=True) val_ds = RavenMultiTaskDatasetBalance(data_dir, cfg, group=task, mode='val', n_demos=n_val, augment=False) else: train_ds = RavensDataset(os.path.join(data_dir, '{}-train'.format(task)), cfg, n_demos=n_demos, augment=True) val_ds = RavensDataset(os.path.join(data_dir, '{}-val'.format(task)), cfg, n_demos=n_val, augment=False) # Initialize agent train_loader = DataLoader(train_ds, shuffle=True, pin_memory=True, batch_size=cfg['train']['batch_size'], num_workers=1 ) test_loader = DataLoader(val_ds, shuffle=False, num_workers=1, batch_size=cfg['train']['batch_size'], pin_memory=True) agent = agents.names[agent_type](name, cfg, train_loader, test_loader) dt_string = datetime.datetime.now().strftime("%d_%m_%Y_%H:%M:%S") print("current time:", dt_string) start_time = time.time() # Main training loop trainer.fit(agent, ckpt_path=last_checkpoint) print("current time:", time.time() - start_time) if __name__ == '__main__': main()