|
from abc import ABC, abstractmethod |
|
import torch |
|
import pytorch_lightning as pl |
|
|
|
class ValidationLoopHook(ABC): |
|
@abstractmethod |
|
def process(self, batch: torch.Tensor, target_batch: torch.Tensor, logits_batch: torch.Tensor, prediction_batch: torch.Tensor) -> None: |
|
""" |
|
Called for every validation batch to process results. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def trigger(self, module: pl.LightningModule): |
|
""" |
|
Called after the validation epoch has concluced to further interact with the module and/or log data. |
|
""" |
|
pass |
|
|
|
@abstractmethod |
|
def reset(self): |
|
""" |
|
Called right after build() to clean up before the next validation epoch starts. |
|
""" |
|
pass |