H-Liu1997's picture
init
31f2f28
raw
history blame
891 Bytes
import torch
import torch.nn as nn
class CoordLoss(nn.Module):
def __init__(self):
super(CoordLoss, self).__init__()
def forward(self, coord_out, coord_gt, valid, is_3D=None):
loss = torch.abs(coord_out - coord_gt) * valid
if is_3D is not None:
loss_z = loss[:,:,2:] * is_3D[:,None,None].float()
loss = torch.cat((loss[:,:,:2], loss_z),2)
return loss
class ParamLoss(nn.Module):
def __init__(self):
super(ParamLoss, self).__init__()
def forward(self, param_out, param_gt, valid):
loss = torch.abs(param_out - param_gt) * valid
return loss
class CELoss(nn.Module):
def __init__(self):
super(CELoss, self).__init__()
self.ce_loss = nn.CrossEntropyLoss(reduction='none')
def forward(self, out, gt_index):
loss = self.ce_loss(out, gt_index)
return loss