mart9992's picture
m
2cd560a
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from ..builder import LOSSES
@LOSSES.register_module()
class JointsMSELoss(nn.Module):
"""MSE loss for heatmaps.
Args:
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight=False, loss_weight=1.):
super().__init__()
self.criterion = nn.MSELoss()
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight):
"""Forward function."""
batch_size = output.size(0)
num_joints = output.size(1)
heatmaps_pred = output.reshape(
(batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
loss = 0.
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze(1)
heatmap_gt = heatmaps_gt[idx].squeeze(1)
if self.use_target_weight:
loss += self.criterion(heatmap_pred * target_weight[:, idx],
heatmap_gt * target_weight[:, idx])
else:
loss += self.criterion(heatmap_pred, heatmap_gt)
return loss / num_joints * self.loss_weight
@LOSSES.register_module()
class CombinedTargetMSELoss(nn.Module):
"""MSE loss for combined target.
CombinedTarget: The combination of classification target
(response map) and regression target (offset map).
Paper ref: Huang et al. The Devil is in the Details: Delving into
Unbiased Data Processing for Human Pose Estimation (CVPR 2020).
Args:
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight, loss_weight=1.):
super().__init__()
self.criterion = nn.MSELoss(reduction='mean')
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def forward(self, output, target, target_weight):
batch_size = output.size(0)
num_channels = output.size(1)
heatmaps_pred = output.reshape(
(batch_size, num_channels, -1)).split(1, 1)
heatmaps_gt = target.reshape(
(batch_size, num_channels, -1)).split(1, 1)
loss = 0.
num_joints = num_channels // 3
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx * 3].squeeze()
heatmap_gt = heatmaps_gt[idx * 3].squeeze()
offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze()
offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze()
offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze()
offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze()
if self.use_target_weight:
heatmap_pred = heatmap_pred * target_weight[:, idx]
heatmap_gt = heatmap_gt * target_weight[:, idx]
# classification loss
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt)
# regression loss
loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred,
heatmap_gt * offset_x_gt)
loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred,
heatmap_gt * offset_y_gt)
return loss / num_joints * self.loss_weight
@LOSSES.register_module()
class JointsOHKMMSELoss(nn.Module):
"""MSE loss with online hard keypoint mining.
Args:
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
topk (int): Only top k joint losses are kept.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self, use_target_weight=False, topk=8, loss_weight=1.):
super().__init__()
assert topk > 0
self.criterion = nn.MSELoss(reduction='none')
self.use_target_weight = use_target_weight
self.topk = topk
self.loss_weight = loss_weight
def _ohkm(self, loss):
"""Online hard keypoint mining."""
ohkm_loss = 0.
N = len(loss)
for i in range(N):
sub_loss = loss[i]
_, topk_idx = torch.topk(
sub_loss, k=self.topk, dim=0, sorted=False)
tmp_loss = torch.gather(sub_loss, 0, topk_idx)
ohkm_loss += torch.sum(tmp_loss) / self.topk
ohkm_loss /= N
return ohkm_loss
def forward(self, output, target, target_weight):
"""Forward function."""
batch_size = output.size(0)
num_joints = output.size(1)
if num_joints < self.topk:
raise ValueError(f'topk ({self.topk}) should not '
f'larger than num_joints ({num_joints}).')
heatmaps_pred = output.reshape(
(batch_size, num_joints, -1)).split(1, 1)
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1)
losses = []
for idx in range(num_joints):
heatmap_pred = heatmaps_pred[idx].squeeze(1)
heatmap_gt = heatmaps_gt[idx].squeeze(1)
if self.use_target_weight:
losses.append(
self.criterion(heatmap_pred * target_weight[:, idx],
heatmap_gt * target_weight[:, idx]))
else:
losses.append(self.criterion(heatmap_pred, heatmap_gt))
losses = [loss.mean(dim=1).unsqueeze(dim=1) for loss in losses]
losses = torch.cat(losses, dim=1)
return self._ohkm(losses) * self.loss_weight