# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from ..builder import LOSSES @LOSSES.register_module() class JointsMSELoss(nn.Module): """MSE loss for heatmaps. Args: use_target_weight (bool): Option to use weighted MSE loss. Different joint types may have different target weights. loss_weight (float): Weight of the loss. Default: 1.0. """ def __init__(self, use_target_weight=False, loss_weight=1.): super().__init__() self.criterion = nn.MSELoss() self.use_target_weight = use_target_weight self.loss_weight = loss_weight def forward(self, output, target, target_weight): """Forward function.""" batch_size = output.size(0) num_joints = output.size(1) heatmaps_pred = output.reshape( (batch_size, num_joints, -1)).split(1, 1) heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) loss = 0. for idx in range(num_joints): heatmap_pred = heatmaps_pred[idx].squeeze(1) heatmap_gt = heatmaps_gt[idx].squeeze(1) if self.use_target_weight: loss += self.criterion(heatmap_pred * target_weight[:, idx], heatmap_gt * target_weight[:, idx]) else: loss += self.criterion(heatmap_pred, heatmap_gt) return loss / num_joints * self.loss_weight @LOSSES.register_module() class CombinedTargetMSELoss(nn.Module): """MSE loss for combined target. CombinedTarget: The combination of classification target (response map) and regression target (offset map). Paper ref: Huang et al. The Devil is in the Details: Delving into Unbiased Data Processing for Human Pose Estimation (CVPR 2020). Args: use_target_weight (bool): Option to use weighted MSE loss. Different joint types may have different target weights. loss_weight (float): Weight of the loss. Default: 1.0. """ def __init__(self, use_target_weight, loss_weight=1.): super().__init__() self.criterion = nn.MSELoss(reduction='mean') self.use_target_weight = use_target_weight self.loss_weight = loss_weight def forward(self, output, target, target_weight): batch_size = output.size(0) num_channels = output.size(1) heatmaps_pred = output.reshape( (batch_size, num_channels, -1)).split(1, 1) heatmaps_gt = target.reshape( (batch_size, num_channels, -1)).split(1, 1) loss = 0. num_joints = num_channels // 3 for idx in range(num_joints): heatmap_pred = heatmaps_pred[idx * 3].squeeze() heatmap_gt = heatmaps_gt[idx * 3].squeeze() offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze() offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze() offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze() offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze() if self.use_target_weight: heatmap_pred = heatmap_pred * target_weight[:, idx] heatmap_gt = heatmap_gt * target_weight[:, idx] # classification loss loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) # regression loss loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred, heatmap_gt * offset_x_gt) loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred, heatmap_gt * offset_y_gt) return loss / num_joints * self.loss_weight @LOSSES.register_module() class JointsOHKMMSELoss(nn.Module): """MSE loss with online hard keypoint mining. Args: use_target_weight (bool): Option to use weighted MSE loss. Different joint types may have different target weights. topk (int): Only top k joint losses are kept. loss_weight (float): Weight of the loss. Default: 1.0. """ def __init__(self, use_target_weight=False, topk=8, loss_weight=1.): super().__init__() assert topk > 0 self.criterion = nn.MSELoss(reduction='none') self.use_target_weight = use_target_weight self.topk = topk self.loss_weight = loss_weight def _ohkm(self, loss): """Online hard keypoint mining.""" ohkm_loss = 0. N = len(loss) for i in range(N): sub_loss = loss[i] _, topk_idx = torch.topk( sub_loss, k=self.topk, dim=0, sorted=False) tmp_loss = torch.gather(sub_loss, 0, topk_idx) ohkm_loss += torch.sum(tmp_loss) / self.topk ohkm_loss /= N return ohkm_loss def forward(self, output, target, target_weight): """Forward function.""" batch_size = output.size(0) num_joints = output.size(1) if num_joints < self.topk: raise ValueError(f'topk ({self.topk}) should not ' f'larger than num_joints ({num_joints}).') heatmaps_pred = output.reshape( (batch_size, num_joints, -1)).split(1, 1) heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) losses = [] for idx in range(num_joints): heatmap_pred = heatmaps_pred[idx].squeeze(1) heatmap_gt = heatmaps_gt[idx].squeeze(1) if self.use_target_weight: losses.append( self.criterion(heatmap_pred * target_weight[:, idx], heatmap_gt * target_weight[:, idx])) else: losses.append(self.criterion(heatmap_pred, heatmap_gt)) losses = [loss.mean(dim=1).unsqueeze(dim=1) for loss in losses] losses = torch.cat(losses, dim=1) return self._ohkm(losses) * self.loss_weight