File size: 1,547 Bytes
8fc2b4e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
import os
import pytorch_lightning as pl
from pytorch_lightning.callbacks.base import Callback


class IntervalModelCheckpoint(Callback):
    """
    Save a checkpoint every N steps, instead of Lightning's default that checkpoints
    based on validation loss.
    """

    def __init__(
            self,
            dirpath,
            save_intervals,
    ):
        """
        Args:
            save_step_frequency: how often to save in steps
            prefix: add a prefix to the name, only used if
                use_modelcheckpoint_filename=False
            use_modelcheckpoint_filename: just use the ModelCheckpoint callback's
                default filename, don't use ours.
        """
        self.dirpath = dirpath
        self.save_intervals = save_intervals
        self.best_val_loss = 1e10

    def on_batch_end(self, trainer: pl.Trainer, _):
        """ Check if we should save a checkpoint after every train batch """
        global_step = trainer.global_step

        if (global_step + 1) in self.save_intervals:
            trainer.run_evaluation()
            val_loss = trainer.callback_metrics['val_loss']
            filename = f"steps={global_step+1:05d}-val_loss={val_loss:0.8f}.ckpt"
            ckpt_path = os.path.join(self.dirpath, filename)
            trainer.save_checkpoint(ckpt_path)

            if val_loss < self.best_val_loss:
                best_ckpt_path = os.path.join(self.dirpath, 'best.ckpt')
                trainer.save_checkpoint(best_ckpt_path)
                self.best_val_loss = val_loss