DmitriiKhizbullin's picture
Reorganized files
8afb176
raw
history blame
No virus
2.07 kB
from typing import Dict, Callable
import torch
from torchmetrics.aggregation import MeanMetric
from torchmetrics.classification.accuracy import MulticlassAccuracy
from torchmetrics.classification import MulticlassCohenKappa
class Metrics:
def __init__(self,
num_classes: int,
labelmap: Dict[int, str],
split: str,
log_fn: Callable[..., None]) -> None:
self.labelmap = labelmap
self.loss = MeanMetric(nan_strategy='ignore')
self.accuracy = MulticlassAccuracy(num_classes=num_classes)
self.per_class_accuracies = MulticlassAccuracy(
num_classes=num_classes, average=None)
self.kappa = MulticlassCohenKappa(num_classes)
self.split = split
self.log_fn = log_fn
def update(self,
loss: torch.Tensor,
preds: torch.Tensor,
labels: torch.Tensor) -> None:
self.loss.update(loss)
self.accuracy.update(preds, labels)
self.per_class_accuracies.update(preds, labels)
self.kappa.update(preds, labels)
def log(self) -> None:
loss = self.loss.compute()
accuracy = self.accuracy.compute()
accuracies = self.per_class_accuracies.compute()
kappa = self.kappa.compute()
mean_accuracy = torch.nanmean(accuracies)
self.log_fn(f"{self.split}/loss", loss, sync_dist=True)
self.log_fn(f"{self.split}/accuracy", accuracy, sync_dist=True)
self.log_fn(f"{self.split}/mean_accuracy", mean_accuracy, sync_dist=True)
for i_class, acc in enumerate(accuracies):
name = self.labelmap[i_class]
self.log_fn(f"{self.split}/acc/{i_class} {name}", acc, sync_dist=True)
self.log_fn(f"{self.split}/kappa", kappa, sync_dist=True)
def to(self, device) -> 'Metrics':
self.loss.to(device) # BUG HERE? should I assign it back?
self.accuracy.to(device)
self.per_class_accuracies.to(device)
self.kappa.to(device)
return self