import torch import torch.nn as nn class CoordLoss(nn.Module): def __init__(self): super(CoordLoss, self).__init__() def forward(self, coord_out, coord_gt, valid, is_3D=None): loss = torch.abs(coord_out - coord_gt) * valid if is_3D is not None: loss_z = loss[:,:,2:] * is_3D[:,None,None].float() loss = torch.cat((loss[:,:,:2], loss_z),2) return loss class ParamLoss(nn.Module): def __init__(self): super(ParamLoss, self).__init__() def forward(self, param_out, param_gt, valid): loss = torch.abs(param_out - param_gt) * valid return loss class CELoss(nn.Module): def __init__(self): super(CELoss, self).__init__() self.ce_loss = nn.CrossEntropyLoss(reduction='none') def forward(self, out, gt_index): loss = self.ce_loss(out, gt_index) return loss