File size: 1,283 Bytes
2cd560a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
@LOSSES.register_module()
class BCELoss(nn.Module):
"""Binary Cross Entropy loss."""
def __init__(self, use_target_weight=False, loss_weight=1.):
super().__init__()
self.criterion = F.binary_cross_entropy
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight=None):
"""Forward function.
Note:
- batch_size: N
- num_labels: K
Args:
output (torch.Tensor[N, K]): Output classification.
target (torch.Tensor[N, K]): Target classification.
target_weight (torch.Tensor[N, K] or torch.Tensor[N]):
Weights across different labels.
"""
if self.use_target_weight:
assert target_weight is not None
loss = self.criterion(output, target, reduction='none')
if target_weight.dim() == 1:
target_weight = target_weight[:, None]
loss = (loss * target_weight).mean()
else:
loss = self.criterion(output, target)
return loss * self.loss_weight
|