File size: 12,514 Bytes
2cd560a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 324 325 326 327 328 329 330 |
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
normal_init)
from mmcv.utils import digit_version
from torch.nn.modules.batchnorm import _BatchNorm
from mmpose.models.utils.ops import resize
from ..backbones.resnet import BasicBlock, Bottleneck
from ..builder import NECKS
try:
from mmcv.ops import DeformConv2d
has_mmcv_full = True
except (ImportError, ModuleNotFoundError):
has_mmcv_full = False
@NECKS.register_module()
class PoseWarperNeck(nn.Module):
"""PoseWarper neck.
`"Learning temporal pose estimation from sparsely-labeled videos"
<https://arxiv.org/abs/1906.04016>`_.
Args:
in_channels (int): Number of input channels from backbone
out_channels (int): Number of output channels
inner_channels (int): Number of intermediate channels of the res block
deform_groups (int): Number of groups in the deformable conv
dilations (list|tuple): different dilations of the offset conv layers
trans_conv_kernel (int): the kernel of the trans conv layer, which is
used to get heatmap from the output of backbone. Default: 1
res_blocks_cfg (dict|None): config of residual blocks. If None,
use the default values. If not None, it should contain the
following keys:
- block (str): the type of residual block, Default: 'BASIC'.
- num_blocks (int): the number of blocks, Default: 20.
offsets_kernel (int): the kernel of offset conv layer.
deform_conv_kernel (int): the kernel of defomrable conv layer.
in_index (int|Sequence[int]): Input feature index. Default: 0
input_transform (str|None): Transformation type of input features.
Options: 'resize_concat', 'multiple_select', None.
Default: None.
- 'resize_concat': Multiple feature maps will be resize to \
the same size as first one and than concat together. \
Usually used in FCN head of HRNet.
- 'multiple_select': Multiple feature maps will be bundle into \
a list and passed into decode head.
- None: Only one select feature map is allowed.
freeze_trans_layer (bool): Whether to freeze the transition layer
(stop grad and set eval mode). Default: True.
norm_eval (bool): Whether to set norm layers to eval mode, namely,
freeze running stats (mean and var). Note: Effect on Batch Norm
and its variants only. Default: False.
im2col_step (int): the argument `im2col_step` in deformable conv,
Default: 80.
"""
blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
minimum_mmcv_version = '1.3.17'
def __init__(self,
in_channels,
out_channels,
inner_channels,
deform_groups=17,
dilations=(3, 6, 12, 18, 24),
trans_conv_kernel=1,
res_blocks_cfg=None,
offsets_kernel=3,
deform_conv_kernel=3,
in_index=0,
input_transform=None,
freeze_trans_layer=True,
norm_eval=False,
im2col_step=80):
super().__init__()
self.in_channels = in_channels
self.out_channels = out_channels
self.inner_channels = inner_channels
self.deform_groups = deform_groups
self.dilations = dilations
self.trans_conv_kernel = trans_conv_kernel
self.res_blocks_cfg = res_blocks_cfg
self.offsets_kernel = offsets_kernel
self.deform_conv_kernel = deform_conv_kernel
self.in_index = in_index
self.input_transform = input_transform
self.freeze_trans_layer = freeze_trans_layer
self.norm_eval = norm_eval
self.im2col_step = im2col_step
identity_trans_layer = False
assert trans_conv_kernel in [0, 1, 3]
kernel_size = trans_conv_kernel
if kernel_size == 3:
padding = 1
elif kernel_size == 1:
padding = 0
else:
# 0 for Identity mapping.
identity_trans_layer = True
if identity_trans_layer:
self.trans_layer = nn.Identity()
else:
self.trans_layer = build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=in_channels,
out_channels=out_channels,
kernel_size=kernel_size,
stride=1,
padding=padding)
# build chain of residual blocks
if res_blocks_cfg is not None and not isinstance(res_blocks_cfg, dict):
raise TypeError('res_blocks_cfg should be dict or None.')
if res_blocks_cfg is None:
block_type = 'BASIC'
num_blocks = 20
else:
block_type = res_blocks_cfg.get('block', 'BASIC')
num_blocks = res_blocks_cfg.get('num_blocks', 20)
block = self.blocks_dict[block_type]
res_layers = []
downsample = nn.Sequential(
build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=out_channels,
out_channels=inner_channels,
kernel_size=1,
stride=1,
bias=False),
build_norm_layer(dict(type='BN'), inner_channels)[1])
res_layers.append(
block(
in_channels=out_channels,
out_channels=inner_channels,
downsample=downsample))
for _ in range(1, num_blocks):
res_layers.append(block(inner_channels, inner_channels))
self.offset_feats = nn.Sequential(*res_layers)
# build offset layers
self.num_offset_layers = len(dilations)
assert self.num_offset_layers > 0, 'Number of offset layers ' \
'should be larger than 0.'
target_offset_channels = 2 * offsets_kernel**2 * deform_groups
offset_layers = [
build_conv_layer(
cfg=dict(type='Conv2d'),
in_channels=inner_channels,
out_channels=target_offset_channels,
kernel_size=offsets_kernel,
stride=1,
dilation=dilations[i],
padding=dilations[i],
bias=False,
) for i in range(self.num_offset_layers)
]
self.offset_layers = nn.ModuleList(offset_layers)
# build deformable conv layers
assert digit_version(mmcv.__version__) >= \
digit_version(self.minimum_mmcv_version), \
f'Current MMCV version: {mmcv.__version__}, ' \
f'but MMCV >= {self.minimum_mmcv_version} is required, see ' \
f'https://github.com/open-mmlab/mmcv/issues/1440, ' \
f'Please install the latest MMCV.'
if has_mmcv_full:
deform_conv_layers = [
DeformConv2d(
in_channels=out_channels,
out_channels=out_channels,
kernel_size=deform_conv_kernel,
stride=1,
padding=int(deform_conv_kernel / 2) * dilations[i],
dilation=dilations[i],
deform_groups=deform_groups,
im2col_step=self.im2col_step,
) for i in range(self.num_offset_layers)
]
else:
raise ImportError('Please install the full version of mmcv '
'to use `DeformConv2d`.')
self.deform_conv_layers = nn.ModuleList(deform_conv_layers)
self.freeze_layers()
def freeze_layers(self):
if self.freeze_trans_layer:
self.trans_layer.eval()
for param in self.trans_layer.parameters():
param.requires_grad = False
def init_weights(self):
for m in self.modules():
if isinstance(m, nn.Conv2d):
normal_init(m, std=0.001)
elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
constant_init(m, 1)
elif isinstance(m, DeformConv2d):
filler = torch.zeros([
m.weight.size(0),
m.weight.size(1),
m.weight.size(2),
m.weight.size(3)
],
dtype=torch.float32,
device=m.weight.device)
for k in range(m.weight.size(0)):
filler[k, k,
int(m.weight.size(2) / 2),
int(m.weight.size(3) / 2)] = 1.0
m.weight = torch.nn.Parameter(filler)
m.weight.requires_grad = True
# posewarper offset layer weight initialization
for m in self.offset_layers.modules():
constant_init(m, 0)
def _transform_inputs(self, inputs):
"""Transform inputs for decoder.
Args:
inputs (list[Tensor] | Tensor): multi-level img features.
Returns:
Tensor: The transformed inputs
"""
if not isinstance(inputs, list):
return inputs
if self.input_transform == 'resize_concat':
inputs = [inputs[i] for i in self.in_index]
upsampled_inputs = [
resize(
input=x,
size=inputs[0].shape[2:],
mode='bilinear',
align_corners=self.align_corners) for x in inputs
]
inputs = torch.cat(upsampled_inputs, dim=1)
elif self.input_transform == 'multiple_select':
inputs = [inputs[i] for i in self.in_index]
else:
inputs = inputs[self.in_index]
return inputs
def forward(self, inputs, frame_weight):
assert isinstance(inputs, (list, tuple)), 'PoseWarperNeck inputs ' \
'should be list or tuple, even though the length is 1, ' \
'for unified processing.'
output_heatmap = 0
if len(inputs) > 1:
inputs = [self._transform_inputs(input) for input in inputs]
inputs = [self.trans_layer(input) for input in inputs]
# calculate difference features
diff_features = [
self.offset_feats(inputs[0] - input) for input in inputs
]
for i in range(len(inputs)):
if frame_weight[i] == 0:
continue
warped_heatmap = 0
for j in range(self.num_offset_layers):
offset = (self.offset_layers[j](diff_features[i]))
warped_heatmap_tmp = self.deform_conv_layers[j](inputs[i],
offset)
warped_heatmap += warped_heatmap_tmp / \
self.num_offset_layers
output_heatmap += warped_heatmap * frame_weight[i]
else:
inputs = inputs[0]
inputs = self._transform_inputs(inputs)
inputs = self.trans_layer(inputs)
num_frames = len(frame_weight)
batch_size = inputs.size(0) // num_frames
ref_x = inputs[:batch_size]
ref_x_tiled = ref_x.repeat(num_frames, 1, 1, 1)
offset_features = self.offset_feats(ref_x_tiled - inputs)
warped_heatmap = 0
for j in range(self.num_offset_layers):
offset = self.offset_layers[j](offset_features)
warped_heatmap_tmp = self.deform_conv_layers[j](inputs, offset)
warped_heatmap += warped_heatmap_tmp / self.num_offset_layers
for i in range(num_frames):
if frame_weight[i] == 0:
continue
output_heatmap += warped_heatmap[i * batch_size:(i + 1) *
batch_size] * frame_weight[i]
return output_heatmap
def train(self, mode=True):
"""Convert the model into training mode."""
super().train(mode)
self.freeze_layers()
if mode and self.norm_eval:
for m in self.modules():
if isinstance(m, _BatchNorm):
m.eval()
|