|
|
|
|
|
|
|
|
|
|
|
import mmcv |
|
import torch |
|
import torch.nn as nn |
|
import torch.nn.functional as F |
|
import torch.utils.checkpoint as cp |
|
from mmcv.cnn import (ConvModule, DepthwiseSeparableConvModule, |
|
build_conv_layer, build_norm_layer, constant_init, |
|
normal_init) |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
from mmpose.utils import get_root_logger |
|
from ..builder import BACKBONES |
|
from .utils import channel_shuffle, load_checkpoint |
|
|
|
|
|
class SpatialWeighting(nn.Module): |
|
"""Spatial weighting module. |
|
|
|
Args: |
|
channels (int): The channels of the module. |
|
ratio (int): channel reduction ratio. |
|
conv_cfg (dict): Config dict for convolution layer. |
|
Default: None, which means using conv2d. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: None. |
|
act_cfg (dict): Config dict for activation layer. |
|
Default: (dict(type='ReLU'), dict(type='Sigmoid')). |
|
The last ConvModule uses Sigmoid by default. |
|
""" |
|
|
|
def __init__(self, |
|
channels, |
|
ratio=16, |
|
conv_cfg=None, |
|
norm_cfg=None, |
|
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): |
|
super().__init__() |
|
if isinstance(act_cfg, dict): |
|
act_cfg = (act_cfg, act_cfg) |
|
assert len(act_cfg) == 2 |
|
assert mmcv.is_tuple_of(act_cfg, dict) |
|
self.global_avgpool = nn.AdaptiveAvgPool2d(1) |
|
self.conv1 = ConvModule( |
|
in_channels=channels, |
|
out_channels=int(channels / ratio), |
|
kernel_size=1, |
|
stride=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg[0]) |
|
self.conv2 = ConvModule( |
|
in_channels=int(channels / ratio), |
|
out_channels=channels, |
|
kernel_size=1, |
|
stride=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg[1]) |
|
|
|
def forward(self, x): |
|
out = self.global_avgpool(x) |
|
out = self.conv1(out) |
|
out = self.conv2(out) |
|
return x * out |
|
|
|
|
|
class CrossResolutionWeighting(nn.Module): |
|
"""Cross-resolution channel weighting module. |
|
|
|
Args: |
|
channels (int): The channels of the module. |
|
ratio (int): channel reduction ratio. |
|
conv_cfg (dict): Config dict for convolution layer. |
|
Default: None, which means using conv2d. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: None. |
|
act_cfg (dict): Config dict for activation layer. |
|
Default: (dict(type='ReLU'), dict(type='Sigmoid')). |
|
The last ConvModule uses Sigmoid by default. |
|
""" |
|
|
|
def __init__(self, |
|
channels, |
|
ratio=16, |
|
conv_cfg=None, |
|
norm_cfg=None, |
|
act_cfg=(dict(type='ReLU'), dict(type='Sigmoid'))): |
|
super().__init__() |
|
if isinstance(act_cfg, dict): |
|
act_cfg = (act_cfg, act_cfg) |
|
assert len(act_cfg) == 2 |
|
assert mmcv.is_tuple_of(act_cfg, dict) |
|
self.channels = channels |
|
total_channel = sum(channels) |
|
self.conv1 = ConvModule( |
|
in_channels=total_channel, |
|
out_channels=int(total_channel / ratio), |
|
kernel_size=1, |
|
stride=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg[0]) |
|
self.conv2 = ConvModule( |
|
in_channels=int(total_channel / ratio), |
|
out_channels=total_channel, |
|
kernel_size=1, |
|
stride=1, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg[1]) |
|
|
|
def forward(self, x): |
|
mini_size = x[-1].size()[-2:] |
|
out = [F.adaptive_avg_pool2d(s, mini_size) for s in x[:-1]] + [x[-1]] |
|
out = torch.cat(out, dim=1) |
|
out = self.conv1(out) |
|
out = self.conv2(out) |
|
out = torch.split(out, self.channels, dim=1) |
|
out = [ |
|
s * F.interpolate(a, size=s.size()[-2:], mode='nearest') |
|
for s, a in zip(x, out) |
|
] |
|
return out |
|
|
|
|
|
class ConditionalChannelWeighting(nn.Module): |
|
"""Conditional channel weighting block. |
|
|
|
Args: |
|
in_channels (int): The input channels of the block. |
|
stride (int): Stride of the 3x3 convolution layer. |
|
reduce_ratio (int): channel reduction ratio. |
|
conv_cfg (dict): Config dict for convolution layer. |
|
Default: None, which means using conv2d. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. Default: False. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
stride, |
|
reduce_ratio, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN'), |
|
with_cp=False): |
|
super().__init__() |
|
self.with_cp = with_cp |
|
self.stride = stride |
|
assert stride in [1, 2] |
|
|
|
branch_channels = [channel // 2 for channel in in_channels] |
|
|
|
self.cross_resolution_weighting = CrossResolutionWeighting( |
|
branch_channels, |
|
ratio=reduce_ratio, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg) |
|
|
|
self.depthwise_convs = nn.ModuleList([ |
|
ConvModule( |
|
channel, |
|
channel, |
|
kernel_size=3, |
|
stride=self.stride, |
|
padding=1, |
|
groups=channel, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=None) for channel in branch_channels |
|
]) |
|
|
|
self.spatial_weighting = nn.ModuleList([ |
|
SpatialWeighting(channels=channel, ratio=4) |
|
for channel in branch_channels |
|
]) |
|
|
|
def forward(self, x): |
|
|
|
def _inner_forward(x): |
|
x = [s.chunk(2, dim=1) for s in x] |
|
x1 = [s[0] for s in x] |
|
x2 = [s[1] for s in x] |
|
|
|
x2 = self.cross_resolution_weighting(x2) |
|
x2 = [dw(s) for s, dw in zip(x2, self.depthwise_convs)] |
|
x2 = [sw(s) for s, sw in zip(x2, self.spatial_weighting)] |
|
|
|
out = [torch.cat([s1, s2], dim=1) for s1, s2 in zip(x1, x2)] |
|
out = [channel_shuffle(s, 2) for s in out] |
|
|
|
return out |
|
|
|
if self.with_cp and x.requires_grad: |
|
out = cp.checkpoint(_inner_forward, x) |
|
else: |
|
out = _inner_forward(x) |
|
|
|
return out |
|
|
|
|
|
class Stem(nn.Module): |
|
"""Stem network block. |
|
|
|
Args: |
|
in_channels (int): The input channels of the block. |
|
stem_channels (int): Output channels of the stem layer. |
|
out_channels (int): The output channels of the block. |
|
expand_ratio (int): adjusts number of channels of the hidden layer |
|
in InvertedResidual by this amount. |
|
conv_cfg (dict): Config dict for convolution layer. |
|
Default: None, which means using conv2d. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. Default: False. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
stem_channels, |
|
out_channels, |
|
expand_ratio, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN'), |
|
with_cp=False): |
|
super().__init__() |
|
self.in_channels = in_channels |
|
self.out_channels = out_channels |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.with_cp = with_cp |
|
|
|
self.conv1 = ConvModule( |
|
in_channels=in_channels, |
|
out_channels=stem_channels, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=dict(type='ReLU')) |
|
|
|
mid_channels = int(round(stem_channels * expand_ratio)) |
|
branch_channels = stem_channels // 2 |
|
if stem_channels == self.out_channels: |
|
inc_channels = self.out_channels - branch_channels |
|
else: |
|
inc_channels = self.out_channels - stem_channels |
|
|
|
self.branch1 = nn.Sequential( |
|
ConvModule( |
|
branch_channels, |
|
branch_channels, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
groups=branch_channels, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=None), |
|
ConvModule( |
|
branch_channels, |
|
inc_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=dict(type='ReLU')), |
|
) |
|
|
|
self.expand_conv = ConvModule( |
|
branch_channels, |
|
mid_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=dict(type='ReLU')) |
|
self.depthwise_conv = ConvModule( |
|
mid_channels, |
|
mid_channels, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
groups=mid_channels, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=None) |
|
self.linear_conv = ConvModule( |
|
mid_channels, |
|
branch_channels |
|
if stem_channels == self.out_channels else stem_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=dict(type='ReLU')) |
|
|
|
def forward(self, x): |
|
|
|
def _inner_forward(x): |
|
x = self.conv1(x) |
|
x1, x2 = x.chunk(2, dim=1) |
|
|
|
x2 = self.expand_conv(x2) |
|
x2 = self.depthwise_conv(x2) |
|
x2 = self.linear_conv(x2) |
|
|
|
out = torch.cat((self.branch1(x1), x2), dim=1) |
|
|
|
out = channel_shuffle(out, 2) |
|
|
|
return out |
|
|
|
if self.with_cp and x.requires_grad: |
|
out = cp.checkpoint(_inner_forward, x) |
|
else: |
|
out = _inner_forward(x) |
|
|
|
return out |
|
|
|
|
|
class IterativeHead(nn.Module): |
|
"""Extra iterative head for feature learning. |
|
|
|
Args: |
|
in_channels (int): The input channels of the block. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
""" |
|
|
|
def __init__(self, in_channels, norm_cfg=dict(type='BN')): |
|
super().__init__() |
|
projects = [] |
|
num_branchs = len(in_channels) |
|
self.in_channels = in_channels[::-1] |
|
|
|
for i in range(num_branchs): |
|
if i != num_branchs - 1: |
|
projects.append( |
|
DepthwiseSeparableConvModule( |
|
in_channels=self.in_channels[i], |
|
out_channels=self.in_channels[i + 1], |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
norm_cfg=norm_cfg, |
|
act_cfg=dict(type='ReLU'), |
|
dw_act_cfg=None, |
|
pw_act_cfg=dict(type='ReLU'))) |
|
else: |
|
projects.append( |
|
DepthwiseSeparableConvModule( |
|
in_channels=self.in_channels[i], |
|
out_channels=self.in_channels[i], |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
norm_cfg=norm_cfg, |
|
act_cfg=dict(type='ReLU'), |
|
dw_act_cfg=None, |
|
pw_act_cfg=dict(type='ReLU'))) |
|
self.projects = nn.ModuleList(projects) |
|
|
|
def forward(self, x): |
|
x = x[::-1] |
|
|
|
y = [] |
|
last_x = None |
|
for i, s in enumerate(x): |
|
if last_x is not None: |
|
last_x = F.interpolate( |
|
last_x, |
|
size=s.size()[-2:], |
|
mode='bilinear', |
|
align_corners=True) |
|
s = s + last_x |
|
s = self.projects[i](s) |
|
y.append(s) |
|
last_x = s |
|
|
|
return y[::-1] |
|
|
|
|
|
class ShuffleUnit(nn.Module): |
|
"""InvertedResidual block for ShuffleNetV2 backbone. |
|
|
|
Args: |
|
in_channels (int): The input channels of the block. |
|
out_channels (int): The output channels of the block. |
|
stride (int): Stride of the 3x3 convolution layer. Default: 1 |
|
conv_cfg (dict): Config dict for convolution layer. |
|
Default: None, which means using conv2d. |
|
norm_cfg (dict): Config dict for normalization layer. |
|
Default: dict(type='BN'). |
|
act_cfg (dict): Config dict for activation layer. |
|
Default: dict(type='ReLU'). |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. Default: False. |
|
""" |
|
|
|
def __init__(self, |
|
in_channels, |
|
out_channels, |
|
stride=1, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN'), |
|
act_cfg=dict(type='ReLU'), |
|
with_cp=False): |
|
super().__init__() |
|
self.stride = stride |
|
self.with_cp = with_cp |
|
|
|
branch_features = out_channels // 2 |
|
if self.stride == 1: |
|
assert in_channels == branch_features * 2, ( |
|
f'in_channels ({in_channels}) should equal to ' |
|
f'branch_features * 2 ({branch_features * 2}) ' |
|
'when stride is 1') |
|
|
|
if in_channels != branch_features * 2: |
|
assert self.stride != 1, ( |
|
f'stride ({self.stride}) should not equal 1 when ' |
|
f'in_channels != branch_features * 2') |
|
|
|
if self.stride > 1: |
|
self.branch1 = nn.Sequential( |
|
ConvModule( |
|
in_channels, |
|
in_channels, |
|
kernel_size=3, |
|
stride=self.stride, |
|
padding=1, |
|
groups=in_channels, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=None), |
|
ConvModule( |
|
in_channels, |
|
branch_features, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg), |
|
) |
|
|
|
self.branch2 = nn.Sequential( |
|
ConvModule( |
|
in_channels if (self.stride > 1) else branch_features, |
|
branch_features, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg), |
|
ConvModule( |
|
branch_features, |
|
branch_features, |
|
kernel_size=3, |
|
stride=self.stride, |
|
padding=1, |
|
groups=branch_features, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=None), |
|
ConvModule( |
|
branch_features, |
|
branch_features, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
conv_cfg=conv_cfg, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg)) |
|
|
|
def forward(self, x): |
|
|
|
def _inner_forward(x): |
|
if self.stride > 1: |
|
out = torch.cat((self.branch1(x), self.branch2(x)), dim=1) |
|
else: |
|
x1, x2 = x.chunk(2, dim=1) |
|
out = torch.cat((x1, self.branch2(x2)), dim=1) |
|
|
|
out = channel_shuffle(out, 2) |
|
|
|
return out |
|
|
|
if self.with_cp and x.requires_grad: |
|
out = cp.checkpoint(_inner_forward, x) |
|
else: |
|
out = _inner_forward(x) |
|
|
|
return out |
|
|
|
|
|
class LiteHRModule(nn.Module): |
|
"""High-Resolution Module for LiteHRNet. |
|
|
|
It contains conditional channel weighting blocks and |
|
shuffle blocks. |
|
|
|
|
|
Args: |
|
num_branches (int): Number of branches in the module. |
|
num_blocks (int): Number of blocks in the module. |
|
in_channels (list(int)): Number of input image channels. |
|
reduce_ratio (int): Channel reduction ratio. |
|
module_type (str): 'LITE' or 'NAIVE' |
|
multiscale_output (bool): Whether to output multi-scale features. |
|
with_fuse (bool): Whether to use fuse layers. |
|
conv_cfg (dict): dictionary to construct and config conv layer. |
|
norm_cfg (dict): dictionary to construct and config norm layer. |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. |
|
""" |
|
|
|
def __init__( |
|
self, |
|
num_branches, |
|
num_blocks, |
|
in_channels, |
|
reduce_ratio, |
|
module_type, |
|
multiscale_output=False, |
|
with_fuse=True, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN'), |
|
with_cp=False, |
|
): |
|
super().__init__() |
|
self._check_branches(num_branches, in_channels) |
|
|
|
self.in_channels = in_channels |
|
self.num_branches = num_branches |
|
|
|
self.module_type = module_type |
|
self.multiscale_output = multiscale_output |
|
self.with_fuse = with_fuse |
|
self.norm_cfg = norm_cfg |
|
self.conv_cfg = conv_cfg |
|
self.with_cp = with_cp |
|
|
|
if self.module_type.upper() == 'LITE': |
|
self.layers = self._make_weighting_blocks(num_blocks, reduce_ratio) |
|
elif self.module_type.upper() == 'NAIVE': |
|
self.layers = self._make_naive_branches(num_branches, num_blocks) |
|
else: |
|
raise ValueError("module_type should be either 'LITE' or 'NAIVE'.") |
|
if self.with_fuse: |
|
self.fuse_layers = self._make_fuse_layers() |
|
self.relu = nn.ReLU() |
|
|
|
def _check_branches(self, num_branches, in_channels): |
|
"""Check input to avoid ValueError.""" |
|
if num_branches != len(in_channels): |
|
error_msg = f'NUM_BRANCHES({num_branches}) ' \ |
|
f'!= NUM_INCHANNELS({len(in_channels)})' |
|
raise ValueError(error_msg) |
|
|
|
def _make_weighting_blocks(self, num_blocks, reduce_ratio, stride=1): |
|
"""Make channel weighting blocks.""" |
|
layers = [] |
|
for i in range(num_blocks): |
|
layers.append( |
|
ConditionalChannelWeighting( |
|
self.in_channels, |
|
stride=stride, |
|
reduce_ratio=reduce_ratio, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
with_cp=self.with_cp)) |
|
|
|
return nn.Sequential(*layers) |
|
|
|
def _make_one_branch(self, branch_index, num_blocks, stride=1): |
|
"""Make one branch.""" |
|
layers = [] |
|
layers.append( |
|
ShuffleUnit( |
|
self.in_channels[branch_index], |
|
self.in_channels[branch_index], |
|
stride=stride, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=dict(type='ReLU'), |
|
with_cp=self.with_cp)) |
|
for i in range(1, num_blocks): |
|
layers.append( |
|
ShuffleUnit( |
|
self.in_channels[branch_index], |
|
self.in_channels[branch_index], |
|
stride=1, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
act_cfg=dict(type='ReLU'), |
|
with_cp=self.with_cp)) |
|
|
|
return nn.Sequential(*layers) |
|
|
|
def _make_naive_branches(self, num_branches, num_blocks): |
|
"""Make branches.""" |
|
branches = [] |
|
|
|
for i in range(num_branches): |
|
branches.append(self._make_one_branch(i, num_blocks)) |
|
|
|
return nn.ModuleList(branches) |
|
|
|
def _make_fuse_layers(self): |
|
"""Make fuse layer.""" |
|
if self.num_branches == 1: |
|
return None |
|
|
|
num_branches = self.num_branches |
|
in_channels = self.in_channels |
|
fuse_layers = [] |
|
num_out_branches = num_branches if self.multiscale_output else 1 |
|
for i in range(num_out_branches): |
|
fuse_layer = [] |
|
for j in range(num_branches): |
|
if j > i: |
|
fuse_layer.append( |
|
nn.Sequential( |
|
build_conv_layer( |
|
self.conv_cfg, |
|
in_channels[j], |
|
in_channels[i], |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, in_channels[i])[1], |
|
nn.Upsample( |
|
scale_factor=2**(j - i), mode='nearest'))) |
|
elif j == i: |
|
fuse_layer.append(None) |
|
else: |
|
conv_downsamples = [] |
|
for k in range(i - j): |
|
if k == i - j - 1: |
|
conv_downsamples.append( |
|
nn.Sequential( |
|
build_conv_layer( |
|
self.conv_cfg, |
|
in_channels[j], |
|
in_channels[j], |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
groups=in_channels[j], |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, |
|
in_channels[j])[1], |
|
build_conv_layer( |
|
self.conv_cfg, |
|
in_channels[j], |
|
in_channels[i], |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, |
|
in_channels[i])[1])) |
|
else: |
|
conv_downsamples.append( |
|
nn.Sequential( |
|
build_conv_layer( |
|
self.conv_cfg, |
|
in_channels[j], |
|
in_channels[j], |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
groups=in_channels[j], |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, |
|
in_channels[j])[1], |
|
build_conv_layer( |
|
self.conv_cfg, |
|
in_channels[j], |
|
in_channels[j], |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, |
|
in_channels[j])[1], |
|
nn.ReLU(inplace=True))) |
|
fuse_layer.append(nn.Sequential(*conv_downsamples)) |
|
fuse_layers.append(nn.ModuleList(fuse_layer)) |
|
|
|
return nn.ModuleList(fuse_layers) |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
if self.num_branches == 1: |
|
return [self.layers[0](x[0])] |
|
|
|
if self.module_type.upper() == 'LITE': |
|
out = self.layers(x) |
|
elif self.module_type.upper() == 'NAIVE': |
|
for i in range(self.num_branches): |
|
x[i] = self.layers[i](x[i]) |
|
out = x |
|
|
|
if self.with_fuse: |
|
out_fuse = [] |
|
for i in range(len(self.fuse_layers)): |
|
|
|
y = out[0] if i == 0 else self.fuse_layers[i][0](out[0]) |
|
for j in range(self.num_branches): |
|
if i == j: |
|
y += out[j] |
|
else: |
|
y += self.fuse_layers[i][j](out[j]) |
|
out_fuse.append(self.relu(y)) |
|
out = out_fuse |
|
if not self.multiscale_output: |
|
out = [out[0]] |
|
return out |
|
|
|
|
|
@BACKBONES.register_module() |
|
class LiteHRNet(nn.Module): |
|
"""Lite-HRNet backbone. |
|
|
|
`Lite-HRNet: A Lightweight High-Resolution Network |
|
<https://arxiv.org/abs/2104.06403>`_. |
|
|
|
Code adapted from 'https://github.com/HRNet/Lite-HRNet'. |
|
|
|
Args: |
|
extra (dict): detailed configuration for each stage of HRNet. |
|
in_channels (int): Number of input image channels. Default: 3. |
|
conv_cfg (dict): dictionary to construct and config conv layer. |
|
norm_cfg (dict): dictionary to construct and config norm layer. |
|
norm_eval (bool): Whether to set norm layers to eval mode, namely, |
|
freeze running stats (mean and var). Note: Effect on Batch Norm |
|
and its variants only. Default: False |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. |
|
|
|
Example: |
|
>>> from mmpose.models import LiteHRNet |
|
>>> import torch |
|
>>> extra=dict( |
|
>>> stem=dict(stem_channels=32, out_channels=32, expand_ratio=1), |
|
>>> num_stages=3, |
|
>>> stages_spec=dict( |
|
>>> num_modules=(2, 4, 2), |
|
>>> num_branches=(2, 3, 4), |
|
>>> num_blocks=(2, 2, 2), |
|
>>> module_type=('LITE', 'LITE', 'LITE'), |
|
>>> with_fuse=(True, True, True), |
|
>>> reduce_ratios=(8, 8, 8), |
|
>>> num_channels=( |
|
>>> (40, 80), |
|
>>> (40, 80, 160), |
|
>>> (40, 80, 160, 320), |
|
>>> )), |
|
>>> with_head=False) |
|
>>> self = LiteHRNet(extra, in_channels=1) |
|
>>> self.eval() |
|
>>> inputs = torch.rand(1, 1, 32, 32) |
|
>>> level_outputs = self.forward(inputs) |
|
>>> for level_out in level_outputs: |
|
... print(tuple(level_out.shape)) |
|
(1, 40, 8, 8) |
|
""" |
|
|
|
def __init__(self, |
|
extra, |
|
in_channels=3, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN'), |
|
norm_eval=False, |
|
with_cp=False): |
|
super().__init__() |
|
self.extra = extra |
|
self.conv_cfg = conv_cfg |
|
self.norm_cfg = norm_cfg |
|
self.norm_eval = norm_eval |
|
self.with_cp = with_cp |
|
|
|
self.stem = Stem( |
|
in_channels, |
|
stem_channels=self.extra['stem']['stem_channels'], |
|
out_channels=self.extra['stem']['out_channels'], |
|
expand_ratio=self.extra['stem']['expand_ratio'], |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg) |
|
|
|
self.num_stages = self.extra['num_stages'] |
|
self.stages_spec = self.extra['stages_spec'] |
|
|
|
num_channels_last = [ |
|
self.stem.out_channels, |
|
] |
|
for i in range(self.num_stages): |
|
num_channels = self.stages_spec['num_channels'][i] |
|
num_channels = [num_channels[i] for i in range(len(num_channels))] |
|
setattr( |
|
self, f'transition{i}', |
|
self._make_transition_layer(num_channels_last, num_channels)) |
|
|
|
stage, num_channels_last = self._make_stage( |
|
self.stages_spec, i, num_channels, multiscale_output=True) |
|
setattr(self, f'stage{i}', stage) |
|
|
|
self.with_head = self.extra['with_head'] |
|
if self.with_head: |
|
self.head_layer = IterativeHead( |
|
in_channels=num_channels_last, |
|
norm_cfg=self.norm_cfg, |
|
) |
|
|
|
def _make_transition_layer(self, num_channels_pre_layer, |
|
num_channels_cur_layer): |
|
"""Make transition layer.""" |
|
num_branches_cur = len(num_channels_cur_layer) |
|
num_branches_pre = len(num_channels_pre_layer) |
|
|
|
transition_layers = [] |
|
for i in range(num_branches_cur): |
|
if i < num_branches_pre: |
|
if num_channels_cur_layer[i] != num_channels_pre_layer[i]: |
|
transition_layers.append( |
|
nn.Sequential( |
|
build_conv_layer( |
|
self.conv_cfg, |
|
num_channels_pre_layer[i], |
|
num_channels_pre_layer[i], |
|
kernel_size=3, |
|
stride=1, |
|
padding=1, |
|
groups=num_channels_pre_layer[i], |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, |
|
num_channels_pre_layer[i])[1], |
|
build_conv_layer( |
|
self.conv_cfg, |
|
num_channels_pre_layer[i], |
|
num_channels_cur_layer[i], |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, |
|
num_channels_cur_layer[i])[1], |
|
nn.ReLU())) |
|
else: |
|
transition_layers.append(None) |
|
else: |
|
conv_downsamples = [] |
|
for j in range(i + 1 - num_branches_pre): |
|
in_channels = num_channels_pre_layer[-1] |
|
out_channels = num_channels_cur_layer[i] \ |
|
if j == i - num_branches_pre else in_channels |
|
conv_downsamples.append( |
|
nn.Sequential( |
|
build_conv_layer( |
|
self.conv_cfg, |
|
in_channels, |
|
in_channels, |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
groups=in_channels, |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, in_channels)[1], |
|
build_conv_layer( |
|
self.conv_cfg, |
|
in_channels, |
|
out_channels, |
|
kernel_size=1, |
|
stride=1, |
|
padding=0, |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, out_channels)[1], |
|
nn.ReLU())) |
|
transition_layers.append(nn.Sequential(*conv_downsamples)) |
|
|
|
return nn.ModuleList(transition_layers) |
|
|
|
def _make_stage(self, |
|
stages_spec, |
|
stage_index, |
|
in_channels, |
|
multiscale_output=True): |
|
num_modules = stages_spec['num_modules'][stage_index] |
|
num_branches = stages_spec['num_branches'][stage_index] |
|
num_blocks = stages_spec['num_blocks'][stage_index] |
|
reduce_ratio = stages_spec['reduce_ratios'][stage_index] |
|
with_fuse = stages_spec['with_fuse'][stage_index] |
|
module_type = stages_spec['module_type'][stage_index] |
|
|
|
modules = [] |
|
for i in range(num_modules): |
|
|
|
if not multiscale_output and i == num_modules - 1: |
|
reset_multiscale_output = False |
|
else: |
|
reset_multiscale_output = True |
|
|
|
modules.append( |
|
LiteHRModule( |
|
num_branches, |
|
num_blocks, |
|
in_channels, |
|
reduce_ratio, |
|
module_type, |
|
multiscale_output=reset_multiscale_output, |
|
with_fuse=with_fuse, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
with_cp=self.with_cp)) |
|
in_channels = modules[-1].in_channels |
|
|
|
return nn.Sequential(*modules), in_channels |
|
|
|
def init_weights(self, pretrained=None): |
|
"""Initialize the weights in backbone. |
|
|
|
Args: |
|
pretrained (str, optional): Path to pre-trained weights. |
|
Defaults to None. |
|
""" |
|
if isinstance(pretrained, str): |
|
logger = get_root_logger() |
|
load_checkpoint(self, pretrained, strict=False, logger=logger) |
|
elif pretrained is None: |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
normal_init(m, std=0.001) |
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): |
|
constant_init(m, 1) |
|
else: |
|
raise TypeError('pretrained must be a str or None') |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
x = self.stem(x) |
|
|
|
y_list = [x] |
|
for i in range(self.num_stages): |
|
x_list = [] |
|
transition = getattr(self, f'transition{i}') |
|
for j in range(self.stages_spec['num_branches'][i]): |
|
if transition[j]: |
|
if j >= len(y_list): |
|
x_list.append(transition[j](y_list[-1])) |
|
else: |
|
x_list.append(transition[j](y_list[j])) |
|
else: |
|
x_list.append(y_list[j]) |
|
y_list = getattr(self, f'stage{i}')(x_list) |
|
|
|
x = y_list |
|
if self.with_head: |
|
x = self.head_layer(x) |
|
|
|
return [x[0]] |
|
|
|
def train(self, mode=True): |
|
"""Convert the model into training mode.""" |
|
super().train(mode) |
|
if mode and self.norm_eval: |
|
for m in self.modules(): |
|
if isinstance(m, _BatchNorm): |
|
m.eval() |
|
|