from typing import Dict, Callable import torch from torchmetrics.aggregation import MeanMetric from torchmetrics.classification.accuracy import MulticlassAccuracy from torchmetrics.classification import MulticlassCohenKappa class Metrics: def __init__(self, num_classes: int, labelmap: Dict[int, str], split: str, log_fn: Callable[..., None]) -> None: self.labelmap = labelmap self.loss = MeanMetric(nan_strategy='ignore') self.accuracy = MulticlassAccuracy(num_classes=num_classes) self.per_class_accuracies = MulticlassAccuracy( num_classes=num_classes, average=None) self.kappa = MulticlassCohenKappa(num_classes) self.split = split self.log_fn = log_fn def update(self, loss: torch.Tensor, preds: torch.Tensor, labels: torch.Tensor) -> None: self.loss.update(loss) self.accuracy.update(preds, labels) self.per_class_accuracies.update(preds, labels) self.kappa.update(preds, labels) def log(self) -> None: loss = self.loss.compute() accuracy = self.accuracy.compute() accuracies = self.per_class_accuracies.compute() kappa = self.kappa.compute() mean_accuracy = torch.nanmean(accuracies) self.log_fn(f"{self.split}/loss", loss, sync_dist=True) self.log_fn(f"{self.split}/accuracy", accuracy, sync_dist=True) self.log_fn(f"{self.split}/mean_accuracy", mean_accuracy, sync_dist=True) for i_class, acc in enumerate(accuracies): name = self.labelmap[i_class] self.log_fn(f"{self.split}/acc/{i_class} {name}", acc, sync_dist=True) self.log_fn(f"{self.split}/kappa", kappa, sync_dist=True) def to(self, device) -> 'Metrics': self.loss.to(device) # BUG HERE? should I assign it back? self.accuracy.to(device) self.per_class_accuracies.to(device) self.kappa.to(device) return self