File size: 12,646 Bytes
2cd560a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
331
332
333
334
335
336
337
338
339
340
341
# Copyright (c) OpenMMLab. All rights reserved.
import torch
import torch.nn as nn

from ..builder import LOSSES
from ..utils.geometry import batch_rodrigues


def perspective_projection(points, rotation, translation, focal_length,
                           camera_center):
    """This function computes the perspective projection of a set of 3D points.

    Note:
        - batch size: B
        - point number: N

    Args:
        points (Tensor([B, N, 3])): A set of 3D points
        rotation (Tensor([B, 3, 3])): Camera rotation matrix
        translation (Tensor([B, 3])): Camera translation
        focal_length (Tensor([B,])): Focal length
        camera_center (Tensor([B, 2])): Camera center

    Returns:
        projected_points (Tensor([B, N, 2])): Projected 2D
            points in image space.
    """

    batch_size = points.shape[0]
    K = torch.zeros([batch_size, 3, 3], device=points.device)
    K[:, 0, 0] = focal_length
    K[:, 1, 1] = focal_length
    K[:, 2, 2] = 1.
    K[:, :-1, -1] = camera_center

    # Transform points
    points = torch.einsum('bij,bkj->bki', rotation, points)
    points = points + translation.unsqueeze(1)

    # Apply perspective distortion
    projected_points = points / points[:, :, -1].unsqueeze(-1)

    # Apply camera intrinsics
    projected_points = torch.einsum('bij,bkj->bki', K, projected_points)
    projected_points = projected_points[:, :, :-1]
    return projected_points


