|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ..builder import LOSSES |
|
|
|
|
|
@LOSSES.register_module() |
|
class JointsMSELoss(nn.Module): |
|
"""MSE loss for heatmaps. |
|
|
|
Args: |
|
use_target_weight (bool): Option to use weighted MSE loss. |
|
Different joint types may have different target weights. |
|
loss_weight (float): Weight of the loss. Default: 1.0. |
|
""" |
|
|
|
def __init__(self, use_target_weight=False, loss_weight=1.): |
|
super().__init__() |
|
self.criterion = nn.MSELoss() |
|
self.use_target_weight = use_target_weight |
|
self.loss_weight = loss_weight |
|
|
|
def forward(self, output, target, target_weight): |
|
"""Forward function.""" |
|
batch_size = output.size(0) |
|
num_joints = output.size(1) |
|
|
|
heatmaps_pred = output.reshape( |
|
(batch_size, num_joints, -1)).split(1, 1) |
|
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) |
|
|
|
loss = 0. |
|
|
|
for idx in range(num_joints): |
|
heatmap_pred = heatmaps_pred[idx].squeeze(1) |
|
heatmap_gt = heatmaps_gt[idx].squeeze(1) |
|
if self.use_target_weight: |
|
loss += self.criterion(heatmap_pred * target_weight[:, idx], |
|
heatmap_gt * target_weight[:, idx]) |
|
else: |
|
loss += self.criterion(heatmap_pred, heatmap_gt) |
|
|
|
return loss / num_joints * self.loss_weight |
|
|
|
|
|
@LOSSES.register_module() |
|
class CombinedTargetMSELoss(nn.Module): |
|
"""MSE loss for combined target. |
|
CombinedTarget: The combination of classification target |
|
(response map) and regression target (offset map). |
|
Paper ref: Huang et al. The Devil is in the Details: Delving into |
|
Unbiased Data Processing for Human Pose Estimation (CVPR 2020). |
|
|
|
Args: |
|
use_target_weight (bool): Option to use weighted MSE loss. |
|
Different joint types may have different target weights. |
|
loss_weight (float): Weight of the loss. Default: 1.0. |
|
""" |
|
|
|
def __init__(self, use_target_weight, loss_weight=1.): |
|
super().__init__() |
|
self.criterion = nn.MSELoss(reduction='mean') |
|
self.use_target_weight = use_target_weight |
|
self.loss_weight = loss_weight |
|
|
|
def forward(self, output, target, target_weight): |
|
batch_size = output.size(0) |
|
num_channels = output.size(1) |
|
heatmaps_pred = output.reshape( |
|
(batch_size, num_channels, -1)).split(1, 1) |
|
heatmaps_gt = target.reshape( |
|
(batch_size, num_channels, -1)).split(1, 1) |
|
loss = 0. |
|
num_joints = num_channels // 3 |
|
for idx in range(num_joints): |
|
heatmap_pred = heatmaps_pred[idx * 3].squeeze() |
|
heatmap_gt = heatmaps_gt[idx * 3].squeeze() |
|
offset_x_pred = heatmaps_pred[idx * 3 + 1].squeeze() |
|
offset_x_gt = heatmaps_gt[idx * 3 + 1].squeeze() |
|
offset_y_pred = heatmaps_pred[idx * 3 + 2].squeeze() |
|
offset_y_gt = heatmaps_gt[idx * 3 + 2].squeeze() |
|
if self.use_target_weight: |
|
heatmap_pred = heatmap_pred * target_weight[:, idx] |
|
heatmap_gt = heatmap_gt * target_weight[:, idx] |
|
|
|
loss += 0.5 * self.criterion(heatmap_pred, heatmap_gt) |
|
|
|
loss += 0.5 * self.criterion(heatmap_gt * offset_x_pred, |
|
heatmap_gt * offset_x_gt) |
|
loss += 0.5 * self.criterion(heatmap_gt * offset_y_pred, |
|
heatmap_gt * offset_y_gt) |
|
return loss / num_joints * self.loss_weight |
|
|
|
|
|
@LOSSES.register_module() |
|
class JointsOHKMMSELoss(nn.Module): |
|
"""MSE loss with online hard keypoint mining. |
|
|
|
Args: |
|
use_target_weight (bool): Option to use weighted MSE loss. |
|
Different joint types may have different target weights. |
|
topk (int): Only top k joint losses are kept. |
|
loss_weight (float): Weight of the loss. Default: 1.0. |
|
""" |
|
|
|
def __init__(self, use_target_weight=False, topk=8, loss_weight=1.): |
|
super().__init__() |
|
assert topk > 0 |
|
self.criterion = nn.MSELoss(reduction='none') |
|
self.use_target_weight = use_target_weight |
|
self.topk = topk |
|
self.loss_weight = loss_weight |
|
|
|
def _ohkm(self, loss): |
|
"""Online hard keypoint mining.""" |
|
ohkm_loss = 0. |
|
N = len(loss) |
|
for i in range(N): |
|
sub_loss = loss[i] |
|
_, topk_idx = torch.topk( |
|
sub_loss, k=self.topk, dim=0, sorted=False) |
|
tmp_loss = torch.gather(sub_loss, 0, topk_idx) |
|
ohkm_loss += torch.sum(tmp_loss) / self.topk |
|
ohkm_loss /= N |
|
return ohkm_loss |
|
|
|
def forward(self, output, target, target_weight): |
|
"""Forward function.""" |
|
batch_size = output.size(0) |
|
num_joints = output.size(1) |
|
if num_joints < self.topk: |
|
raise ValueError(f'topk ({self.topk}) should not ' |
|
f'larger than num_joints ({num_joints}).') |
|
heatmaps_pred = output.reshape( |
|
(batch_size, num_joints, -1)).split(1, 1) |
|
heatmaps_gt = target.reshape((batch_size, num_joints, -1)).split(1, 1) |
|
|
|
losses = [] |
|
for idx in range(num_joints): |
|
heatmap_pred = heatmaps_pred[idx].squeeze(1) |
|
heatmap_gt = heatmaps_gt[idx].squeeze(1) |
|
if self.use_target_weight: |
|
losses.append( |
|
self.criterion(heatmap_pred * target_weight[:, idx], |
|
heatmap_gt * target_weight[:, idx])) |
|
else: |
|
losses.append(self.criterion(heatmap_pred, heatmap_gt)) |
|
|
|
losses = [loss.mean(dim=1).unsqueeze(dim=1) for loss in losses] |
|
losses = torch.cat(losses, dim=1) |
|
|
|
return self._ohkm(losses) * self.loss_weight |
|
|