File size: 12,514 Bytes
2cd560a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
324
325
326
327
328
329
330
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import torch
import torch.nn as nn
from mmcv.cnn import (build_conv_layer, build_norm_layer, constant_init,
                      normal_init)
from mmcv.utils import digit_version
from torch.nn.modules.batchnorm import _BatchNorm

from mmpose.models.utils.ops import resize
from ..backbones.resnet import BasicBlock, Bottleneck
from ..builder import NECKS

try:
    from mmcv.ops import DeformConv2d
    has_mmcv_full = True
except (ImportError, ModuleNotFoundError):
    has_mmcv_full = False


@NECKS.register_module()
class PoseWarperNeck(nn.Module):
    """PoseWarper neck.

    `"Learning temporal pose estimation from sparsely-labeled videos"
    <https://arxiv.org/abs/1906.04016>`_.

    Args:
        in_channels (int): Number of input channels from backbone
        out_channels (int): Number of output channels
        inner_channels (int): Number of intermediate channels of the res block
        deform_groups (int): Number of groups in the deformable conv
        dilations (list|tuple): different dilations of the offset conv layers
        trans_conv_kernel (int): the kernel of the trans conv layer, which is
            used to get heatmap from the output of backbone. Default: 1
        res_blocks_cfg (dict|None): config of residual blocks. If None,
            use the default values. If not None, it should contain the
            following keys:

            - block (str): the type of residual block, Default: 'BASIC'.
            - num_blocks (int):  the number of blocks, Default: 20.

        offsets_kernel (int): the kernel of offset conv layer.
        deform_conv_kernel (int): the kernel of defomrable conv layer.
        in_index (int|Sequence[int]): Input feature index. Default: 0
        input_transform (str|None): Transformation type of input features.
            Options: 'resize_concat', 'multiple_select', None.
            Default: None.

            - 'resize_concat': Multiple feature maps will be resize to \
                the same size as first one and than concat together. \
                Usually used in FCN head of HRNet.
            - 'multiple_select': Multiple feature maps will be bundle into \
                a list and passed into decode head.
            - None: Only one select feature map is allowed.

        freeze_trans_layer (bool): Whether to freeze the transition layer
            (stop grad and set eval mode). Default: True.
        norm_eval (bool): Whether to set norm layers to eval mode, namely,
            freeze running stats (mean and var). Note: Effect on Batch Norm
            and its variants only. Default: False.
        im2col_step (int): the argument `im2col_step` in deformable conv,
            Default: 80.
    """
    blocks_dict = {'BASIC': BasicBlock, 'BOTTLENECK': Bottleneck}
    minimum_mmcv_version = '1.3.17'

    def __init__(self,
                 in_channels,
                 out_channels,
                 inner_channels,
                 deform_groups=17,
                 dilations=(3, 6, 12, 18, 24),
                 trans_conv_kernel=1,
                 res_blocks_cfg=None,
                 offsets_kernel=3,
                 deform_conv_kernel=3,
                 in_index=0,
                 input_transform=None,
                 freeze_trans_layer=True,
                 norm_eval=False,
                 im2col_step=80):
        super().__init__()
        self.in_channels = in_channels
        self.out_channels = out_channels
        self.inner_channels = inner_channels
        self.deform_groups = deform_groups
        self.dilations = dilations
        self.trans_conv_kernel = trans_conv_kernel
        self.res_blocks_cfg = res_blocks_cfg
        self.offsets_kernel = offsets_kernel
        self.deform_conv_kernel = deform_conv_kernel
        self.in_index = in_index
        self.input_transform = input_transform
        self.freeze_trans_layer = freeze_trans_layer
        self.norm_eval = norm_eval
        self.im2col_step = im2col_step

        identity_trans_layer = False

        assert trans_conv_kernel in [0, 1, 3]
        kernel_size = trans_conv_kernel
        if kernel_size == 3:
            padding = 1
        elif kernel_size == 1:
            padding = 0
        else:
            # 0 for Identity mapping.
            identity_trans_layer = True

        if identity_trans_layer:
            self.trans_layer = nn.Identity()
        else:
            self.trans_layer = build_conv_layer(
                cfg=dict(type='Conv2d'),
                in_channels=in_channels,
                out_channels=out_channels,
                kernel_size=kernel_size,
                stride=1,
                padding=padding)

        # build chain of residual blocks
        if res_blocks_cfg is not None and not isinstance(res_blocks_cfg, dict):
            raise TypeError('res_blocks_cfg should be dict or None.')

        if res_blocks_cfg is None:
            block_type = 'BASIC'
            num_blocks = 20
        else:
            block_type = res_blocks_cfg.get('block', 'BASIC')
            num_blocks = res_blocks_cfg.get('num_blocks', 20)

        block = self.blocks_dict[block_type]

        res_layers = []
        downsample = nn.Sequential(
            build_conv_layer(
                cfg=dict(type='Conv2d'),
                in_channels=out_channels,
                out_channels=inner_channels,
                kernel_size=1,
                stride=1,
                bias=False),
            build_norm_layer(dict(type='BN'), inner_channels)[1])
        res_layers.append(
            block(
                in_channels=out_channels,
                out_channels=inner_channels,
                downsample=downsample))

        for _ in range(1, num_blocks):
            res_layers.append(block(inner_channels, inner_channels))
        self.offset_feats = nn.Sequential(*res_layers)

        # build offset layers
        self.num_offset_layers = len(dilations)
        assert self.num_offset_layers > 0, 'Number of offset layers ' \
            'should be larger than 0.'

        target_offset_channels = 2 * offsets_kernel**2 * deform_groups

        offset_layers = [
            build_conv_layer(
                cfg=dict(type='Conv2d'),
                in_channels=inner_channels,
                out_channels=target_offset_channels,
                kernel_size=offsets_kernel,
                stride=1,
                dilation=dilations[i],
                padding=dilations[i],
                bias=False,
            ) for i in range(self.num_offset_layers)
        ]
        self.offset_layers = nn.ModuleList(offset_layers)

        # build deformable conv layers
        assert digit_version(mmcv.__version__) >= \
            digit_version(self.minimum_mmcv_version), \
            f'Current MMCV version: {mmcv.__version__}, ' \
            f'but MMCV >= {self.minimum_mmcv_version} is required, see ' \
            f'https://github.com/open-mmlab/mmcv/issues/1440, ' \
            f'Please install the latest MMCV.'

        if has_mmcv_full:
            deform_conv_layers = [
                DeformConv2d(
                    in_channels=out_channels,
                    out_channels=out_channels,
                    kernel_size=deform_conv_kernel,
                    stride=1,
                    padding=int(deform_conv_kernel / 2) * dilations[i],
                    dilation=dilations[i],
                    deform_groups=deform_groups,
                    im2col_step=self.im2col_step,
                ) for i in range(self.num_offset_layers)
            ]
        else:
            raise ImportError('Please install the full version of mmcv '
                              'to use `DeformConv2d`.')

        self.deform_conv_layers = nn.ModuleList(deform_conv_layers)

        self.freeze_layers()

    def freeze_layers(self):
        if self.freeze_trans_layer:
            self.trans_layer.eval()

            for param in self.trans_layer.parameters():
                param.requires_grad = False

    def init_weights(self):
        for m in self.modules():
            if isinstance(m, nn.Conv2d):
                normal_init(m, std=0.001)
            elif isinstance(m, (_BatchNorm, nn.GroupNorm)):
                constant_init(m, 1)
            elif isinstance(m, DeformConv2d):
                filler = torch.zeros([
                    m.weight.size(0),
                    m.weight.size(1),
                    m.weight.size(2),
                    m.weight.size(3)
                ],
                                     dtype=torch.float32,
                                     device=m.weight.device)
                for k in range(m.weight.size(0)):
                    filler[k, k,
                           int(m.weight.size(2) / 2),
                           int(m.weight.size(3) / 2)] = 1.0
                m.weight = torch.nn.Parameter(filler)
                m.weight.requires_grad = True

        # posewarper offset layer weight initialization
        for m in self.offset_layers.modules():
            constant_init(m, 0)

    def _transform_inputs(self, inputs):
        """Transform inputs for decoder.

        Args:
            inputs (list[Tensor] | Tensor): multi-level img features.

        Returns:
            Tensor: The transformed inputs
        """
        if not isinstance(inputs, list):
            return inputs

        if self.input_transform == 'resize_concat':
            inputs = [inputs[i] for i in self.in_index]
            upsampled_inputs = [
                resize(
                    input=x,
                    size=inputs[0].shape[2:],
                    mode='bilinear',
                    align_corners=self.align_corners) for x in inputs
            ]
            inputs = torch.cat(upsampled_inputs, dim=1)
        elif self.input_transform == 'multiple_select':
            inputs = [inputs[i] for i in self.in_index]
        else:
            inputs = inputs[self.in_index]

        return inputs

    def forward(self, inputs, frame_weight):
        assert isinstance(inputs, (list, tuple)), 'PoseWarperNeck inputs ' \
            'should be list or tuple, even though the length is 1, ' \
            'for unified processing.'

        output_heatmap = 0
        if len(inputs) > 1:
            inputs = [self._transform_inputs(input) for input in inputs]
            inputs = [self.trans_layer(input) for input in inputs]

            # calculate difference features
            diff_features = [
                self.offset_feats(inputs[0] - input) for input in inputs
            ]

            for i in range(len(inputs)):
                if frame_weight[i] == 0:
                    continue
                warped_heatmap = 0
                for j in range(self.num_offset_layers):
                    offset = (self.offset_layers[j](diff_features[i]))
                    warped_heatmap_tmp = self.deform_conv_layers[j](inputs[i],
                                                                    offset)
                    warped_heatmap += warped_heatmap_tmp / \
                        self.num_offset_layers

                output_heatmap += warped_heatmap * frame_weight[i]

        else:
            inputs = inputs[0]
            inputs = self._transform_inputs(inputs)
            inputs = self.trans_layer(inputs)

            num_frames = len(frame_weight)
            batch_size = inputs.size(0) // num_frames
            ref_x = inputs[:batch_size]
            ref_x_tiled = ref_x.repeat(num_frames, 1, 1, 1)

            offset_features = self.offset_feats(ref_x_tiled - inputs)

            warped_heatmap = 0
            for j in range(self.num_offset_layers):
                offset = self.offset_layers[j](offset_features)

                warped_heatmap_tmp = self.deform_conv_layers[j](inputs, offset)
                warped_heatmap += warped_heatmap_tmp / self.num_offset_layers

            for i in range(num_frames):
                if frame_weight[i] == 0:
                    continue
                output_heatmap += warped_heatmap[i * batch_size:(i + 1) *
                                                 batch_size] * frame_weight[i]

        return output_heatmap

    def train(self, mode=True):
        """Convert the model into training mode."""
        super().train(mode)
        self.freeze_layers()
        if mode and self.norm_eval:
            for m in self.modules():
                if isinstance(m, _BatchNorm):
                    m.eval()