Paul Engstler
Initial commit
92f0e98
raw
history blame contribute delete
755 Bytes
from abc import ABC, abstractmethod
import torch
import pytorch_lightning as pl
class ValidationLoopHook(ABC):
@abstractmethod
def process(self, batch: torch.Tensor, target_batch: torch.Tensor, logits_batch: torch.Tensor, prediction_batch: torch.Tensor) -> None:
"""
Called for every validation batch to process results.
"""
pass
@abstractmethod
def trigger(self, module: pl.LightningModule):
"""
Called after the validation epoch has concluced to further interact with the module and/or log data.
"""
pass
@abstractmethod
def reset(self):
"""
Called right after build() to clean up before the next validation epoch starts.
"""
pass