|
|
|
|
|
|
|
|
|
|
|
from abc import abstractmethod |
|
|
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import normal_init, xavier_init |
|
|
|
from mmpose.models.utils.geometry import batch_rodrigues |
|
|
|
|
|
class BaseDiscriminator(nn.Module): |
|
"""Base linear module for SMPL parameter discriminator. |
|
|
|
Args: |
|
fc_layers (Tuple): Tuple of neuron count, |
|
such as (9, 32, 32, 1) |
|
use_dropout (Tuple): Tuple of bool define use dropout or not |
|
for each layer, such as (True, True, False) |
|
drop_prob (Tuple): Tuple of float defined the drop prob, |
|
such as (0.5, 0.5, 0) |
|
use_activation(Tuple): Tuple of bool define use active function |
|
or not, such as (True, True, False) |
|
""" |
|
|
|
def __init__(self, fc_layers, use_dropout, drop_prob, use_activation): |
|
super().__init__() |
|
self.fc_layers = fc_layers |
|
self.use_dropout = use_dropout |
|
self.drop_prob = drop_prob |
|
self.use_activation = use_activation |
|
self._check() |
|
self.create_layers() |
|
|
|
def _check(self): |
|
"""Check input to avoid ValueError.""" |
|
if not isinstance(self.fc_layers, tuple): |
|
raise TypeError(f'fc_layers require tuple, ' |
|
f'get {type(self.fc_layers)}') |
|
|
|
if not isinstance(self.use_dropout, tuple): |
|
raise TypeError(f'use_dropout require tuple, ' |
|
f'get {type(self.use_dropout)}') |
|
|
|
if not isinstance(self.drop_prob, tuple): |
|
raise TypeError(f'drop_prob require tuple, ' |
|
f'get {type(self.drop_prob)}') |
|
|
|
if not isinstance(self.use_activation, tuple): |
|
raise TypeError(f'use_activation require tuple, ' |
|
f'get {type(self.use_activation)}') |
|
|
|
l_fc_layer = len(self.fc_layers) |
|
l_use_drop = len(self.use_dropout) |
|
l_drop_prob = len(self.drop_prob) |
|
l_use_activation = len(self.use_activation) |
|
|
|
pass_check = ( |
|
l_fc_layer >= 2 and l_use_drop < l_fc_layer |
|
and l_drop_prob < l_fc_layer and l_use_activation < l_fc_layer |
|
and l_drop_prob == l_use_drop) |
|
|
|
if not pass_check: |
|
msg = 'Wrong BaseDiscriminator parameters!' |
|
raise ValueError(msg) |
|
|
|
def create_layers(self): |
|
"""Create layers.""" |
|
l_fc_layer = len(self.fc_layers) |
|
l_use_drop = len(self.use_dropout) |
|
l_use_activation = len(self.use_activation) |
|
|
|
self.fc_blocks = nn.Sequential() |
|
|
|
for i in range(l_fc_layer - 1): |
|
self.fc_blocks.add_module( |
|
name=f'regressor_fc_{i}', |
|
module=nn.Linear( |
|
in_features=self.fc_layers[i], |
|
out_features=self.fc_layers[i + 1])) |
|
|
|
if i < l_use_activation and self.use_activation[i]: |
|
self.fc_blocks.add_module( |
|
name=f'regressor_af_{i}', module=nn.ReLU()) |
|
|
|
if i < l_use_drop and self.use_dropout[i]: |
|
self.fc_blocks.add_module( |
|
name=f'regressor_fc_dropout_{i}', |
|
module=nn.Dropout(p=self.drop_prob[i])) |
|
|
|
@abstractmethod |
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
msg = 'the base class [BaseDiscriminator] is not callable!' |
|
raise NotImplementedError(msg) |
|
|
|
def init_weights(self): |
|
"""Initialize model weights.""" |
|
for m in self.fc_blocks.named_modules(): |
|
if isinstance(m, nn.Linear): |
|
xavier_init(m, gain=0.01) |
|
|
|
|
|
class ShapeDiscriminator(BaseDiscriminator): |
|
"""Discriminator for SMPL shape parameters, the inputs is (batch_size x 10) |
|
|
|
Args: |
|
fc_layers (Tuple): Tuple of neuron count, such as (10, 5, 1) |
|
use_dropout (Tuple): Tuple of bool define use dropout or |
|
not for each layer, such as (True, True, False) |
|
drop_prob (Tuple): Tuple of float defined the drop prob, |
|
such as (0.5, 0) |
|
use_activation(Tuple): Tuple of bool define use active |
|
function or not, such as (True, False) |
|
""" |
|
|
|
def __init__(self, fc_layers, use_dropout, drop_prob, use_activation): |
|
if fc_layers[-1] != 1: |
|
msg = f'the neuron count of the last layer ' \ |
|
f'must be 1, but got {fc_layers[-1]}' |
|
raise ValueError(msg) |
|
|
|
super().__init__(fc_layers, use_dropout, drop_prob, use_activation) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
return self.fc_blocks(inputs) |
|
|
|
|
|
class PoseDiscriminator(nn.Module): |
|
"""Discriminator for SMPL pose parameters of each joint. It is composed of |
|
discriminators for each joints. The inputs is (batch_size x joint_count x |
|
9) |
|
|
|
Args: |
|
channels (Tuple): Tuple of channel number, |
|
such as (9, 32, 32, 1) |
|
joint_count (int): Joint number, such as 23 |
|
""" |
|
|
|
def __init__(self, channels, joint_count): |
|
super().__init__() |
|
if channels[-1] != 1: |
|
msg = f'the neuron count of the last layer ' \ |
|
f'must be 1, but got {channels[-1]}' |
|
raise ValueError(msg) |
|
self.joint_count = joint_count |
|
|
|
self.conv_blocks = nn.Sequential() |
|
len_channels = len(channels) |
|
for idx in range(len_channels - 2): |
|
self.conv_blocks.add_module( |
|
name=f'conv_{idx}', |
|
module=nn.Conv2d( |
|
in_channels=channels[idx], |
|
out_channels=channels[idx + 1], |
|
kernel_size=1, |
|
stride=1)) |
|
|
|
self.fc_layer = nn.ModuleList() |
|
for idx in range(joint_count): |
|
self.fc_layer.append( |
|
nn.Linear( |
|
in_features=channels[len_channels - 2], out_features=1)) |
|
|
|
def forward(self, inputs): |
|
"""Forward function. |
|
|
|
The input is (batch_size x joint_count x 9). |
|
""" |
|
|
|
inputs = inputs.transpose(1, 2).unsqueeze(2).contiguous() |
|
|
|
internal_outputs = self.conv_blocks(inputs) |
|
outputs = [] |
|
for idx in range(self.joint_count): |
|
outputs.append(self.fc_layer[idx](internal_outputs[:, :, 0, idx])) |
|
|
|
return torch.cat(outputs, 1), internal_outputs |
|
|
|
def init_weights(self): |
|
"""Initialize model weights.""" |
|
for m in self.conv_blocks: |
|
if isinstance(m, nn.Conv2d): |
|
normal_init(m, std=0.001, bias=0) |
|
for m in self.fc_layer.named_modules(): |
|
if isinstance(m, nn.Linear): |
|
xavier_init(m, gain=0.01) |
|
|
|
|
|
class FullPoseDiscriminator(BaseDiscriminator): |
|
"""Discriminator for SMPL pose parameters of all joints. |
|
|
|
Args: |
|
fc_layers (Tuple): Tuple of neuron count, |
|
such as (736, 1024, 1024, 1) |
|
use_dropout (Tuple): Tuple of bool define use dropout or not |
|
for each layer, such as (True, True, False) |
|
drop_prob (Tuple): Tuple of float defined the drop prob, |
|
such as (0.5, 0.5, 0) |
|
use_activation(Tuple): Tuple of bool define use active |
|
function or not, such as (True, True, False) |
|
""" |
|
|
|
def __init__(self, fc_layers, use_dropout, drop_prob, use_activation): |
|
if fc_layers[-1] != 1: |
|
msg = f'the neuron count of the last layer must be 1,' \ |
|
f' but got {fc_layers[-1]}' |
|
raise ValueError(msg) |
|
|
|
super().__init__(fc_layers, use_dropout, drop_prob, use_activation) |
|
|
|
def forward(self, inputs): |
|
"""Forward function.""" |
|
return self.fc_blocks(inputs) |
|
|
|
|
|
class SMPLDiscriminator(nn.Module): |
|
"""Discriminator for SMPL pose and shape parameters. It is composed of a |
|
discriminator for SMPL shape parameters, a discriminator for SMPL pose |
|
parameters of all joints and a discriminator for SMPL pose parameters of |
|
each joint. |
|
|
|
Args: |
|
beta_channel (tuple of int): Tuple of neuron count of the |
|
discriminator of shape parameters. Defaults to (10, 5, 1) |
|
per_joint_channel (tuple of int): Tuple of neuron count of the |
|
discriminator of each joint. Defaults to (9, 32, 32, 1) |
|
full_pose_channel (tuple of int): Tuple of neuron count of the |
|
discriminator of full pose. Defaults to (23*32, 1024, 1024, 1) |
|
""" |
|
|
|
def __init__(self, |
|
beta_channel=(10, 5, 1), |
|
per_joint_channel=(9, 32, 32, 1), |
|
full_pose_channel=(23 * 32, 1024, 1024, 1)): |
|
super().__init__() |
|
self.joint_count = 23 |
|
|
|
assert beta_channel[0] == 10 |
|
|
|
|
|
assert per_joint_channel[0] == 9 |
|
assert self.joint_count * per_joint_channel[-2] \ |
|
== full_pose_channel[0] |
|
|
|
self.beta_channel = beta_channel |
|
self.per_joint_channel = per_joint_channel |
|
self.full_pose_channel = full_pose_channel |
|
self._create_sub_modules() |
|
|
|
def _create_sub_modules(self): |
|
"""Create sub discriminators.""" |
|
|
|
|
|
self.pose_discriminator = PoseDiscriminator(self.per_joint_channel, |
|
self.joint_count) |
|
|
|
|
|
fc_layers = self.full_pose_channel |
|
use_dropout = tuple([False] * (len(fc_layers) - 1)) |
|
drop_prob = tuple([0.5] * (len(fc_layers) - 1)) |
|
use_activation = tuple([True] * (len(fc_layers) - 2) + [False]) |
|
|
|
self.full_pose_discriminator = FullPoseDiscriminator( |
|
fc_layers, use_dropout, drop_prob, use_activation) |
|
|
|
|
|
fc_layers = self.beta_channel |
|
use_dropout = tuple([False] * (len(fc_layers) - 1)) |
|
drop_prob = tuple([0.5] * (len(fc_layers) - 1)) |
|
use_activation = tuple([True] * (len(fc_layers) - 2) + [False]) |
|
self.shape_discriminator = ShapeDiscriminator(fc_layers, use_dropout, |
|
drop_prob, |
|
use_activation) |
|
|
|
def forward(self, thetas): |
|
"""Forward function.""" |
|
_, poses, shapes = thetas |
|
|
|
batch_size = poses.shape[0] |
|
shape_disc_value = self.shape_discriminator(shapes) |
|
|
|
|
|
|
|
if poses.dim() == 2: |
|
rotate_matrixs = \ |
|
batch_rodrigues(poses.contiguous().view(-1, 3) |
|
).view(batch_size, 24, 9)[:, 1:, :] |
|
else: |
|
rotate_matrixs = poses.contiguous().view(batch_size, 24, |
|
9)[:, 1:, :].contiguous() |
|
pose_disc_value, pose_inter_disc_value \ |
|
= self.pose_discriminator(rotate_matrixs) |
|
full_pose_disc_value = self.full_pose_discriminator( |
|
pose_inter_disc_value.contiguous().view(batch_size, -1)) |
|
return torch.cat( |
|
(pose_disc_value, full_pose_disc_value, shape_disc_value), 1) |
|
|
|
def init_weights(self): |
|
"""Initialize model weights.""" |
|
self.full_pose_discriminator.init_weights() |
|
self.pose_discriminator.init_weights() |
|
self.shape_discriminator.init_weights() |
|
|