import torch import torch.nn as nn import torch.nn.functional as F from ..builder import LOSSES def _expand_onehot_labels(labels, label_weights, label_channels): bin_labels = labels.new_full((labels.size(0), label_channels), 0) inds = torch.nonzero( (labels >= 0) & (labels < label_channels), as_tuple=False).squeeze() if inds.numel() > 0: bin_labels[inds, labels[inds]] = 1 bin_label_weights = label_weights.view(-1, 1).expand( label_weights.size(0), label_channels) return bin_labels, bin_label_weights # TODO: code refactoring to make it consistent with other losses @LOSSES.register_module() class GHMC(nn.Module): """GHM Classification Loss. Details of the theorem can be viewed in the paper `Gradient Harmonized Single-stage Detector `_. Args: bins (int): Number of the unit regions for distribution calculation. momentum (float): The parameter for moving average. use_sigmoid (bool): Can only be true for BCE based loss now. loss_weight (float): The weight of the total GHM-C loss. """ def __init__(self, bins=10, momentum=0, use_sigmoid=True, loss_weight=1.0): super(GHMC, self).__init__() self.bins = bins self.momentum = momentum edges = torch.arange(bins + 1).float() / bins self.register_buffer('edges', edges) self.edges[-1] += 1e-6 if momentum > 0: acc_sum = torch.zeros(bins) self.register_buffer('acc_sum', acc_sum) self.use_sigmoid = use_sigmoid if not self.use_sigmoid: raise NotImplementedError self.loss_weight = loss_weight def forward(self, pred, target, label_weight, *args, **kwargs): """Calculate the GHM-C loss. Args: pred (float tensor of size [batch_num, class_num]): The direct prediction of classification fc layer. target (float tensor of size [batch_num, class_num]): Binary class target for each sample. label_weight (float tensor of size [batch_num, class_num]): the value is 1 if the sample is valid and 0 if ignored. Returns: The gradient harmonized loss. """ # the target should be binary class label if pred.dim() != target.dim(): target, label_weight = _expand_onehot_labels( target, label_weight, pred.size(-1)) target, label_weight = target.float(), label_weight.float() edges = self.edges mmt = self.momentum weights = torch.zeros_like(pred) # gradient length g = torch.abs(pred.sigmoid().detach() - target) valid = label_weight > 0 tot = max(valid.float().sum().item(), 1.0) n = 0 # n valid bins for i in range(self.bins): inds = (g >= edges[i]) & (g < edges[i + 1]) & valid num_in_bin = inds.sum().item() if num_in_bin > 0: if mmt > 0: self.acc_sum[i] = mmt * self.acc_sum[i] \ + (1 - mmt) * num_in_bin weights[inds] = tot / self.acc_sum[i] else: weights[inds] = tot / num_in_bin n += 1 if n > 0: weights = weights / n loss = F.binary_cross_entropy_with_logits( pred, target, weights, reduction='sum') / tot return loss * self.loss_weight # TODO: code refactoring to make it consistent with other losses @LOSSES.register_module() class GHMR(nn.Module): """GHM Regression Loss. Details of the theorem can be viewed in the paper `Gradient Harmonized Single-stage Detector `_. Args: mu (float): The parameter for the Authentic Smooth L1 loss. bins (int): Number of the unit regions for distribution calculation. momentum (float): The parameter for moving average. loss_weight (float): The weight of the total GHM-R loss. """ def __init__(self, mu=0.02, bins=10, momentum=0, loss_weight=1.0): super(GHMR, self).__init__() self.mu = mu self.bins = bins edges = torch.arange(bins + 1).float() / bins self.register_buffer('edges', edges) self.edges[-1] = 1e3 self.momentum = momentum if momentum > 0: acc_sum = torch.zeros(bins) self.register_buffer('acc_sum', acc_sum) self.loss_weight = loss_weight # TODO: support reduction parameter def forward(self, pred, target, label_weight, avg_factor=None): """Calculate the GHM-R loss. Args: pred (float tensor of size [batch_num, 4 (* class_num)]): The prediction of box regression layer. Channel number can be 4 or 4 * class_num depending on whether it is class-agnostic. target (float tensor of size [batch_num, 4 (* class_num)]): The target regression values with the same size of pred. label_weight (float tensor of size [batch_num, 4 (* class_num)]): The weight of each sample, 0 if ignored. Returns: The gradient harmonized loss. """ mu = self.mu edges = self.edges mmt = self.momentum # ASL1 loss diff = pred - target loss = torch.sqrt(diff * diff + mu * mu) - mu # gradient length g = torch.abs(diff / torch.sqrt(mu * mu + diff * diff)).detach() weights = torch.zeros_like(g) valid = label_weight > 0 tot = max(label_weight.float().sum().item(), 1.0) n = 0 # n: valid bins for i in range(self.bins): inds = (g >= edges[i]) & (g < edges[i + 1]) & valid num_in_bin = inds.sum().item() if num_in_bin > 0: n += 1 if mmt > 0: self.acc_sum[i] = mmt * self.acc_sum[i] \ + (1 - mmt) * num_in_bin weights[inds] = tot / self.acc_sum[i] else: weights[inds] = tot / num_in_bin if n > 0: weights /= n loss = loss * weights loss = loss.sum() / tot return loss * self.loss_weight