@LOSSES.register_module()
class MeshLoss(nn.Module):
    """Mix loss for 3D human mesh. It is composed of loss on 2D joints, 3D
    joints, mesh vertices and smpl parameters (if any).

    Args:
        joints_2d_loss_weight (float): Weight for loss on 2D joints.
        joints_3d_loss_weight (float): Weight for loss on 3D joints.
        vertex_loss_weight (float): Weight for loss on 3D verteices.
        smpl_pose_loss_weight (float): Weight for loss on SMPL
            pose parameters.
        smpl_beta_loss_weight (float): Weight for loss on SMPL
            shape parameters.
        img_res (int): Input image resolution.
        focal_length (float): Focal length of camera model. Default=5000.
    """

    def __init__(self,
                 joints_2d_loss_weight,
                 joints_3d_loss_weight,
                 vertex_loss_weight,
                 smpl_pose_loss_weight,
                 smpl_beta_loss_weight,
                 img_res,
                 focal_length=5000):

        super().__init__()
        # Per-vertex loss on the mesh
        self.criterion_vertex = nn.L1Loss(reduction='none')

        # Joints (2D and 3D) loss
        self.criterion_joints_2d = nn.SmoothL1Loss(reduction='none')
        self.criterion_joints_3d = nn.SmoothL1Loss(reduction='none')

        # Loss for SMPL parameter regression
        self.criterion_regr = nn.MSELoss(reduction='none')

        self.joints_2d_loss_weight = joints_2d_loss_weight
        self.joints_3d_loss_weight = joints_3d_loss_weight
        self.vertex_loss_weight = vertex_loss_weight
        self.smpl_pose_loss_weight = smpl_pose_loss_weight
        self.smpl_beta_loss_weight = smpl_beta_loss_weight
        self.focal_length = focal_length
        self.img_res = img_res

    def joints_2d_loss(self, pred_joints_2d, gt_joints_2d, joints_2d_visible):
        """Compute 2D reprojection loss on the joints.

        The loss is weighted by joints_2d_visible.
        """
        conf = joints_2d_visible.float()
        loss = (conf *
                self.criterion_joints_2d(pred_joints_2d, gt_joints_2d)).mean()
        return loss

    def joints_3d_loss(self, pred_joints_3d, gt_joints_3d, joints_3d_visible):
        """Compute 3D joints loss for the examples that 3D joint annotations
        are available.

        The loss is weighted by joints_3d_visible.
        """
        conf = joints_3d_visible.float()
        if len(gt_joints_3d) > 0:
            gt_pelvis = (gt_joints_3d[:, 2, :] + gt_joints_3d[:, 3, :]) / 2
            gt_joints_3d = gt_joints_3d - gt_pelvis[:, None, :]
            pred_pelvis = (pred_joints_3d[:, 2, :] +
                           pred_joints_3d[:, 3, :]) / 2
            pred_joints_3d = pred_joints_3d - pred_pelvis[:, None, :]
            return (
                conf *
                self.criterion_joints_3d(pred_joints_3d, gt_joints_3d)).mean()
        return pred_joints_3d.sum() * 0

    def vertex_loss(self, pred_vertices, gt_vertices, has_smpl):
        """Compute 3D vertex loss for the examples that 3D human mesh
        annotations are available.

        The loss is weighted by the has_smpl.
        """
        conf = has_smpl.float()
        loss_vertex = self.criterion_vertex(pred_vertices, gt_vertices)
        loss_vertex = (conf[:, None, None] * loss_vertex).mean()
        return loss_vertex

    def smpl_losses(self, pred_rotmat, pred_betas, gt_pose, gt_betas,
                    has_smpl):
        """Compute SMPL parameters loss for the examples that SMPL parameter
        annotations are available.

        The loss is weighted by has_smpl.
        """
        conf = has_smpl.float()
        gt_rotmat = batch_rodrigues(gt_pose.view(-1, 3)).view(-1, 24, 3, 3)
        loss_regr_pose = self.criterion_regr(pred_rotmat, gt_rotmat)
        loss_regr_betas = self.criterion_regr(pred_betas, gt_betas)
        loss_regr_pose = (conf[:, None, None, None] * loss_regr_pose).mean()
        loss_regr_betas = (conf[:, None] * loss_regr_betas).mean()
        return loss_regr_pose, loss_regr_betas

    def project_points(self, points_3d, camera):
        """Perform orthographic projection of 3D points using the camera
        parameters, return projected 2D points in image plane.

        Note:
            - batch size: B
            - point number: N

        Args:
            points_3d (Tensor([B, N, 3])): 3D points.
            camera (Tensor([B, 3])): camera parameters with the
                3 channel as (scale, translation_x, translation_y)

        Returns:
            Tensor([B, N, 2]): projected 2D points \
                in image space.
        """
        batch_size = points_3d.shape[0]
        device = points_3d.device
        cam_t = torch.stack([
            camera[:, 1], camera[:, 2], 2 * self.focal_length /
            (self.img_res * camera[:, 0] + 1e-9)
        ],
                            dim=-1)
        camera_center = camera.new_zeros([batch_size, 2])
        rot_t = torch.eye(
            3, device=device,
            dtype=points_3d.dtype).unsqueeze(0).expand(batch_size, -1, -1)
        joints_2d = perspective_projection(
            points_3d,
            rotation=rot_t,
            translation=cam_t,
            focal_length=self.focal_length,
            camera_center=camera_center)
        return joints_2d

    def forward(self, output, target):
        """Forward function.

        Args:
            output (dict): dict of network predicted results.
                Keys: 'vertices', 'joints_3d', 'camera',
                'pose'(optional), 'beta'(optional)
            target (dict): dict of ground-truth labels.
                Keys: 'vertices', 'joints_3d', 'joints_3d_visible',
                'joints_2d', 'joints_2d_visible', 'pose', 'beta',
                'has_smpl'

        Returns:
            dict: dict of losses.
        """
        losses = {}

        # Per-vertex loss for the shape
        pred_vertices = output['vertices']

        gt_vertices = target['vertices']
        has_smpl = target['has_smpl']
        loss_vertex = self.vertex_loss(pred_vertices, gt_vertices, has_smpl)
        losses['vertex_loss'] = loss_vertex * self.vertex_loss_weight

        # Compute loss on SMPL parameters, if available
        if 'pose' in output.keys() and 'beta' in output.keys():
            pred_rotmat = output['pose']
            pred_betas = output['beta']
            gt_pose = target['pose']
            gt_betas = target['beta']
            loss_regr_pose, loss_regr_betas = self.smpl_losses(
                pred_rotmat, pred_betas, gt_pose, gt_betas, has_smpl)
            losses['smpl_pose_loss'] = \
                loss_regr_pose * self.smpl_pose_loss_weight
            losses['smpl_beta_loss'] = \
                loss_regr_betas * self.smpl_beta_loss_weight

        # Compute 3D joints loss
        pred_joints_3d = output['joints_3d']
        gt_joints_3d = target['joints_3d']
        joints_3d_visible = target['joints_3d_visible']
        loss_joints_3d = self.joints_3d_loss(pred_joints_3d, gt_joints_3d,
                                             joints_3d_visible)
        losses['joints_3d_loss'] = loss_joints_3d * self.joints_3d_loss_weight

        # Compute 2D reprojection loss for the 2D joints
        pred_camera = output['camera']
        gt_joints_2d = target['joints_2d']
        joints_2d_visible = target['joints_2d_visible']
        pred_joints_2d = self.project_points(pred_joints_3d, pred_camera)

        # Normalize keypoints to [-1,1]
        # The coordinate origin of pred_joints_2d is
        #  the center of the input image.
        pred_joints_2d = 2 * pred_joints_2d / (self.img_res - 1)
        # The coordinate origin of gt_joints_2d is
        # the top left corner of the input image.
        gt_joints_2d = 2 * gt_joints_2d / (self.img_res - 1) - 1
        loss_joints_2d = self.joints_2d_loss(pred_joints_2d, gt_joints_2d,
                                             joints_2d_visible)
        losses['joints_2d_loss'] = loss_joints_2d * self.joints_2d_loss_weight

        return losses


