File size: 2,787 Bytes
2cd560a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 |
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn
from ..builder import LOSSES
@LOSSES.register_module()
class AdaptiveWingLoss(nn.Module):
"""Adaptive wing loss. paper ref: 'Adaptive Wing Loss for Robust Face
Alignment via Heatmap Regression' Wang et al. ICCV'2019.
Args:
alpha (float), omega (float), epsilon (float), theta (float)
are hyper-parameters.
use_target_weight (bool): Option to use weighted MSE loss.
Different joint types may have different target weights.
loss_weight (float): Weight of the loss. Default: 1.0.
"""
def __init__(self,
alpha=2.1,
omega=14,
epsilon=1,
theta=0.5,
use_target_weight=False,
loss_weight=1.):
super().__init__()
self.alpha = float(alpha)
self.omega = float(omega)
self.epsilon = float(epsilon)
self.theta = float(theta)
self.use_target_weight = use_target_weight
self.loss_weight = loss_weight
def criterion(self, pred, target):
"""Criterion of wingloss.
Note:
batch_size: N
num_keypoints: K
Args:
pred (torch.Tensor[NxKxHxW]): Predicted heatmaps.
target (torch.Tensor[NxKxHxW]): Target heatmaps.
"""
H, W = pred.shape[2:4]
delta = (target - pred).abs()
A = self.omega * (
1 / (1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
) * (self.alpha - target) * (torch.pow(
self.theta / self.epsilon,
self.alpha - target - 1)) * (1 / self.epsilon)
C = self.theta * A - self.omega * torch.log(
1 + torch.pow(self.theta / self.epsilon, self.alpha - target))
losses = torch.where(
delta < self.theta,
self.omega *
torch.log(1 +
torch.pow(delta / self.epsilon, self.alpha - target)),
A * delta - C)
return torch.mean(losses)
def forward(self, output, target, target_weight):
"""Forward function.
Note:
batch_size: N
num_keypoints: K
Args:
output (torch.Tensor[NxKxHxW]): Output heatmaps.
target (torch.Tensor[NxKxHxW]): Target heatmaps.
target_weight (torch.Tensor[NxKx1]):
Weights across different joint types.
"""
if self.use_target_weight:
loss = self.criterion(output * target_weight.unsqueeze(-1),
target * target_weight.unsqueeze(-1))
else:
loss = self.criterion(output, target)
return loss * self.loss_weight
|