|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from ..builder import LOSSES |
|
from ..utils.geometry import batch_rodrigues |
|
|
|
|
|
def perspective_projection(points, rotation, translation, focal_length, |
|
camera_center): |
|
"""This function computes the perspective projection of a set of 3D points. |
|
|
|
Note: |
|
- batch size: B |
|
- point number: N |
|
|
|
Args: |
|
points (Tensor([B, N, 3])): A set of 3D points |
|
rotation (Tensor([B, 3, 3])): Camera rotation matrix |
|
translation (Tensor([B, 3])): Camera translation |
|
focal_length (Tensor([B,])): Focal length |
|
camera_center (Tensor([B, 2])): Camera center |
|
|
|
Returns: |
|
projected_points (Tensor([B, N, 2])): Projected 2D |
|
points in image space. |
|
""" |
|
|
|
batch_size = points.shape[0] |
|
K = torch.zeros([batch_size, 3, 3], device=points.device) |
|
K[:, 0, 0] = focal_length |
|
K[:, 1, 1] = focal_length |
|
K[:, 2, 2] = 1. |
|
K[:, :-1, -1] = camera_center |
|
|
|
|
|
points = torch.einsum('bij,bkj->bki', rotation, points) |
|
points = points + translation.unsqueeze(1) |
|
|
|
|
|
projected_points = points / points[:, :, -1].unsqueeze(-1) |
|
|
|
|
|
projected_points = torch.einsum('bij,bkj->bki', K, projected_points) |
|
projected_points = projected_points[:, :, :-1] |
|
return projected_points |
|
|
|
|
|
@LOSSES.register_module() |
|
class MeshLoss(nn.Module): |
|
"""Mix loss for 3D human mesh. It is composed of loss on 2D joints, 3D |
|
joints, mesh vertices and smpl parameters (if any). |
|
|
|
Args: |
|
joints_2d_loss_weight (float): Weight for loss on 2D joints. |
|
joints_3d_loss_weight (float): Weight for loss on 3D joints. |
|
vertex_loss_weight (float): Weight for loss on 3D verteices. |
|
smpl_pose_loss_weight (float): Weight for loss on SMPL |
|
pose parameters. |
|
smpl_beta_loss_weight (float): Weight for loss on SMPL |
|
shape parameters. |
|
img_res (int): Input image resolution. |
|
focal_length (float): Focal length of camera model. Default=5000. |
|
""" |
|
|
|
def __init__(self, |
|
joints_2d_loss_weight, |
|
joints_3d_loss_weight, |
|
vertex_loss_weight, |
|
smpl_pose_loss_weight, |
|
smpl_beta_loss_weight, |
|
img_res, |
|
focal_length=5000): |
|
|
|
super().__init__() |
|
|
|
self.criterion_vertex = nn.L1Loss(reduction='none') |
|
|
|
|
|
self.criterion_joints_2d = nn.SmoothL1Loss(reduction='none') |
|
self.criterion_joints_3d = nn.SmoothL1Loss(reduction='none') |
|
|
|
|
|
self.criterion_regr = nn.MSELoss(reduction='none') |
|
|
|
self.joints_2d_loss_weight = joints_2d_loss_weight |
|
self.joints_3d_loss_weight = joints_3d_loss_weight |
|
self.vertex_loss_weight = vertex_loss_weight |
|
self.smpl_pose_loss_weight = smpl_pose_loss_weight |
|
self.smpl_beta_loss_weight = smpl_beta_loss_weight |
|
self.focal_length = focal_length |
|
self.img_res = img_res |
|
|
|
def joints_2d_loss(self, pred_joints_2d, gt_joints_2d, joints_2d_visible): |
|
"""Compute 2D reprojection loss on the joints. |
|
|
|
The loss is weighted by joints_2d_visible. |
|
""" |
|
conf = joints_2d_visible.float() |
|
loss = (conf * |
|
self.criterion_joints_2d(pred_joints_2d, gt_joints_2d)).mean() |
|
return loss |
|
|
|
def joints_3d_loss(self, pred_joints_3d, gt_joints_3d, joints_3d_visible): |
|
"""Compute 3D joints loss for the examples that 3D joint annotations |
|
are available. |
|
|
|
The loss is weighted by joints_3d_visible. |
|
""" |
|
conf = joints_3d_visible.float() |
|
if len(gt_joints_3d) > 0: |
|
gt_pelvis = (gt_joints_3d[:, 2, :] + gt_joints_3d[:, 3, :]) / 2 |
|
gt_joints_3d = gt_joints_3d - gt_pelvis[:, None, :] |
|
pred_pelvis = (pred_joints_3d[:, 2, :] + |
|
pred_joints_3d[:, 3, :]) / 2 |
|
pred_joints_3d = pred_joints_3d - pred_pelvis[:, None, :] |
|
return ( |
|
conf * |
|
self.criterion_joints_3d(pred_joints_3d, gt_joints_3d)).mean() |
|
return pred_joints_3d.sum() * 0 |
|
|
|
def vertex_loss(self, pred_vertices, gt_vertices, has_smpl): |
|
"""Compute 3D vertex loss for the examples that 3D human mesh |
|
annotations are available. |
|
|
|
The loss is weighted by the has_smpl. |
|
""" |
|
conf = has_smpl.float() |
|
loss_vertex = self.criterion_vertex(pred_vertices, gt_vertices) |
|
loss_vertex = (conf[:, None, None] * loss_vertex).mean() |
|
return loss_vertex |
|
|
|
def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas, |
|
has_smpl): |
|
"""Compute SMPL parameters loss for the examples that SMPL parameter |
|
annotations are available. |
|
|
|
The loss is weighted by has_smpl. |
|
""" |
|
conf = has_smpl.float() |
|
gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3) |
|
loss_regr_pose = self.criterion_regr(pred_rotmat, gt_rotmat) |
|
loss_regr_betas = self.criterion_regr(pred_betas, gt_betas) |
|
loss_regr_pose = (conf[:, None, None, None] * loss_regr_pose).mean() |
|
loss_regr_betas = (conf[:, None] * loss_regr_betas).mean() |
|
return loss_regr_pose, loss_regr_betas |
|
|
|
def project_points(self, points_3d, camera): |
|
"""Perform orthographic projection of 3D points using the camera |
|
parameters, return projected 2D points in image plane. |
|
|
|
Note: |
|
- batch size: B |
|
- point number: N |
|
|
|
Args: |
|
points_3d (Tensor([B, N, 3])): 3D points. |
|
camera (Tensor([B, 3])): camera parameters with the |
|
3 channel as (scale, translation_x, translation_y) |
|
|
|
Returns: |
|
Tensor([B, N, 2]): projected 2D points \ |
|
in image space. |
|
""" |
|
batch_size = points_3d.shape[0] |
|
device = points_3d.device |
|
cam_t = torch.stack([ |
|
camera[:, 1], camera[:, 2], 2 * self.focal_length / |
|
(self.img_res * camera[:, 0] + 1e-9) |
|
], |
|
dim=-1) |
|
camera_center = camera.new_zeros([batch_size, 2]) |
|
rot_t = torch.eye( |
|
3, device=device, |
|
dtype=points_3d.dtype).unsqueeze(0).expand(batch_size, -1, -1) |
|
joints_2d = perspective_projection( |
|
points_3d, |
|
rotation=rot_t, |
|
translation=cam_t, |
|
focal_length=self.focal_length, |
|
camera_center=camera_center) |
|
return joints_2d |
|
|
|
def forward(self, output, target): |
|
"""Forward function. |
|
|
|
Args: |
|
output (dict): dict of network predicted results. |
|
Keys: 'vertices', 'joints_3d', 'camera', |
|
'pose'(optional), 'beta'(optional) |
|
target (dict): dict of ground-truth labels. |
|
Keys: 'vertices', 'joints_3d', 'joints_3d_visible', |
|
'joints_2d', 'joints_2d_visible', 'pose', 'beta', |
|
'has_smpl' |
|
|
|
Returns: |
|
dict: dict of losses. |
|
""" |
|
losses = {} |
|
|
|
|
|
pred_vertices = output['vertices'] |
|
|
|
gt_vertices = target['vertices'] |
|
has_smpl = target['has_smpl'] |
|
loss_vertex = self.vertex_loss(pred_vertices, gt_vertices, has_smpl) |
|
losses['vertex_loss'] = loss_vertex * self.vertex_loss_weight |
|
|
|
|
|
if 'pose' in output.keys() and 'beta' in output.keys(): |
|
pred_rotmat = output['pose'] |
|
pred_betas = output['beta'] |
|
gt_pose = target['pose'] |
|
gt_betas = target['beta'] |
|
loss_regr_pose, loss_regr_betas = self.smpl_losses( |
|
pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl) |
|
losses['smpl_pose_loss'] = \ |
|
loss_regr_pose * self.smpl_pose_loss_weight |
|
losses['smpl_beta_loss'] = \ |
|
loss_regr_betas * self.smpl_beta_loss_weight |
|
|
|
|
|
pred_joints_3d = output['joints_3d'] |
|
gt_joints_3d = target['joints_3d'] |
|
joints_3d_visible = target['joints_3d_visible'] |
|
loss_joints_3d = self.joints_3d_loss(pred_joints_3d, gt_joints_3d, |
|
joints_3d_visible) |
|
losses['joints_3d_loss'] = loss_joints_3d * self.joints_3d_loss_weight |
|
|
|
|
|
pred_camera = output['camera'] |
|
gt_joints_2d = target['joints_2d'] |
|
joints_2d_visible = target['joints_2d_visible'] |
|
pred_joints_2d = self.project_points(pred_joints_3d, pred_camera) |
|
|
|
|
|
|
|
|
|
pred_joints_2d = 2 * pred_joints_2d / (self.img_res - 1) |
|
|
|
|
|
gt_joints_2d = 2 * gt_joints_2d / (self.img_res - 1) - 1 |
|
loss_joints_2d = self.joints_2d_loss(pred_joints_2d, gt_joints_2d, |
|
joints_2d_visible) |
|
losses['joints_2d_loss'] = loss_joints_2d * self.joints_2d_loss_weight |
|
|
|
return losses |
|
|
|
|
|
@LOSSES.register_module() |
|
class GANLoss(nn.Module): |
|
"""Define GAN loss. |
|
|
|
Args: |
|
gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'. |
|
real_label_val (float): The value for real label. Default: 1.0. |
|
fake_label_val (float): The value for fake label. Default: 0.0. |
|
loss_weight (float): Loss weight. Default: 1.0. |
|
Note that loss_weight is only for generators; and it is always 1.0 |
|
for discriminators. |
|
""" |
|
|
|
def __init__(self, |
|
gan_type, |
|
real_label_val=1.0, |
|
fake_label_val=0.0, |
|
loss_weight=1.0): |
|
super().__init__() |
|
self.gan_type = gan_type |
|
self.loss_weight = loss_weight |
|
self.real_label_val = real_label_val |
|
self.fake_label_val = fake_label_val |
|
|
|
if self.gan_type == 'vanilla': |
|
self.loss = nn.BCEWithLogitsLoss() |
|
elif self.gan_type == 'lsgan': |
|
self.loss = nn.MSELoss() |
|
elif self.gan_type == 'wgan': |
|
self.loss = self._wgan_loss |
|
elif self.gan_type == 'hinge': |
|
self.loss = nn.ReLU() |
|
else: |
|
raise NotImplementedError( |
|
f'GAN type {self.gan_type} is not implemented.') |
|
|
|
@staticmethod |
|
def _wgan_loss(input, target): |
|
"""wgan loss. |
|
|
|
Args: |
|
input (Tensor): Input tensor. |
|
target (bool): Target label. |
|
|
|
Returns: |
|
Tensor: wgan loss. |
|
""" |
|
return -input.mean() if target else input.mean() |
|
|
|
def get_target_label(self, input, target_is_real): |
|
"""Get target label. |
|
|
|
Args: |
|
input (Tensor): Input tensor. |
|
target_is_real (bool): Whether the target is real or fake. |
|
|
|
Returns: |
|
(bool | Tensor): Target tensor. Return bool for wgan, \ |
|
otherwise, return Tensor. |
|
""" |
|
|
|
if self.gan_type == 'wgan': |
|
return target_is_real |
|
target_val = ( |
|
self.real_label_val if target_is_real else self.fake_label_val) |
|
return input.new_ones(input.size()) * target_val |
|
|
|
def forward(self, input, target_is_real, is_disc=False): |
|
""" |
|
Args: |
|
input (Tensor): The input for the loss module, i.e., the network |
|
prediction. |
|
target_is_real (bool): Whether the targe is real or fake. |
|
is_disc (bool): Whether the loss for discriminators or not. |
|
Default: False. |
|
|
|
Returns: |
|
Tensor: GAN loss value. |
|
""" |
|
target_label = self.get_target_label(input, target_is_real) |
|
if self.gan_type == 'hinge': |
|
if is_disc: |
|
input = -input if target_is_real else input |
|
loss = self.loss(1 + input).mean() |
|
else: |
|
loss = -input.mean() |
|
else: |
|
loss = self.loss(input, target_label) |
|
|
|
|
|
return loss if is_disc else loss * self.loss_weight |
|
|