|
|
|
import copy |
|
|
|
import torch.nn as nn |
|
from mmcv.cnn import ConvModule, MaxPool2d, constant_init, normal_init |
|
from torch.nn.modules.batchnorm import _BatchNorm |
|
|
|
from mmpose.utils import get_root_logger |
|
from ..builder import BACKBONES |
|
from .base_backbone import BaseBackbone |
|
from .utils import load_checkpoint |
|
|
|
|
|
class HourglassAEModule(nn.Module): |
|
"""Modified Hourglass Module for HourglassNet_AE backbone. |
|
|
|
Generate module recursively and use BasicBlock as the base unit. |
|
|
|
Args: |
|
depth (int): Depth of current HourglassModule. |
|
stage_channels (list[int]): Feature channels of sub-modules in current |
|
and follow-up HourglassModule. |
|
norm_cfg (dict): Dictionary to construct and config norm layer. |
|
""" |
|
|
|
def __init__(self, |
|
depth, |
|
stage_channels, |
|
norm_cfg=dict(type='BN', requires_grad=True)): |
|
|
|
norm_cfg = copy.deepcopy(norm_cfg) |
|
super().__init__() |
|
|
|
self.depth = depth |
|
|
|
cur_channel = stage_channels[0] |
|
next_channel = stage_channels[1] |
|
|
|
self.up1 = ConvModule( |
|
cur_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg) |
|
|
|
self.pool1 = MaxPool2d(2, 2) |
|
|
|
self.low1 = ConvModule( |
|
cur_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg) |
|
|
|
if self.depth > 1: |
|
self.low2 = HourglassAEModule(depth - 1, stage_channels[1:]) |
|
else: |
|
self.low2 = ConvModule( |
|
next_channel, next_channel, 3, padding=1, norm_cfg=norm_cfg) |
|
|
|
self.low3 = ConvModule( |
|
next_channel, cur_channel, 3, padding=1, norm_cfg=norm_cfg) |
|
|
|
self.up2 = nn.UpsamplingNearest2d(scale_factor=2) |
|
|
|
def forward(self, x): |
|
"""Model forward function.""" |
|
up1 = self.up1(x) |
|
pool1 = self.pool1(x) |
|
low1 = self.low1(pool1) |
|
low2 = self.low2(low1) |
|
low3 = self.low3(low2) |
|
up2 = self.up2(low3) |
|
return up1 + up2 |
|
|
|
|
|
@BACKBONES.register_module() |
|
class HourglassAENet(BaseBackbone): |
|
"""Hourglass-AE Network proposed by Newell et al. |
|
|
|
Associative Embedding: End-to-End Learning for Joint |
|
Detection and Grouping. |
|
|
|
More details can be found in the `paper |
|
<https://arxiv.org/abs/1611.05424>`__ . |
|
|
|
Args: |
|
downsample_times (int): Downsample times in a HourglassModule. |
|
num_stacks (int): Number of HourglassModule modules stacked, |
|
1 for Hourglass-52, 2 for Hourglass-104. |
|
stage_channels (list[int]): Feature channel of each sub-module in a |
|
HourglassModule. |
|
stage_blocks (list[int]): Number of sub-modules stacked in a |
|
HourglassModule. |
|
feat_channels (int): Feature channel of conv after a HourglassModule. |
|
norm_cfg (dict): Dictionary to construct and config norm layer. |
|
|
|
Example: |
|
>>> from mmpose.models import HourglassAENet |
|
>>> import torch |
|
>>> self = HourglassAENet() |
|
>>> self.eval() |
|
>>> inputs = torch.rand(1, 3, 512, 512) |
|
>>> level_outputs = self.forward(inputs) |
|
>>> for level_output in level_outputs: |
|
... print(tuple(level_output.shape)) |
|
(1, 34, 128, 128) |
|
""" |
|
|
|
def __init__(self, |
|
downsample_times=4, |
|
num_stacks=1, |
|
out_channels=34, |
|
stage_channels=(256, 384, 512, 640, 768), |
|
feat_channels=256, |
|
norm_cfg=dict(type='BN', requires_grad=True)): |
|
|
|
norm_cfg = copy.deepcopy(norm_cfg) |
|
super().__init__() |
|
|
|
self.num_stacks = num_stacks |
|
assert self.num_stacks >= 1 |
|
assert len(stage_channels) > downsample_times |
|
|
|
cur_channels = stage_channels[0] |
|
|
|
self.stem = nn.Sequential( |
|
ConvModule(3, 64, 7, padding=3, stride=2, norm_cfg=norm_cfg), |
|
ConvModule(64, 128, 3, padding=1, norm_cfg=norm_cfg), |
|
MaxPool2d(2, 2), |
|
ConvModule(128, 128, 3, padding=1, norm_cfg=norm_cfg), |
|
ConvModule(128, feat_channels, 3, padding=1, norm_cfg=norm_cfg), |
|
) |
|
|
|
self.hourglass_modules = nn.ModuleList([ |
|
nn.Sequential( |
|
HourglassAEModule( |
|
downsample_times, stage_channels, norm_cfg=norm_cfg), |
|
ConvModule( |
|
feat_channels, |
|
feat_channels, |
|
3, |
|
padding=1, |
|
norm_cfg=norm_cfg), |
|
ConvModule( |
|
feat_channels, |
|
feat_channels, |
|
3, |
|
padding=1, |
|
norm_cfg=norm_cfg)) for _ in range(num_stacks) |
|
]) |
|
|
|
self.out_convs = nn.ModuleList([ |
|
ConvModule( |
|
cur_channels, |
|
out_channels, |
|
1, |
|
padding=0, |
|
norm_cfg=None, |
|
act_cfg=None) for _ in range(num_stacks) |
|
]) |
|
|
|
self.remap_out_convs = nn.ModuleList([ |
|
ConvModule( |
|
out_channels, |
|
feat_channels, |
|
1, |
|
norm_cfg=norm_cfg, |
|
act_cfg=None) for _ in range(num_stacks - 1) |
|
]) |
|
|
|
self.remap_feature_convs = nn.ModuleList([ |
|
ConvModule( |
|
feat_channels, |
|
feat_channels, |
|
1, |
|
norm_cfg=norm_cfg, |
|
act_cfg=None) for _ in range(num_stacks - 1) |
|
]) |
|
|
|
self.relu = nn.ReLU(inplace=True) |
|
|
|
def init_weights(self, pretrained=None): |
|
"""Initialize the weights in backbone. |
|
|
|
Args: |
|
pretrained (str, optional): Path to pre-trained weights. |
|
Defaults to None. |
|
""" |
|
if isinstance(pretrained, str): |
|
logger = get_root_logger() |
|
load_checkpoint(self, pretrained, strict=False, logger=logger) |
|
elif pretrained is None: |
|
for m in self.modules(): |
|
if isinstance(m, nn.Conv2d): |
|
normal_init(m, std=0.001) |
|
elif isinstance(m, (_BatchNorm, nn.GroupNorm)): |
|
constant_init(m, 1) |
|
else: |
|
raise TypeError('pretrained must be a str or None') |
|
|
|
def forward(self, x): |
|
"""Model forward function.""" |
|
inter_feat = self.stem(x) |
|
out_feats = [] |
|
|
|
for ind in range(self.num_stacks): |
|
single_hourglass = self.hourglass_modules[ind] |
|
out_conv = self.out_convs[ind] |
|
|
|
hourglass_feat = single_hourglass(inter_feat) |
|
out_feat = out_conv(hourglass_feat) |
|
out_feats.append(out_feat) |
|
|
|
if ind < self.num_stacks - 1: |
|
inter_feat = inter_feat + self.remap_out_convs[ind]( |
|
out_feat) + self.remap_feature_convs[ind]( |
|
hourglass_feat) |
|
|
|
return out_feats |
|
|