File size: 6,348 Bytes
2cd560a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
# Copyright (c) OpenMMLab. All rights reserved.
import numpy as np
import torch
import torch.nn as nn

from ..builder import MESH_MODELS

try:
    from smplx import SMPL as SMPL_
    has_smpl = True
except (ImportError, ModuleNotFoundError):
    has_smpl = False


@MESH_MODELS.register_module()
class SMPL(nn.Module):
    """SMPL 3d human mesh model of paper ref: Matthew Loper. ``SMPL: A skinned
    multi-person linear model''. This module is based on the smplx project
    (https://github.com/vchoutas/smplx).

    Args:
        smpl_path (str): The path to the folder where the model weights are
            stored.
        joints_regressor (str): The path to the file where the joints
            regressor weight are stored.
    """

    def __init__(self, smpl_path, joints_regressor):
        super().__init__()

        assert has_smpl, 'Please install smplx to use SMPL.'

        self.smpl_neutral = SMPL_(
            model_path=smpl_path,
            create_global_orient=False,
            create_body_pose=False,
            create_transl=False,
            gender='neutral')

        self.smpl_male = SMPL_(
            model_path=smpl_path,
            create_betas=False,
            create_global_orient=False,
            create_body_pose=False,
            create_transl=False,
            gender='male')

        self.smpl_female = SMPL_(
            model_path=smpl_path,
            create_betas=False,
            create_global_orient=False,
            create_body_pose=False,
            create_transl=False,
            gender='female')

        joints_regressor = torch.tensor(
            np.load(joints_regressor), dtype=torch.float)[None, ...]
        self.register_buffer('joints_regressor', joints_regressor)

        self.num_verts = self.smpl_neutral.get_num_verts()
        self.num_joints = self.joints_regressor.shape[1]

    def smpl_forward(self, model, **kwargs):
        """Apply a specific SMPL model with given model parameters.

        Note:
            B: batch size
            V: number of vertices
            K: number of joints

        Returns:
            outputs (dict): Dict with mesh vertices and joints.
                - vertices: Tensor([B, V, 3]), mesh vertices
                - joints: Tensor([B, K, 3]), 3d joints regressed
                    from mesh vertices.
        """

        betas = kwargs['betas']
        batch_size = betas.shape[0]
        device = betas.device
        output = {}
        if batch_size == 0:
            output['vertices'] = betas.new_zeros([0, self.num_verts, 3])
            output['joints'] = betas.new_zeros([0, self.num_joints, 3])
        else:
            smpl_out = model(**kwargs)
            output['vertices'] = smpl_out.vertices
            output['joints'] = torch.matmul(
                self.joints_regressor.to(device), output['vertices'])
        return output

    def get_faces(self):
        """Return mesh faces.

        Note:
            F: number of faces

        Returns:
            faces: np.ndarray([F, 3]), mesh faces
        """
        return self.smpl_neutral.faces

    def forward(self,
                betas,
                body_pose,
                global_orient,
                transl=None,
                gender=None):
        """Forward function.

        Note:
            B: batch size
            J: number of controllable joints of model, for smpl model J=23
            K: number of joints

        Args:
            betas: Tensor([B, 10]), human body shape parameters of SMPL model.
            body_pose: Tensor([B, J*3] or [B, J, 3, 3]), human body pose
                parameters of SMPL model. It should be axis-angle vector
                ([B, J*3]) or rotation matrix ([B, J, 3, 3)].
            global_orient: Tensor([B, 3] or [B, 1, 3, 3]), global orientation
                of human body. It should be axis-angle vector ([B, 3]) or
                rotation matrix ([B, 1, 3, 3)].
            transl: Tensor([B, 3]), global translation of human body.
            gender: Tensor([B]), gender parameters of human body. -1 for
                neutral, 0 for male , 1 for female.

        Returns:
            outputs (dict): Dict with mesh vertices and joints.
                - vertices: Tensor([B, V, 3]), mesh vertices
                - joints: Tensor([B, K, 3]), 3d joints regressed from
                    mesh vertices.
        """

        batch_size = betas.shape[0]
        pose2rot = True if body_pose.dim() == 2 else False
        if batch_size > 0 and gender is not None:
            output = {
                'vertices': betas.new_zeros([batch_size, self.num_verts, 3]),
                'joints': betas.new_zeros([batch_size, self.num_joints, 3])
            }

            mask = gender < 0
            _out = self.smpl_forward(
                self.smpl_neutral,
                betas=betas[mask],
                body_pose=body_pose[mask],
                global_orient=global_orient[mask],
                transl=transl[mask] if transl is not None else None,
                pose2rot=pose2rot)
            output['vertices'][mask] = _out['vertices']
            output['joints'][mask] = _out['joints']

            mask = gender == 0
            _out = self.smpl_forward(
                self.smpl_male,
                betas=betas[mask],
                body_pose=body_pose[mask],
                global_orient=global_orient[mask],
                transl=transl[mask] if transl is not None else None,
                pose2rot=pose2rot)
            output['vertices'][mask] = _out['vertices']
            output['joints'][mask] = _out['joints']

            mask = gender == 1
            _out = self.smpl_forward(
                self.smpl_male,
                betas=betas[mask],
                body_pose=body_pose[mask],
                global_orient=global_orient[mask],
                transl=transl[mask] if transl is not None else None,
                pose2rot=pose2rot)
            output['vertices'][mask] = _out['vertices']
            output['joints'][mask] = _out['joints']
        else:
            return self.smpl_forward(
                self.smpl_neutral,
                betas=betas,
                body_pose=body_pose,
                global_orient=global_orient,
                transl=transl,
                pose2rot=pose2rot)

        return output