|
from typing import Any |
|
|
|
from pytorch_toolbelt.losses import BinaryFocalLoss |
|
from torch import nn |
|
from torch.nn.modules.loss import BCEWithLogitsLoss |
|
|
|
|
|
class WeightedLosses(nn.Module): |
|
def __init__(self, losses, weights): |
|
super().__init__() |
|
self.losses = losses |
|
self.weights = weights |
|
|
|
def forward(self, *input: Any, **kwargs: Any): |
|
cum_loss = 0 |
|
for loss, w in zip(self.losses, self.weights): |
|
cum_loss += w * loss.forward(*input, **kwargs) |
|
return cum_loss |
|
|
|
|
|
class BinaryCrossentropy(BCEWithLogitsLoss): |
|
pass |
|
|
|
|
|
class FocalLoss(BinaryFocalLoss): |
|
def __init__(self, alpha=None, gamma=3, ignore_index=None, reduction="mean", normalized=False, |
|
reduced_threshold=None): |
|
super().__init__(alpha, gamma, ignore_index, reduction, normalized, reduced_threshold) |