|
|
|
import mmcv |
|
import torch |
|
import torch.nn as nn |
|
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init, |
|
normal_init) |
|
from mmcv.utils import digit_version |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
from mmpose.models.utils.ops import resize |
|
from ..backbones.resnet import BasicBlock, Bottleneck |
|
from ..builder import NECKS |
|
|
|
try: |
|
from mmcv.ops import DeformConv2d |
|
has_mmcv_full = True |
|
except (ImportError, ModuleNotFoundError): |
|
has_mmcv_full = False |
|
|
|
|
|
@NECKS.register_module() |
|
class PoseWarperNeck(nn.Module): |
|
"""PoseWarper neck. |
|
|
|
`"Learning temporal pose estimation from sparsely-labeled videos" |
|
<https://arxiv.org/abs/1906.04016>`_. |
|
|
|
Args: |
|
in_channels (int): Number of input channels from backbone |
|
out_channels (int): Number of output channels |
|
inner_channels (int): Number of intermediate channels of the res block |
|
deform_groups (int): Number of groups in the deformable conv |
|
dilations (list|tuple): different dilations of the offset conv layers |
|
trans_conv_kernel (int): the kernel of the trans conv layer, which is |
|
used to get heatmap from the output of backbone. Default: 1 |
|
res_blocks_cfg (dict|None): config of residual blocks. If None, |
|
use the default values. If not None, it should contain the |
|
following keys: |
|
|
|
- block (str): the type of residual block, Default: 'BASIC'. |
|
- num_blocks (int): the number of blocks, Default: 20. |
|
|
|
offsets_kernel (int): the kernel of offset conv layer. |
|
deform_conv_kernel (int): the kernel of defomrable conv layer. |
|
in_index (int|Sequence[int]): Input feature index. Default: 0 |
|
input_transform (str|None): Transformation type of input features. |
|
Options: 'resize_concat', 'multiple_select', None. |
|
Default: None. |
|
|
|
- 'resize_concat': Multiple feature maps will be resize to \ |
|
the same size as first one and than concat together. \ |
|
Usually used in FCN head of HRNet. |
|
- 'multiple_select': Multiple feature maps will be bundle into \ |
|
a list and passed into decode head. |
|
- None: Only one select feature map is allowed. |
|
|
|
freeze_trans_layer (bool): Whether to freeze the transition layer |
|
(stop grad and set eval mode). Default: True. |
|
norm_eval (bool): Whether to set norm layers to eval mode, namely, |
|
freeze running stats (mean and var). Note: Effect on Batch Norm |
|
and its variants only. Default: False. |
|
im2col_step (int): the argument `im2col_step` in deformable conv, |
|
Default: 80. |
|
""" |
|
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck} |
|
minimum_mmcv_version = '1.3.17' |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
inner_channels, |
|
deform_groups=17, |
|
dilations=(3, 6, 12, 18, 24), |
|
trans_conv_kernel=1, |
|
res_blocks_cfg=None, |
|
offsets_kernel=3, |
|
deform_conv_kernel=3, |
|
in_index=0, |
|
input_transform=None, |
|
freeze_trans_layer=True, |
|
norm_eval=False, |
|
im2col_step=80): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.inner_channels = inner_channels |
|
self.deform_groups = deform_groups |
|
self.dilations = dilations |
|
self.trans_conv_kernel = trans_conv_kernel |
|
self.res_blocks_cfg = res_blocks_cfg |
|
self.offsets_kernel = offsets_kernel |
|
self.deform_conv_kernel = deform_conv_kernel |
|
self.in_index = in_index |
|
self.input_transform = input_transform |
|
self.freeze_trans_layer = freeze_trans_layer |
|
self.norm_eval = norm_eval |
|
self.im2col_step = im2col_step |
|
|
|
identity_trans_layer = False |
|
|
|
assert trans_conv_kernel in [0, 1, 3] |
|
kernel_size = trans_conv_kernel |
|
if kernel_size == 3: |
|
padding = 1 |
|
elif kernel_size == 1: |
|
padding = 0 |
|
else: |
|
|
|
identity_trans_layer = True |
|
|
|
if identity_trans_layer: |
|
self.trans_layer = nn.Identity() |
|
else: |
|
self.trans_layer = build_conv_layer( |
|
cfg=dict(type='Conv2d'), |
|
in_channels=in_channels, |
|
out_channels=out_channels, |
|
kernel_size=kernel_size, |
|
stride=1, |
|
padding=padding) |
|
|
|
|
|
if res_blocks_cfg is not None and not isinstance(res_blocks_cfg, dict): |
|
raise TypeError('res_blocks_cfg should be dict or None.') |
|
|
|
if res_blocks_cfg is None: |
|
block_type = 'BASIC' |
|
num_blocks = 20 |
|
else: |
|
block_type = res_blocks_cfg.get('block', 'BASIC') |
|
num_blocks = res_blocks_cfg.get('num_blocks', 20) |
|
|
|
block = self.blocks_dict[block_type] |
|
|
|
res_layers = [] |
|
downsample = nn.Sequential( |
|
build_conv_layer( |
|
cfg=dict(type='Conv2d'), |
|
in_channels=out_channels, |
|
out_channels=inner_channels, |
|
kernel_size=1, |
|
stride=1, |
|
bias=False), |
|
build_norm_layer(dict(type='BN'), inner_channels)[1]) |
|
res_layers.append( |
|
block( |
|
in_channels=out_channels, |
|
out_channels=inner_channels, |
|
downsample=downsample)) |
|
|
|
for _ in range(1, num_blocks): |
|
res_layers.append(block(inner_channels, inner_channels)) |
|
self.offset_feats = nn.Sequential(*res_layers) |
|
|
|
|
|
self.num_offset_layers = len(dilations) |
|
assert self.num_offset_layers > 0, 'Number of offset layers ' \ |
|
'should be larger than 0.' |
|
|
|
target_offset_channels = 2 * offsets_kernel**2 * deform_groups |
|
|
|
offset_layers = [ |
|
build_conv_layer( |
|
cfg=dict(type='Conv2d'), |
|
in_channels=inner_channels, |
|
out_channels=target_offset_channels, |
|
kernel_size=offsets_kernel, |
|
stride=1, |
|
dilation=dilations[i], |
|
padding=dilations[i], |
|
bias=False, |
|
) for i in range(self.num_offset_layers) |
|
] |
|
self.offset_layers = nn.ModuleList(offset_layers) |
|
|
|
|
|
assert digit_version(mmcv.__version__) >= \ |
|
digit_version(self.minimum_mmcv_version), \ |
|
f'Current MMCV version: {mmcv.__version__}, ' \ |
|
f'but MMCV >= {self.minimum_mmcv_version} is required, see ' \ |
|
f'https://github.com/open-mmlab/mmcv/issues/1440, ' \ |
|
f'Please install the latest MMCV.' |
|
|
|
if has_mmcv_full: |
|
deform_conv_layers = [ |
|
DeformConv2d( |
|
in_channels=out_channels, |
|
out_channels=out_channels, |
|
kernel_size=deform_conv_kernel, |
|
stride=1, |
|
padding=int(deform_conv_kernel / 2) * dilations[i], |
|
dilation=dilations[i], |
|
deform_groups=deform_groups, |
|
im2col_step=self.im2col_step, |
|
) for i in range(self.num_offset_layers) |
|
] |
|
else: |
|
raise ImportError('Please install the full version of mmcv ' |
|
'to use `DeformConv2d`.') |
|
|
|
self.deform_conv_layers = nn.ModuleList(deform_conv_layers) |
|
|
|
self.freeze_layers() |
|
|
|
def freeze_layers(self): |
|
if self.freeze_trans_layer: |
|
self.trans_layer.eval() |
|
|
|
for param in self.trans_layer.parameters(): |
|
param.requires_grad = False |
|
|
|
def init_weights(self): |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
normal_init(m, std=0.001) |
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): |
|
constant_init(m, 1) |
|
elif isinstance(m, DeformConv2d): |
|
filler = torch.zeros([ |
|
m.weight.size(0), |
|
m.weight.size(1), |
|
m.weight.size(2), |
|
m.weight.size(3) |
|
], |
|
dtype=torch.float32, |
|
device=m.weight.device) |
|
for k in range(m.weight.size(0)): |
|
filler[k, k, |
|
int(m.weight.size(2) / 2), |
|
int(m.weight.size(3) / 2)] = 1.0 |
|
m.weight = torch.nn.Parameter(filler) |
|
m.weight.requires_grad = True |
|
|
|
|
|
for m in self.offset_layers.modules(): |
|
constant_init(m, 0) |
|
|
|
def _transform_inputs(self, inputs): |
|
"""Transform inputs for decoder. |
|
|
|
Args: |
|
inputs (list[Tensor] | Tensor): multi-level img features. |
|
|
|
Returns: |
|
Tensor: The transformed inputs |
|
""" |
|
if not isinstance(inputs, list): |
|
return inputs |
|
|
|
if self.input_transform == 'resize_concat': |
|
inputs = [inputs[i] for i in self.in_index] |
|
upsampled_inputs = [ |
|
resize( |
|
input=x, |
|
size=inputs[0].shape[2:], |
|
mode='bilinear', |
|
align_corners=self.align_corners) for x in inputs |
|
] |
|
inputs = torch.cat(upsampled_inputs, dim=1) |
|
elif self.input_transform == 'multiple_select': |
|
inputs = [inputs[i] for i in self.in_index] |
|
else: |
|
inputs = inputs[self.in_index] |
|
|
|
return inputs |
|
|
|
def forward(self, inputs, frame_weight): |
|
assert isinstance(inputs, (list, tuple)), 'PoseWarperNeck inputs ' \ |
|
'should be list or tuple, even though the length is 1, ' \ |
|
'for unified processing.' |
|
|
|
output_heatmap = 0 |
|
if len(inputs) > 1: |
|
inputs = [self._transform_inputs(input) for input in inputs] |
|
inputs = [self.trans_layer(input) for input in inputs] |
|
|
|
|
|
diff_features = [ |
|
self.offset_feats(inputs[0] - input) for input in inputs |
|
] |
|
|
|
for i in range(len(inputs)): |
|
if frame_weight[i] == 0: |
|
continue |
|
warped_heatmap = 0 |
|
for j in range(self.num_offset_layers): |
|
offset = (self.offset_layers[j](diff_features[i])) |
|
warped_heatmap_tmp = self.deform_conv_layers[j](inputs[i], |
|
offset) |
|
warped_heatmap += warped_heatmap_tmp / \ |
|
self.num_offset_layers |
|
|
|
output_heatmap += warped_heatmap * frame_weight[i] |
|
|
|
else: |
|
inputs = inputs[0] |
|
inputs = self._transform_inputs(inputs) |
|
inputs = self.trans_layer(inputs) |
|
|
|
num_frames = len(frame_weight) |
|
batch_size = inputs.size(0) // num_frames |
|
ref_x = inputs[:batch_size] |
|
ref_x_tiled = ref_x.repeat(num_frames, 1, 1, 1) |
|
|
|
offset_features = self.offset_feats(ref_x_tiled - inputs) |
|
|
|
warped_heatmap = 0 |
|
for j in range(self.num_offset_layers): |
|
offset = self.offset_layers[j](offset_features) |
|
|
|
warped_heatmap_tmp = self.deform_conv_layers[j](inputs, offset) |
|
warped_heatmap += warped_heatmap_tmp / self.num_offset_layers |
|
|
|
for i in range(num_frames): |
|
if frame_weight[i] == 0: |
|
continue |
|
output_heatmap += warped_heatmap[i * batch_size:(i + 1) * |
|
batch_size] * frame_weight[i] |
|
|
|
return output_heatmap |
|
|
|
def train(self, mode=True): |
|
"""Convert the model into training mode.""" |
|
super().train(mode) |
|
self.freeze_layers() |
|
if mode and self.norm_eval: |
|
for m in self.modules(): |
|
if isinstance(m, _BatchNorm): |
|
m.eval() |
|
|