import os import pytorch_lightning as pl from pytorch_lightning.callbacks.base import Callback class IntervalModelCheckpoint(Callback): """ Save a checkpoint every N steps, instead of Lightning's default that checkpoints based on validation loss. """ def __init__( self, dirpath, save_intervals, ): """ Args: save_step_frequency: how often to save in steps prefix: add a prefix to the name, only used if use_modelcheckpoint_filename=False use_modelcheckpoint_filename: just use the ModelCheckpoint callback's default filename, don't use ours. """ self.dirpath = dirpath self.save_intervals = save_intervals self.best_val_loss = 1e10 def on_batch_end(self, trainer: pl.Trainer, _): """ Check if we should save a checkpoint after every train batch """ global_step = trainer.global_step if (global_step + 1) in self.save_intervals: trainer.run_evaluation() val_loss = trainer.callback_metrics['val_loss'] filename = f"steps={global_step+1:05d}-val_loss={val_loss:0.8f}.ckpt" ckpt_path = os.path.join(self.dirpath, filename) trainer.save_checkpoint(ckpt_path) if val_loss < self.best_val_loss: best_ckpt_path = os.path.join(self.dirpath, 'best.ckpt') trainer.save_checkpoint(best_ckpt_path) self.best_val_loss = val_loss