File size: 2,058 Bytes
fdb2891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
"""Monitor rate of change of loss."""
from __future__ import annotations
import torch

class FDiffMetrics(Callback):
    """Rate of change of metrics.

    tracks and plots the rate of change of metrics effectively taking the
    numerical derivative of the metrics
    """

    def __init__(self, diff_train_metrics: bool=False, diff_eval_metrics: bool=True):
        self.diff_train_metrics = diff_train_metrics
        self.diff_eval_metrics = diff_eval_metrics
        self.train_prev_loss = None
        self.train_prev_metric = {}
        self.eval_prev_metric = {}

    def batch_end(self, state: State, logger: Logger) -> None:
        if self.diff_train_metrics:
            if not isinstance(state.loss, torch.Tensor):
                raise NotImplementedError('Multiple losses not supported yet')
            loss = state.loss.item()
            if self.train_prev_loss:
                logger.log_metrics({'loss/train/total_fdiff': loss - self.train_prev_loss})
            self.train_prev_loss = loss
            for k in self.train_prev_metric.keys():
                logger.log_metrics({f'metrics/train/{k}_fdiff': state.train_metric_values[k] - self.train_prev_metric[k]})
            for k in state.train_metric_values.keys():
                value = state.train_metric_values[k]
                self.train_prev_metric[k] = value

    def eval_end(self, state: State, logger: Logger) -> None:
        if self.diff_eval_metrics:
            evaluator = state.dataloader_label
            assert evaluator is not None, 'dataloader should have been set'
            metrics = list(state.eval_metrics[evaluator].keys())
            for k in metrics:
                mkey = '/'.join(['metrics', evaluator, k])
                if mkey in self.eval_prev_metric.keys():
                    logger.log_metrics({f'{mkey}_fdiff': state.eval_metric_values[k] - self.eval_prev_metric[mkey]})
            for k in metrics:
                mkey = '/'.join(['metrics', evaluator, k])
                self.eval_prev_metric[mkey] = state.eval_metric_values[k]