# Copyright (c) OpenMMLab. All rights reserved. import copy import torch.nn as nn import torch.utils.checkpoint as cp from mmcv.cnn import ConvModule, build_conv_layer, build_norm_layer from mmcv.cnn.bricks import ContextBlock from mmcv.utils.parrots_wrapper import _BatchNorm from ..builder import BACKBONES from .base_backbone import BaseBackbone class ViPNAS_Bottleneck(nn.Module): """Bottleneck block for ViPNAS_ResNet. Args: in_channels (int): Input channels of this block. out_channels (int): Output channels of this block. expansion (int): The ratio of ``out_channels/mid_channels`` where ``mid_channels`` is the input/output channels of conv2. Default: 4. stride (int): stride of the block. Default: 1 dilation (int): dilation of convolution. Default: 1 downsample (nn.Module): downsample operation on identity branch. Default: None. style (str): ``"pytorch"`` or ``"caffe"``. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. Default: "pytorch". with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. conv_cfg (dict): dictionary to construct and config conv layer. Default: None norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') kernel_size (int): kernel size of conv2 searched in ViPANS. groups (int): group number of conv2 searched in ViPNAS. attention (bool): whether to use attention module in the end of the block. """ def __init__(self, in_channels, out_channels, expansion=4, stride=1, dilation=1, downsample=None, style='pytorch', with_cp=False, conv_cfg=None, norm_cfg=dict(type='BN'), kernel_size=3, groups=1, attention=False): # Protect mutable default arguments norm_cfg = copy.deepcopy(norm_cfg) super().__init__() assert style in ['pytorch', 'caffe'] self.in_channels = in_channels self.out_channels = out_channels self.expansion = expansion assert out_channels % expansion == 0 self.mid_channels = out_channels // expansion self.stride = stride self.dilation = dilation self.style = style self.with_cp = with_cp self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg if self.style == 'pytorch': self.conv1_stride = 1 self.conv2_stride = stride else: self.conv1_stride = stride self.conv2_stride = 1 self.norm1_name, norm1 = build_norm_layer( norm_cfg, self.mid_channels, postfix=1) self.norm2_name, norm2 = build_norm_layer( norm_cfg, self.mid_channels, postfix=2) self.norm3_name, norm3 = build_norm_layer( norm_cfg, out_channels, postfix=3) self.conv1 = build_conv_layer( conv_cfg, in_channels, self.mid_channels, kernel_size=1, stride=self.conv1_stride, bias=False) self.add_module(self.norm1_name, norm1) self.conv2 = build_conv_layer( conv_cfg, self.mid_channels, self.mid_channels, kernel_size=kernel_size, stride=self.conv2_stride, padding=kernel_size // 2, groups=groups, dilation=dilation, bias=False) self.add_module(self.norm2_name, norm2) self.conv3 = build_conv_layer( conv_cfg, self.mid_channels, out_channels, kernel_size=1, bias=False) self.add_module(self.norm3_name, norm3) if attention: self.attention = ContextBlock(out_channels, max(1.0 / 16, 16.0 / out_channels)) else: self.attention = None self.relu = nn.ReLU(inplace=True) self.downsample = downsample @property def norm1(self): """nn.Module: the normalization layer named "norm1" """ return getattr(self, self.norm1_name) @property def norm2(self): """nn.Module: the normalization layer named "norm2" """ return getattr(self, self.norm2_name) @property def norm3(self): """nn.Module: the normalization layer named "norm3" """ return getattr(self, self.norm3_name) def forward(self, x): """Forward function.""" def _inner_forward(x): identity = x out = self.conv1(x) out = self.norm1(out) out = self.relu(out) out = self.conv2(out) out = self.norm2(out) out = self.relu(out) out = self.conv3(out) out = self.norm3(out) if self.attention is not None: out = self.attention(out) if self.downsample is not None: identity = self.downsample(x) out += identity return out if self.with_cp and x.requires_grad: out = cp.checkpoint(_inner_forward, x) else: out = _inner_forward(x) out = self.relu(out) return out def get_expansion(block, expansion=None): """Get the expansion of a residual block. The block expansion will be obtained by the following order: 1. If ``expansion`` is given, just return it. 2. If ``block`` has the attribute ``expansion``, then return ``block.expansion``. 3. Return the default value according the the block type: 4 for ``ViPNAS_Bottleneck``. Args: block (class): The block class. expansion (int | None): The given expansion ratio. Returns: int: The expansion of the block. """ if isinstance(expansion, int): assert expansion > 0 elif expansion is None: if hasattr(block, 'expansion'): expansion = block.expansion elif issubclass(block, ViPNAS_Bottleneck): expansion = 1 else: raise TypeError(f'expansion is not specified for {block.__name__}') else: raise TypeError('expansion must be an integer or None') return expansion class ViPNAS_ResLayer(nn.Sequential): """ViPNAS_ResLayer to build ResNet style backbone. Args: block (nn.Module): Residual block used to build ViPNAS ResLayer. num_blocks (int): Number of blocks. in_channels (int): Input channels of this block. out_channels (int): Output channels of this block. expansion (int, optional): The expansion for BasicBlock/Bottleneck. If not specified, it will firstly be obtained via ``block.expansion``. If the block has no attribute "expansion", the following default values will be used: 1 for BasicBlock and 4 for Bottleneck. Default: None. stride (int): stride of the first block. Default: 1. avg_down (bool): Use AvgPool instead of stride conv when downsampling in the bottleneck. Default: False conv_cfg (dict): dictionary to construct and config conv layer. Default: None norm_cfg (dict): dictionary to construct and config norm layer. Default: dict(type='BN') downsample_first (bool): Downsample at the first block or last block. False for Hourglass, True for ResNet. Default: True kernel_size (int): Kernel Size of the corresponding convolution layer searched in the block. groups (int): Group number of the corresponding convolution layer searched in the block. attention (bool): Whether to use attention module in the end of the block. """ def __init__(self, block, num_blocks, in_channels, out_channels, expansion=None, stride=1, avg_down=False, conv_cfg=None, norm_cfg=dict(type='BN'), downsample_first=True, kernel_size=3, groups=1, attention=False, **kwargs): # Protect mutable default arguments norm_cfg = copy.deepcopy(norm_cfg) self.block = block self.expansion = get_expansion(block, expansion) downsample = None if stride != 1 or in_channels != out_channels: downsample = [] conv_stride = stride if avg_down and stride != 1: conv_stride = 1 downsample.append( nn.AvgPool2d( kernel_size=stride, stride=stride, ceil_mode=True, count_include_pad=False)) downsample.extend([ build_conv_layer( conv_cfg, in_channels, out_channels, kernel_size=1, stride=conv_stride, bias=False), build_norm_layer(norm_cfg, out_channels)[1] ]) downsample = nn.Sequential(*downsample) layers = [] if downsample_first: layers.append( block( in_channels=in_channels, out_channels=out_channels, expansion=self.expansion, stride=stride, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg, kernel_size=kernel_size, groups=groups, attention=attention, **kwargs)) in_channels = out_channels for _ in range(1, num_blocks): layers.append( block( in_channels=in_channels, out_channels=out_channels, expansion=self.expansion, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, kernel_size=kernel_size, groups=groups, attention=attention, **kwargs)) else: # downsample_first=False is for HourglassModule for i in range(0, num_blocks - 1): layers.append( block( in_channels=in_channels, out_channels=in_channels, expansion=self.expansion, stride=1, conv_cfg=conv_cfg, norm_cfg=norm_cfg, kernel_size=kernel_size, groups=groups, attention=attention, **kwargs)) layers.append( block( in_channels=in_channels, out_channels=out_channels, expansion=self.expansion, stride=stride, downsample=downsample, conv_cfg=conv_cfg, norm_cfg=norm_cfg, kernel_size=kernel_size, groups=groups, attention=attention, **kwargs)) super().__init__(*layers) @BACKBONES.register_module() class ViPNAS_ResNet(BaseBackbone): """ViPNAS_ResNet backbone. "ViPNAS: Efficient Video Pose Estimation via Neural Architecture Search" More details can be found in the `paper `__ . Args: depth (int): Network depth, from {18, 34, 50, 101, 152}. in_channels (int): Number of input image channels. Default: 3. num_stages (int): Stages of the network. Default: 4. strides (Sequence[int]): Strides of the first block of each stage. Default: ``(1, 2, 2, 2)``. dilations (Sequence[int]): Dilation of each stage. Default: ``(1, 1, 1, 1)``. out_indices (Sequence[int]): Output from which stages. If only one stage is specified, a single tensor (feature map) is returned, otherwise multiple stages are specified, a tuple of tensors will be returned. Default: ``(3, )``. style (str): `pytorch` or `caffe`. If set to "pytorch", the stride-two layer is the 3x3 conv layer, otherwise the stride-two layer is the first 1x1 conv layer. deep_stem (bool): Replace 7x7 conv in input stem with 3 3x3 conv. Default: False. avg_down (bool): Use AvgPool instead of stride conv when downsampling in the bottleneck. Default: False. frozen_stages (int): Stages to be frozen (stop grad and set eval mode). -1 means not freezing any parameters. Default: -1. conv_cfg (dict | None): The config dict for conv layers. Default: None. norm_cfg (dict): The config dict for norm layers. norm_eval (bool): Whether to set norm layers to eval mode, namely, freeze running stats (mean and var). Note: Effect on Batch Norm and its variants only. Default: False. with_cp (bool): Use checkpoint or not. Using checkpoint will save some memory while slowing down the training speed. Default: False. zero_init_residual (bool): Whether to use zero init for last norm layer in resblocks to let them behave as identity. Default: True. wid (list(int)): Searched width config for each stage. expan (list(int)): Searched expansion ratio config for each stage. dep (list(int)): Searched depth config for each stage. ks (list(int)): Searched kernel size config for each stage. group (list(int)): Searched group number config for each stage. att (list(bool)): Searched attention config for each stage. """ arch_settings = { 50: ViPNAS_Bottleneck, } def __init__(self, depth, in_channels=3, num_stages=4, strides=(1, 2, 2, 2), dilations=(1, 1, 1, 1), out_indices=(3, ), style='pytorch', deep_stem=False, avg_down=False, frozen_stages=-1, conv_cfg=None, norm_cfg=dict(type='BN', requires_grad=True), norm_eval=False, with_cp=False, zero_init_residual=True, wid=[48, 80, 160, 304, 608], expan=[None, 1, 1, 1, 1], dep=[None, 4, 6, 7, 3], ks=[7, 3, 5, 5, 5], group=[None, 16, 16, 16, 16], att=[None, True, False, True, True]): # Protect mutable default arguments norm_cfg = copy.deepcopy(norm_cfg) super().__init__() if depth not in self.arch_settings: raise KeyError(f'invalid depth {depth} for resnet') self.depth = depth self.stem_channels = dep[0] self.num_stages = num_stages assert 1 <= num_stages <= 4 self.strides = strides self.dilations = dilations assert len(strides) == len(dilations) == num_stages self.out_indices = out_indices assert max(out_indices) < num_stages self.style = style self.deep_stem = deep_stem self.avg_down = avg_down self.frozen_stages = frozen_stages self.conv_cfg = conv_cfg self.norm_cfg = norm_cfg self.with_cp = with_cp self.norm_eval = norm_eval self.zero_init_residual = zero_init_residual self.block = self.arch_settings[depth] self.stage_blocks = dep[1:1 + num_stages] self._make_stem_layer(in_channels, wid[0], ks[0]) self.res_layers = [] _in_channels = wid[0] for i, num_blocks in enumerate(self.stage_blocks): expansion = get_expansion(self.block, expan[i + 1]) _out_channels = wid[i + 1] * expansion stride = strides[i] dilation = dilations[i] res_layer = self.make_res_layer( block=self.block, num_blocks=num_blocks, in_channels=_in_channels, out_channels=_out_channels, expansion=expansion, stride=stride, dilation=dilation, style=self.style, avg_down=self.avg_down, with_cp=with_cp, conv_cfg=conv_cfg, norm_cfg=norm_cfg, kernel_size=ks[i + 1], groups=group[i + 1], attention=att[i + 1]) _in_channels = _out_channels layer_name = f'layer{i + 1}' self.add_module(layer_name, res_layer) self.res_layers.append(layer_name) self._freeze_stages() self.feat_dim = res_layer[-1].out_channels def make_res_layer(self, **kwargs): """Make a ViPNAS ResLayer.""" return ViPNAS_ResLayer(**kwargs) @property def norm1(self): """nn.Module: the normalization layer named "norm1" """ return getattr(self, self.norm1_name) def _make_stem_layer(self, in_channels, stem_channels, kernel_size): """Make stem layer.""" if self.deep_stem: self.stem = nn.Sequential( ConvModule( in_channels, stem_channels // 2, kernel_size=3, stride=2, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, inplace=True), ConvModule( stem_channels // 2, stem_channels // 2, kernel_size=3, stride=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, inplace=True), ConvModule( stem_channels // 2, stem_channels, kernel_size=3, stride=1, padding=1, conv_cfg=self.conv_cfg, norm_cfg=self.norm_cfg, inplace=True)) else: self.conv1 = build_conv_layer( self.conv_cfg, in_channels, stem_channels, kernel_size=kernel_size, stride=2, padding=kernel_size // 2, bias=False) self.norm1_name, norm1 = build_norm_layer( self.norm_cfg, stem_channels, postfix=1) self.add_module(self.norm1_name, norm1) self.relu = nn.ReLU(inplace=True) self.maxpool = nn.MaxPool2d(kernel_size=3, stride=2, padding=1) def _freeze_stages(self): """Freeze parameters.""" if self.frozen_stages >= 0: if self.deep_stem: self.stem.eval() for param in self.stem.parameters(): param.requires_grad = False else: self.norm1.eval() for m in [self.conv1, self.norm1]: for param in m.parameters(): param.requires_grad = False for i in range(1, self.frozen_stages + 1): m = getattr(self, f'layer{i}') m.eval() for param in m.parameters(): param.requires_grad = False def init_weights(self, pretrained=None): """Initialize model weights.""" super().init_weights(pretrained) if pretrained is None: for m in self.modules(): if isinstance(m, nn.Conv2d): nn.init.normal_(m.weight, std=0.001) for name, _ in m.named_parameters(): if name in ['bias']: nn.init.constant_(m.bias, 0) elif isinstance(m, nn.BatchNorm2d): nn.init.constant_(m.weight, 1) nn.init.constant_(m.bias, 0) def forward(self, x): """Forward function.""" if self.deep_stem: x = self.stem(x) else: x = self.conv1(x) x = self.norm1(x) x = self.relu(x) x = self.maxpool(x) outs = [] for i, layer_name in enumerate(self.res_layers): res_layer = getattr(self, layer_name) x = res_layer(x) if i in self.out_indices: outs.append(x) if len(outs) == 1: return outs[0] return tuple(outs) def train(self, mode=True): """Convert the model into training mode.""" super().train(mode) self._freeze_stages() if mode and self.norm_eval: for m in self.modules(): # trick: eval have effect on BatchNorm only if isinstance(m, _BatchNorm): m.eval()