# Copyright (c) OpenMMLab. All rights reserved. import torch import torch.nn as nn from ..builder import LOSSES from ..utils.geometry import batch_rodrigues def perspective_projection(points, rotation, translation, focal_length, camera_center): """This function computes the perspective projection of a set of 3D points. Note: - batch size: B - point number: N Args: points (Tensor([B, N, 3])): A set of 3D points rotation (Tensor([B, 3, 3])): Camera rotation matrix translation (Tensor([B, 3])): Camera translation focal_length (Tensor([B,])): Focal length camera_center (Tensor([B, 2])): Camera center Returns: projected_points (Tensor([B, N, 2])): Projected 2D points in image space. """ batch_size = points.shape[0] K = torch.zeros([batch_size, 3, 3], device=points.device) K[:, 0, 0] = focal_length K[:, 1, 1] = focal_length K[:, 2, 2] = 1. K[:, :-1, -1] = camera_center # Transform points points = torch.einsum('bij,bkj->bki', rotation, points) points = points + translation.unsqueeze(1) # Apply perspective distortion projected_points = points / points[:, :, -1].unsqueeze(-1) # Apply camera intrinsics projected_points = torch.einsum('bij,bkj->bki', K, projected_points) projected_points = projected_points[:, :, :-1] return projected_points @LOSSES.register_module() class MeshLoss(nn.Module): """Mix loss for 3D human mesh. It is composed of loss on 2D joints, 3D joints, mesh vertices and smpl parameters (if any). Args: joints_2d_loss_weight (float): Weight for loss on 2D joints. joints_3d_loss_weight (float): Weight for loss on 3D joints. vertex_loss_weight (float): Weight for loss on 3D verteices. smpl_pose_loss_weight (float): Weight for loss on SMPL pose parameters. smpl_beta_loss_weight (float): Weight for loss on SMPL shape parameters. img_res (int): Input image resolution. focal_length (float): Focal length of camera model. Default=5000. """ def __init__(self, joints_2d_loss_weight, joints_3d_loss_weight, vertex_loss_weight, smpl_pose_loss_weight, smpl_beta_loss_weight, img_res, focal_length=5000): super().__init__() # Per-vertex loss on the mesh self.criterion_vertex = nn.L1Loss(reduction='none') # Joints (2D and 3D) loss self.criterion_joints_2d = nn.SmoothL1Loss(reduction='none') self.criterion_joints_3d = nn.SmoothL1Loss(reduction='none') # Loss for SMPL parameter regression self.criterion_regr = nn.MSELoss(reduction='none') self.joints_2d_loss_weight = joints_2d_loss_weight self.joints_3d_loss_weight = joints_3d_loss_weight self.vertex_loss_weight = vertex_loss_weight self.smpl_pose_loss_weight = smpl_pose_loss_weight self.smpl_beta_loss_weight = smpl_beta_loss_weight self.focal_length = focal_length self.img_res = img_res def joints_2d_loss(self, pred_joints_2d, gt_joints_2d, joints_2d_visible): """Compute 2D reprojection loss on the joints. The loss is weighted by joints_2d_visible. """ conf = joints_2d_visible.float() loss = (conf * self.criterion_joints_2d(pred_joints_2d, gt_joints_2d)).mean() return loss def joints_3d_loss(self, pred_joints_3d, gt_joints_3d, joints_3d_visible): """Compute 3D joints loss for the examples that 3D joint annotations are available. The loss is weighted by joints_3d_visible. """ conf = joints_3d_visible.float() if len(gt_joints_3d) > 0: gt_pelvis = (gt_joints_3d[:, 2, :] + gt_joints_3d[:, 3, :]) / 2 gt_joints_3d = gt_joints_3d - gt_pelvis[:, None, :] pred_pelvis = (pred_joints_3d[:, 2, :] + pred_joints_3d[:, 3, :]) / 2 pred_joints_3d = pred_joints_3d - pred_pelvis[:, None, :] return ( conf * self.criterion_joints_3d(pred_joints_3d, gt_joints_3d)).mean() return pred_joints_3d.sum() * 0 def vertex_loss(self, pred_vertices, gt_vertices, has_smpl): """Compute 3D vertex loss for the examples that 3D human mesh annotations are available. The loss is weighted by the has_smpl. """ conf = has_smpl.float() loss_vertex = self.criterion_vertex(pred_vertices, gt_vertices) loss_vertex = (conf[:, None, None] * loss_vertex).mean() return loss_vertex def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl): """Compute SMPL parameters loss for the examples that SMPL parameter annotations are available. The loss is weighted by has_smpl. """ conf = has_smpl.float() gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3) loss_regr_pose = self.criterion_regr(pred_rotmat, gt_rotmat) loss_regr_betas = self.criterion_regr(pred_betas, gt_betas) loss_regr_pose = (conf[:, None, None, None] * loss_regr_pose).mean() loss_regr_betas = (conf[:, None] * loss_regr_betas).mean() return loss_regr_pose, loss_regr_betas def project_points(self, points_3d, camera): """Perform orthographic projection of 3D points using the camera parameters, return projected 2D points in image plane. Note: - batch size: B - point number: N Args: points_3d (Tensor([B, N, 3])): 3D points. camera (Tensor([B, 3])): camera parameters with the 3 channel as (scale, translation_x, translation_y) Returns: Tensor([B, N, 2]): projected 2D points \ in image space. """ batch_size = points_3d.shape[0] device = points_3d.device cam_t = torch.stack([ camera[:, 1], camera[:, 2], 2 * self.focal_length / (self.img_res * camera[:, 0] + 1e-9) ], dim=-1) camera_center = camera.new_zeros([batch_size, 2]) rot_t = torch.eye( 3, device=device, dtype=points_3d.dtype).unsqueeze(0).expand(batch_size, -1, -1) joints_2d = perspective_projection( points_3d, rotation=rot_t, translation=cam_t, focal_length=self.focal_length, camera_center=camera_center) return joints_2d def forward(self, output, target): """Forward function. Args: output (dict): dict of network predicted results. Keys: 'vertices', 'joints_3d', 'camera', 'pose'(optional), 'beta'(optional) target (dict): dict of ground-truth labels. Keys: 'vertices', 'joints_3d', 'joints_3d_visible', 'joints_2d', 'joints_2d_visible', 'pose', 'beta', 'has_smpl' Returns: dict: dict of losses. """ losses = {} # Per-vertex loss for the shape pred_vertices = output['vertices'] gt_vertices = target['vertices'] has_smpl = target['has_smpl'] loss_vertex = self.vertex_loss(pred_vertices, gt_vertices, has_smpl) losses['vertex_loss'] = loss_vertex * self.vertex_loss_weight # Compute loss on SMPL parameters, if available if 'pose' in output.keys() and 'beta' in output.keys(): pred_rotmat = output['pose'] pred_betas = output['beta'] gt_pose = target['pose'] gt_betas = target['beta'] loss_regr_pose, loss_regr_betas = self.smpl_losses( pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl) losses['smpl_pose_loss'] = \ loss_regr_pose * self.smpl_pose_loss_weight losses['smpl_beta_loss'] = \ loss_regr_betas * self.smpl_beta_loss_weight # Compute 3D joints loss pred_joints_3d = output['joints_3d'] gt_joints_3d = target['joints_3d'] joints_3d_visible = target['joints_3d_visible'] loss_joints_3d = self.joints_3d_loss(pred_joints_3d, gt_joints_3d, joints_3d_visible) losses['joints_3d_loss'] = loss_joints_3d * self.joints_3d_loss_weight # Compute 2D reprojection loss for the 2D joints pred_camera = output['camera'] gt_joints_2d = target['joints_2d'] joints_2d_visible = target['joints_2d_visible'] pred_joints_2d = self.project_points(pred_joints_3d, pred_camera) # Normalize keypoints to [-1,1] # The coordinate origin of pred_joints_2d is # the center of the input image. pred_joints_2d = 2 * pred_joints_2d / (self.img_res - 1) # The coordinate origin of gt_joints_2d is # the top left corner of the input image. gt_joints_2d = 2 * gt_joints_2d / (self.img_res - 1) - 1 loss_joints_2d = self.joints_2d_loss(pred_joints_2d, gt_joints_2d, joints_2d_visible) losses['joints_2d_loss'] = loss_joints_2d * self.joints_2d_loss_weight return losses @LOSSES.register_module() class GANLoss(nn.Module): """Define GAN loss. Args: gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. real_label_val (float): The value for real label. Default: 1.0. fake_label_val (float): The value for fake label. Default: 0.0. loss_weight (float): Loss weight. Default: 1.0. Note that loss_weight is only for generators; and it is always 1.0 for discriminators. """ def __init__(self, gan_type, real_label_val=1.0, fake_label_val=0.0, loss_weight=1.0): super().__init__() self.gan_type = gan_type self.loss_weight = loss_weight self.real_label_val = real_label_val self.fake_label_val = fake_label_val if self.gan_type == 'vanilla': self.loss = nn.BCEWithLogitsLoss() elif self.gan_type == 'lsgan': self.loss = nn.MSELoss() elif self.gan_type == 'wgan': self.loss = self._wgan_loss elif self.gan_type == 'hinge': self.loss = nn.ReLU() else: raise NotImplementedError( f'GAN type {self.gan_type} is not implemented.') @staticmethod def _wgan_loss(input, target): """wgan loss. Args: input (Tensor): Input tensor. target (bool): Target label. Returns: Tensor: wgan loss. """ return -input.mean() if target else input.mean() def get_target_label(self, input, target_is_real): """Get target label. Args: input (Tensor): Input tensor. target_is_real (bool): Whether the target is real or fake. Returns: (bool | Tensor): Target tensor. Return bool for wgan, \ otherwise, return Tensor. """ if self.gan_type == 'wgan': return target_is_real target_val = ( self.real_label_val if target_is_real else self.fake_label_val) return input.new_ones(input.size()) * target_val def forward(self, input, target_is_real, is_disc=False): """ Args: input (Tensor): The input for the loss module, i.e., the network prediction. target_is_real (bool): Whether the targe is real or fake. is_disc (bool): Whether the loss for discriminators or not. Default: False. Returns: Tensor: GAN loss value. """ target_label = self.get_target_label(input, target_is_real) if self.gan_type == 'hinge': if is_disc: # for discriminators in hinge-gan input = -input if target_is_real else input loss = self.loss(1 + input).mean() else: # for generators in hinge-gan loss = -input.mean() else: # other gan types loss = self.loss(input, target_label) # loss_weight is always 1.0 for discriminators return loss if is_disc else loss * self.loss_weight