@LOSSES.register_module()
class GANLoss(nn.Module):
    """Define GAN loss.

    Args:
        gan_type (str): Support 'vanilla', 'lsgan', 'wgan', 'hinge'.
        real_label_val (float): The value for real label. Default: 1.0.
        fake_label_val (float): The value for fake label. Default: 0.0.
        loss_weight (float): Loss weight. Default: 1.0.
            Note that loss_weight is only for generators; and it is always 1.0
            for discriminators.
    """

    def __init__(self,
                 gan_type,
                 real_label_val=1.0,
                 fake_label_val=0.0,
                 loss_weight=1.0):
        super().__init__()
        self.gan_type = gan_type
        self.loss_weight = loss_weight
        self.real_label_val = real_label_val
        self.fake_label_val = fake_label_val

        if self.gan_type == 'vanilla':
            self.loss = nn.BCEWithLogitsLoss()
        elif self.gan_type == 'lsgan':
            self.loss = nn.MSELoss()
        elif self.gan_type == 'wgan':
            self.loss = self._wgan_loss
        elif self.gan_type == 'hinge':
            self.loss = nn.ReLU()
        else:
            raise NotImplementedError(
                f'GAN type {self.gan_type} is not implemented.')

    @staticmethod
    def _wgan_loss(input, target):
        """wgan loss.

        Args:
            input (Tensor): Input tensor.
            target (bool): Target label.

        Returns:
            Tensor: wgan loss.
        """
        return -input.mean() if target else input.mean()

    def get_target_label(self, input, target_is_real):
        """Get target label.

        Args:
            input (Tensor): Input tensor.
            target_is_real (bool): Whether the target is real or fake.

        Returns:
            (bool | Tensor): Target tensor. Return bool for wgan, \
                otherwise, return Tensor.
        """

        if self.gan_type == 'wgan':
            return target_is_real
        target_val = (
            self.real_label_val if target_is_real else self.fake_label_val)
        return input.new_ones(input.size()) * target_val

    def forward(self, input, target_is_real, is_disc=False):
        """
        Args:
            input (Tensor): The input for the loss module, i.e., the network
                prediction.
            target_is_real (bool): Whether the targe is real or fake.
            is_disc (bool): Whether the loss for discriminators or not.
                Default: False.

        Returns:
            Tensor: GAN loss value.
        """
        target_label = self.get_target_label(input, target_is_real)
        if self.gan_type == 'hinge':
            if is_disc:  # for discriminators in hinge-gan
                input = -input if target_is_real else input
                loss = self.loss(1 + input).mean()
            else:  # for generators in hinge-gan
                loss = -input.mean()
        else:  # other gan types
            loss = self.loss(input, target_label)

        # loss_weight is always 1.0 for discriminators
        return loss if is_disc else loss * self.loss_weight