File size: 5,894 Bytes
2cd560a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 |
import math
import mmcv
import torch
import torch.nn as nn
import torch.nn.functional as F
from ..builder import LOSSES
@LOSSES.register_module()
class RLELoss_poseur_old(nn.Module):
''' RLE Regression Loss
'''
def __init__(self, OUTPUT_3D=False, use_target_weight=True, size_average=True):
super(RLELoss_poseur_old, self).__init__()
self.size_average = size_average
self.amp = 1 / math.sqrt(2 * math.pi)
def logQ(self, gt_uv, pred_jts, sigma):
return torch.log(sigma / self.amp) + torch.abs(gt_uv - pred_jts) / (math.sqrt(2) * sigma + 1e-9)
def forward(self, output, target_uv, target_uv_weight):
pred_jts = output.pred_jts
sigma = output.sigma
gt_uv = target_uv.reshape(pred_jts.shape)
gt_uv_weight = target_uv_weight.reshape(pred_jts.shape)
nf_loss = output.nf_loss * gt_uv_weight[:, :, :1]
# print(gt_uv.min(), gt_uv.max())
residual = True
if residual:
Q_logprob = self.logQ(gt_uv, pred_jts, sigma) * gt_uv_weight
loss = nf_loss + Q_logprob
if self.size_average and gt_uv_weight.sum() > 0:
return loss.sum() / len(loss)
else:
return loss.sum()
@LOSSES.register_module()
class RLELoss_poseur(nn.Module):
''' RLE Regression Loss
'''
def __init__(self, OUTPUT_3D=False, use_target_weight=True, size_average=True):
super(RLELoss_poseur, self).__init__()
self.size_average = size_average
self.amp = 1 / math.sqrt(2 * math.pi)
def logQ(self, gt_uv, pred_jts, sigma):
return torch.log(sigma / self.amp) + torch.abs(gt_uv - pred_jts) / (math.sqrt(2) * sigma + 1e-9)
def forward(self, output, target_uvd, target_uvd_weight):
pred_jts = output.pred_jts
sigma = output.sigma
gt_uv = target_uvd.reshape(pred_jts.shape)
gt_uv_weight = target_uvd_weight.reshape(pred_jts.shape)
# nf_loss = output.nf_loss * gt_uv_weight[:, :, :1]
nf_loss = output.nf_loss * gt_uv_weight
residual = True
if residual:
Q_logprob = self.logQ(gt_uv, pred_jts, sigma) * gt_uv_weight
loss = nf_loss + Q_logprob
if self.size_average and gt_uv_weight.sum() > 0:
return loss.sum() / len(loss)
else:
return loss.sum()
@LOSSES.register_module()
class RLEOHKMLoss(nn.Module):
''' RLE Regression Loss
'''
def __init__(self, OUTPUT_3D=False, use_target_weight=True, size_average=True, topk=8,
ori_weight = 1.0, ohkm_weight = 0.0):
super(RLEOHKMLoss, self).__init__()
self.size_average = size_average
self.amp = 1 / math.sqrt(2 * math.pi)
self.topk = topk
self.ori_weight = ori_weight
self.ohkm_weight = ohkm_weight
self.neg_inf = -float("Inf")
def logQ(self, gt_uv, pred_jts, sigma):
return torch.log(sigma / self.amp) + torch.abs(gt_uv - pred_jts) / (math.sqrt(2) * sigma + 1e-9)
def ohkm(self, loss, weight):
# mask = weight == 0
loss_value = loss.clone().detach()
loss_value[weight == 0] = self.neg_inf
_, topk_idx = torch.topk(
loss_value, k=self.topk, dim=1, sorted=False)
tmp_loss = torch.gather(loss, 1, topk_idx)
tmp_weight = torch.gather(weight, 1, topk_idx)
# tmp_loss[tmp_loss==-float("Inf")] = 0
tmp_loss = tmp_loss * tmp_weight
tmp_loss = tmp_loss.flatten(start_dim=1).sum(dim = 1)
# tmp_weight = tmp_weight.flatten(start_dim=1).sum(dim = 1)
# tmp_loss = tmp_loss / tmp_weight
return tmp_loss.mean()
def ori(self, loss, weight):
# mask = weight == 0
loss = loss * weight
loss = loss.flatten(start_dim=1).sum(dim = 1)
# weight = weight.flatten(start_dim=1).sum(dim = 1)
return loss.mean()
def forward(self, output, target_uv, target_uv_weight):
pred_jts = output.pred_jts
sigma = output.sigma
gt_uv = target_uv.reshape(pred_jts.shape)
gt_uv_weight = target_uv_weight.reshape(pred_jts.shape)
# gt_uv_weight = gt_uv_weight[:, :, :1]
nf_loss = output.nf_loss
q_loss = self.logQ(gt_uv, pred_jts, sigma)
# nf_loss_ohkm = self.ohkm(nf_loss, gt_uv_weight)
# q_loss_ohkm = self.ohkm(q_loss, gt_uv_weight)
ori_loss = nf_loss + q_loss
ohkm_loss = self.ohkm(ori_loss, gt_uv_weight)
ori_loss = self.ori(ori_loss, gt_uv_weight)
loss = self.ori_weight * ori_loss + self.ohkm_weight * ohkm_loss
return loss #TODO mean?
# nf_loss = output.nf_loss * gt_uv_weight
# Q_logprob = self.logQ(gt_uv, pred_jts, sigma) * gt_uv_weight
# loss = nf_loss + Q_logprob
# return loss.sum() / len(loss)
@LOSSES.register_module()
class RLELoss3D(nn.Module):
''' RLE Regression Loss 3D
'''
def __init__(self, OUTPUT_3D=False, size_average=True):
super(RLELoss3D, self).__init__()
self.size_average = size_average
self.amp = 1 / math.sqrt(2 * math.pi)
def logQ(self, gt_uv, pred_jts, sigma):
return torch.log(sigma / self.amp) + torch.abs(gt_uv - pred_jts) / (math.sqrt(2) * sigma + 1e-9)
def forward(self, output, labels):
nf_loss = output.nf_loss
pred_jts = output.pred_jts
sigma = output.sigma
gt_uv = labels['target_uvd'].reshape(pred_jts.shape)
gt_uv_weight = labels['target_uvd_weight'].reshape(pred_jts.shape)
nf_loss = nf_loss * gt_uv_weight
residual = True
if residual:
Q_logprob = self.logQ(gt_uv, pred_jts, sigma) * gt_uv_weight
loss = nf_loss + Q_logprob
if self.size_average and gt_uv_weight.sum() > 0:
return loss.sum() / len(loss)
else:
return loss.sum() |