|
|
|
|
|
import math |
|
|
|
import torch |
|
import torch.nn as nn |
|
|
|
from mmcv.cnn import (build_activation_layer, build_conv_layer, |
|
build_norm_layer, trunc_normal_init) |
|
from mmcv.cnn.bricks.transformer import build_dropout |
|
from mmcv.runner import BaseModule |
|
from torch.nn.functional import pad |
|
|
|
from ..builder import BACKBONES |
|
from .hrnet import Bottleneck, HRModule, HRNet |
|
|
|
|
|
def nlc_to_nchw(x, hw_shape): |
|
"""Convert [N, L, C] shape tensor to [N, C, H, W] shape tensor. |
|
|
|
Args: |
|
x (Tensor): The input tensor of shape [N, L, C] before conversion. |
|
hw_shape (Sequence[int]): The height and width of output feature map. |
|
|
|
Returns: |
|
Tensor: The output tensor of shape [N, C, H, W] after conversion. |
|
""" |
|
H, W = hw_shape |
|
assert len(x.shape) == 3 |
|
B, L, C = x.shape |
|
assert L == H * W, 'The seq_len doesn\'t match H, W' |
|
return x.transpose(1, 2).reshape(B, C, H, W) |
|
|
|
|
|
def nchw_to_nlc(x): |
|
"""Flatten [N, C, H, W] shape tensor to [N, L, C] shape tensor. |
|
|
|
Args: |
|
x (Tensor): The input tensor of shape [N, C, H, W] before conversion. |
|
|
|
Returns: |
|
Tensor: The output tensor of shape [N, L, C] after conversion. |
|
""" |
|
assert len(x.shape) == 4 |
|
return x.flatten(2).transpose(1, 2).contiguous() |
|
|
|
|
|
def build_drop_path(drop_path_rate): |
|
"""Build drop path layer.""" |
|
return build_dropout(dict(type='DropPath', drop_prob=drop_path_rate)) |
|
|
|
|
|
class WindowMSA(BaseModule): |
|
"""Window based multi-head self-attention (W-MSA) module with relative |
|
position bias. |
|
|
|
Args: |
|
embed_dims (int): Number of input channels. |
|
num_heads (int): Number of attention heads. |
|
window_size (tuple[int]): The height and width of the window. |
|
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. |
|
Default: True. |
|
qk_scale (float | None, optional): Override default qk scale of |
|
head_dim ** -0.5 if set. Default: None. |
|
attn_drop_rate (float, optional): Dropout ratio of attention weight. |
|
Default: 0.0 |
|
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. |
|
with_rpe (bool, optional): If True, use relative position bias. |
|
Default: True. |
|
init_cfg (dict | None, optional): The Config for initialization. |
|
Default: None. |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims, |
|
num_heads, |
|
window_size, |
|
qkv_bias=True, |
|
qk_scale=None, |
|
attn_drop_rate=0., |
|
proj_drop_rate=0., |
|
with_rpe=True, |
|
init_cfg=None): |
|
|
|
super().__init__(init_cfg=init_cfg) |
|
self.embed_dims = embed_dims |
|
self.window_size = window_size |
|
self.num_heads = num_heads |
|
head_embed_dims = embed_dims // num_heads |
|
self.scale = qk_scale or head_embed_dims**-0.5 |
|
|
|
self.with_rpe = with_rpe |
|
if self.with_rpe: |
|
|
|
self.relative_position_bias_table = nn.Parameter( |
|
torch.zeros( |
|
(2 * window_size[0] - 1) * (2 * window_size[1] - 1), |
|
num_heads)) |
|
|
|
Wh, Ww = self.window_size |
|
rel_index_coords = self.double_step_seq(2 * Ww - 1, Wh, 1, Ww) |
|
rel_position_index = rel_index_coords + rel_index_coords.T |
|
rel_position_index = rel_position_index.flip(1).contiguous() |
|
self.register_buffer('relative_position_index', rel_position_index) |
|
|
|
self.qkv = nn.Linear(embed_dims, embed_dims * 3, bias=qkv_bias) |
|
self.attn_drop = nn.Dropout(attn_drop_rate) |
|
self.proj = nn.Linear(embed_dims, embed_dims) |
|
self.proj_drop = nn.Dropout(proj_drop_rate) |
|
|
|
self.softmax = nn.Softmax(dim=-1) |
|
|
|
def init_weights(self): |
|
trunc_normal_init(self.relative_position_bias_table, std=0.02) |
|
|
|
def forward(self, x, mask=None): |
|
""" |
|
Args: |
|
|
|
x (tensor): input features with shape of (B*num_windows, N, C) |
|
mask (tensor | None, Optional): mask with shape of (num_windows, |
|
Wh*Ww, Wh*Ww), value should be between (-inf, 0]. |
|
""" |
|
B, N, C = x.shape |
|
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, |
|
C // self.num_heads).permute(2, 0, 3, 1, 4) |
|
q, k, v = qkv[0], qkv[1], qkv[2] |
|
|
|
q = q * self.scale |
|
attn = (q @ k.transpose(-2, -1)) |
|
|
|
if self.with_rpe: |
|
relative_position_bias = self.relative_position_bias_table[ |
|
self.relative_position_index.view(-1)].view( |
|
self.window_size[0] * self.window_size[1], |
|
self.window_size[0] * self.window_size[1], |
|
-1) |
|
relative_position_bias = relative_position_bias.permute( |
|
2, 0, 1).contiguous() |
|
attn = attn + relative_position_bias.unsqueeze(0) |
|
|
|
if mask is not None: |
|
nW = mask.shape[0] |
|
attn = attn.view(B // nW, nW, self.num_heads, N, |
|
N) + mask.unsqueeze(1).unsqueeze(0) |
|
attn = attn.view(-1, self.num_heads, N, N) |
|
attn = self.softmax(attn) |
|
|
|
attn = self.attn_drop(attn) |
|
|
|
x = (attn @ v).transpose(1, 2).reshape(B, N, C) |
|
x = self.proj(x) |
|
x = self.proj_drop(x) |
|
return x |
|
|
|
@staticmethod |
|
def double_step_seq(step1, len1, step2, len2): |
|
seq1 = torch.arange(0, step1 * len1, step1) |
|
seq2 = torch.arange(0, step2 * len2, step2) |
|
return (seq1[:, None] + seq2[None, :]).reshape(1, -1) |
|
|
|
|
|
class LocalWindowSelfAttention(BaseModule): |
|
r""" Local-window Self Attention (LSA) module with relative position bias. |
|
|
|
This module is the short-range self-attention module in the |
|
Interlaced Sparse Self-Attention <https://arxiv.org/abs/1907.12273>`_. |
|
|
|
Args: |
|
embed_dims (int): Number of input channels. |
|
num_heads (int): Number of attention heads. |
|
window_size (tuple[int] | int): The height and width of the window. |
|
qkv_bias (bool, optional): If True, add a learnable bias to q, k, v. |
|
Default: True. |
|
qk_scale (float | None, optional): Override default qk scale of |
|
head_dim ** -0.5 if set. Default: None. |
|
attn_drop_rate (float, optional): Dropout ratio of attention weight. |
|
Default: 0.0 |
|
proj_drop_rate (float, optional): Dropout ratio of output. Default: 0. |
|
with_rpe (bool, optional): If True, use relative position bias. |
|
Default: True. |
|
with_pad_mask (bool, optional): If True, mask out the padded tokens in |
|
the attention process. Default: False. |
|
init_cfg (dict | None, optional): The Config for initialization. |
|
Default: None. |
|
""" |
|
|
|
def __init__(self, |
|
embed_dims, |
|
num_heads, |
|
window_size, |
|
qkv_bias=True, |
|
qk_scale=None, |
|
attn_drop_rate=0., |
|
proj_drop_rate=0., |
|
with_rpe=True, |
|
with_pad_mask=False, |
|
init_cfg=None): |
|
super().__init__(init_cfg=init_cfg) |
|
if isinstance(window_size, int): |
|
window_size = (window_size, window_size) |
|
self.window_size = window_size |
|
self.with_pad_mask = with_pad_mask |
|
self.attn = WindowMSA( |
|
embed_dims=embed_dims, |
|
num_heads=num_heads, |
|
window_size=window_size, |
|
qkv_bias=qkv_bias, |
|
qk_scale=qk_scale, |
|
attn_drop_rate=attn_drop_rate, |
|
proj_drop_rate=proj_drop_rate, |
|
with_rpe=with_rpe, |
|
init_cfg=init_cfg) |
|
|
|
def forward(self, x, H, W, **kwargs): |
|
"""Forward function.""" |
|
B, N, C = x.shape |
|
x = x.view(B, H, W, C) |
|
Wh, Ww = self.window_size |
|
|
|
|
|
pad_h = math.ceil(H / Wh) * Wh - H |
|
pad_w = math.ceil(W / Ww) * Ww - W |
|
x = pad(x, (0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, |
|
pad_h - pad_h // 2)) |
|
|
|
|
|
x = x.view(B, math.ceil(H / Wh), Wh, math.ceil(W / Ww), Ww, C) |
|
x = x.permute(0, 1, 3, 2, 4, 5) |
|
x = x.reshape(-1, Wh * Ww, C) |
|
|
|
|
|
if self.with_pad_mask and pad_h > 0 and pad_w > 0: |
|
pad_mask = x.new_zeros(1, H, W, 1) |
|
pad_mask = pad( |
|
pad_mask, [ |
|
0, 0, pad_w // 2, pad_w - pad_w // 2, pad_h // 2, |
|
pad_h - pad_h // 2 |
|
], |
|
value=-float('inf')) |
|
pad_mask = pad_mask.view(1, math.ceil(H / Wh), Wh, |
|
math.ceil(W / Ww), Ww, 1) |
|
pad_mask = pad_mask.permute(1, 3, 0, 2, 4, 5) |
|
pad_mask = pad_mask.reshape(-1, Wh * Ww) |
|
pad_mask = pad_mask[:, None, :].expand([-1, Wh * Ww, -1]) |
|
out = self.attn(x, pad_mask, **kwargs) |
|
else: |
|
out = self.attn(x, **kwargs) |
|
|
|
|
|
out = out.reshape(B, math.ceil(H / Wh), math.ceil(W / Ww), Wh, Ww, C) |
|
out = out.permute(0, 1, 3, 2, 4, 5) |
|
out = out.reshape(B, H + pad_h, W + pad_w, C) |
|
|
|
|
|
out = out[:, pad_h // 2:H + pad_h // 2, pad_w // 2:W + pad_w // 2] |
|
return out.reshape(B, N, C) |
|
|
|
|
|
class CrossFFN(BaseModule): |
|
r"""FFN with Depthwise Conv of HRFormer. |
|
|
|
Args: |
|
in_features (int): The feature dimension. |
|
hidden_features (int, optional): The hidden dimension of FFNs. |
|
Defaults: The same as in_features. |
|
act_cfg (dict, optional): Config of activation layer. |
|
Default: dict(type='GELU'). |
|
dw_act_cfg (dict, optional): Config of activation layer appended |
|
right after DW Conv. Default: dict(type='GELU'). |
|
norm_cfg (dict, optional): Config of norm layer. |
|
Default: dict(type='SyncBN'). |
|
init_cfg (dict | list | None, optional): The init config. |
|
Default: None. |
|
""" |
|
|
|
def __init__(self, |
|
in_features, |
|
hidden_features=None, |
|
out_features=None, |
|
act_cfg=dict(type='GELU'), |
|
dw_act_cfg=dict(type='GELU'), |
|
norm_cfg=dict(type='SyncBN'), |
|
init_cfg=None): |
|
super().__init__(init_cfg=init_cfg) |
|
out_features = out_features or in_features |
|
hidden_features = hidden_features or in_features |
|
self.fc1 = nn.Conv2d(in_features, hidden_features, kernel_size=1) |
|
self.act1 = build_activation_layer(act_cfg) |
|
self.norm1 = build_norm_layer(norm_cfg, hidden_features)[1] |
|
self.dw3x3 = nn.Conv2d( |
|
hidden_features, |
|
hidden_features, |
|
kernel_size=3, |
|
stride=1, |
|
groups=hidden_features, |
|
padding=1) |
|
self.act2 = build_activation_layer(dw_act_cfg) |
|
self.norm2 = build_norm_layer(norm_cfg, hidden_features)[1] |
|
self.fc2 = nn.Conv2d(hidden_features, out_features, kernel_size=1) |
|
self.act3 = build_activation_layer(act_cfg) |
|
self.norm3 = build_norm_layer(norm_cfg, out_features)[1] |
|
|
|
|
|
self.layers = [ |
|
self.fc1, self.norm1, self.act1, self.dw3x3, self.norm2, self.act2, |
|
self.fc2, self.norm3, self.act3 |
|
] |
|
|
|
def forward(self, x, H, W): |
|
"""Forward function.""" |
|
x = nlc_to_nchw(x, (H, W)) |
|
for layer in self.layers: |
|
x = layer(x) |
|
x = nchw_to_nlc(x) |
|
return x |
|
|
|
|
|
class HRFormerBlock(BaseModule): |
|
"""High-Resolution Block for HRFormer. |
|
|
|
Args: |
|
in_features (int): The input dimension. |
|
out_features (int): The output dimension. |
|
num_heads (int): The number of head within each LSA. |
|
window_size (int, optional): The window size for the LSA. |
|
Default: 7 |
|
mlp_ratio (int, optional): The expansion ration of FFN. |
|
Default: 4 |
|
act_cfg (dict, optional): Config of activation layer. |
|
Default: dict(type='GELU'). |
|
norm_cfg (dict, optional): Config of norm layer. |
|
Default: dict(type='SyncBN'). |
|
transformer_norm_cfg (dict, optional): Config of transformer norm |
|
layer. Default: dict(type='LN', eps=1e-6). |
|
init_cfg (dict | list | None, optional): The init config. |
|
Default: None. |
|
""" |
|
|
|
expansion = 1 |
|
|
|
def __init__(self, |
|
in_features, |
|
out_features, |
|
num_heads, |
|
window_size=7, |
|
mlp_ratio=4.0, |
|
drop_path=0.0, |
|
act_cfg=dict(type='GELU'), |
|
norm_cfg=dict(type='SyncBN'), |
|
transformer_norm_cfg=dict(type='LN', eps=1e-6), |
|
init_cfg=None, |
|
**kwargs): |
|
super(HRFormerBlock, self).__init__(init_cfg=init_cfg) |
|
self.num_heads = num_heads |
|
self.window_size = window_size |
|
self.mlp_ratio = mlp_ratio |
|
|
|
self.norm1 = build_norm_layer(transformer_norm_cfg, in_features)[1] |
|
self.attn = LocalWindowSelfAttention( |
|
in_features, |
|
num_heads=num_heads, |
|
window_size=window_size, |
|
init_cfg=None, |
|
**kwargs) |
|
|
|
self.norm2 = build_norm_layer(transformer_norm_cfg, out_features)[1] |
|
self.ffn = CrossFFN( |
|
in_features=in_features, |
|
hidden_features=int(in_features * mlp_ratio), |
|
out_features=out_features, |
|
norm_cfg=norm_cfg, |
|
act_cfg=act_cfg, |
|
dw_act_cfg=act_cfg, |
|
init_cfg=None) |
|
|
|
self.drop_path = build_drop_path( |
|
drop_path) if drop_path > 0.0 else nn.Identity() |
|
|
|
def forward(self, x): |
|
"""Forward function.""" |
|
B, C, H, W = x.size() |
|
|
|
x = x.view(B, C, -1).permute(0, 2, 1) |
|
x = x + self.drop_path(self.attn(self.norm1(x), H, W)) |
|
|
|
x = x + self.drop_path(self.ffn(self.norm2(x), H, W)) |
|
x = x.permute(0, 2, 1).view(B, C, H, W) |
|
return x |
|
|
|
def extra_repr(self): |
|
"""(Optional) Set the extra information about this module.""" |
|
return 'num_heads={}, window_size={}, mlp_ratio={}'.format( |
|
self.num_heads, self.window_size, self.mlp_ratio) |
|
|
|
|
|
class HRFomerModule(HRModule): |
|
"""High-Resolution Module for HRFormer. |
|
|
|
Args: |
|
num_branches (int): The number of branches in the HRFormerModule. |
|
block (nn.Module): The building block of HRFormer. |
|
The block should be the HRFormerBlock. |
|
num_blocks (tuple): The number of blocks in each branch. |
|
The length must be equal to num_branches. |
|
num_inchannels (tuple): The number of input channels in each branch. |
|
The length must be equal to num_branches. |
|
num_channels (tuple): The number of channels in each branch. |
|
The length must be equal to num_branches. |
|
num_heads (tuple): The number of heads within the LSAs. |
|
num_window_sizes (tuple): The window size for the LSAs. |
|
num_mlp_ratios (tuple): The expansion ratio for the FFNs. |
|
drop_path (int, optional): The drop path rate of HRFomer. |
|
Default: 0.0 |
|
multiscale_output (bool, optional): Whether to output multi-level |
|
features produced by multiple branches. If False, only the first |
|
level feature will be output. Default: True. |
|
conv_cfg (dict, optional): Config of the conv layers. |
|
Default: None. |
|
norm_cfg (dict, optional): Config of the norm layers appended |
|
right after conv. Default: dict(type='SyncBN', requires_grad=True) |
|
transformer_norm_cfg (dict, optional): Config of the norm layers. |
|
Default: dict(type='LN', eps=1e-6) |
|
with_cp (bool): Use checkpoint or not. Using checkpoint will save some |
|
memory while slowing down the training speed. Default: False |
|
upsample_cfg(dict, optional): The config of upsample layers in fuse |
|
layers. Default: dict(mode='bilinear', align_corners=False) |
|
""" |
|
|
|
def __init__(self, |
|
num_branches, |
|
block, |
|
num_blocks, |
|
num_inchannels, |
|
num_channels, |
|
num_heads, |
|
num_window_sizes, |
|
num_mlp_ratios, |
|
multiscale_output=True, |
|
drop_paths=0.0, |
|
with_rpe=True, |
|
with_pad_mask=False, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='SyncBN', requires_grad=True), |
|
transformer_norm_cfg=dict(type='LN', eps=1e-6), |
|
with_cp=False, |
|
upsample_cfg=dict(mode='bilinear', align_corners=False)): |
|
|
|
self.transformer_norm_cfg = transformer_norm_cfg |
|
self.drop_paths = drop_paths |
|
self.num_heads = num_heads |
|
self.num_window_sizes = num_window_sizes |
|
self.num_mlp_ratios = num_mlp_ratios |
|
self.with_rpe = with_rpe |
|
self.with_pad_mask = with_pad_mask |
|
|
|
super().__init__(num_branches, block, num_blocks, num_inchannels, |
|
num_channels, multiscale_output, with_cp, conv_cfg, |
|
norm_cfg, upsample_cfg) |
|
|
|
def _make_one_branch(self, |
|
branch_index, |
|
block, |
|
num_blocks, |
|
num_channels, |
|
stride=1): |
|
"""Build one branch.""" |
|
|
|
assert stride == 1 and self.in_channels[branch_index] == num_channels[ |
|
branch_index] |
|
layers = [] |
|
layers.append( |
|
block( |
|
self.in_channels[branch_index], |
|
num_channels[branch_index], |
|
num_heads=self.num_heads[branch_index], |
|
window_size=self.num_window_sizes[branch_index], |
|
mlp_ratio=self.num_mlp_ratios[branch_index], |
|
drop_path=self.drop_paths[0], |
|
norm_cfg=self.norm_cfg, |
|
transformer_norm_cfg=self.transformer_norm_cfg, |
|
init_cfg=None, |
|
with_rpe=self.with_rpe, |
|
with_pad_mask=self.with_pad_mask)) |
|
|
|
self.in_channels[ |
|
branch_index] = self.in_channels[branch_index] * block.expansion |
|
for i in range(1, num_blocks[branch_index]): |
|
layers.append( |
|
block( |
|
self.in_channels[branch_index], |
|
num_channels[branch_index], |
|
num_heads=self.num_heads[branch_index], |
|
window_size=self.num_window_sizes[branch_index], |
|
mlp_ratio=self.num_mlp_ratios[branch_index], |
|
drop_path=self.drop_paths[i], |
|
norm_cfg=self.norm_cfg, |
|
transformer_norm_cfg=self.transformer_norm_cfg, |
|
init_cfg=None, |
|
with_rpe=self.with_rpe, |
|
with_pad_mask=self.with_pad_mask)) |
|
return nn.Sequential(*layers) |
|
|
|
def _make_fuse_layers(self): |
|
"""Build fuse layers.""" |
|
if self.num_branches == 1: |
|
return None |
|
num_branches = self.num_branches |
|
num_inchannels = self.in_channels |
|
fuse_layers = [] |
|
for i in range(num_branches if self.multiscale_output else 1): |
|
fuse_layer = [] |
|
for j in range(num_branches): |
|
if j > i: |
|
fuse_layer.append( |
|
nn.Sequential( |
|
build_conv_layer( |
|
self.conv_cfg, |
|
num_inchannels[j], |
|
num_inchannels[i], |
|
kernel_size=1, |
|
stride=1, |
|
bias=False), |
|
build_norm_layer(self.norm_cfg, |
|
num_inchannels[i])[1], |
|
nn.Upsample( |
|
scale_factor=2**(j - i), |
|
mode=self.upsample_cfg['mode'], |
|
align_corners=self. |
|
upsample_cfg['align_corners']))) |
|
elif j == i: |
|
fuse_layer.append(None) |
|
else: |
|
conv3x3s = [] |
|
for k in range(i - j): |
|
if k == i - j - 1: |
|
num_outchannels_conv3x3 = num_inchannels[i] |
|
with_out_act = False |
|
else: |
|
num_outchannels_conv3x3 = num_inchannels[j] |
|
with_out_act = True |
|
sub_modules = [ |
|
build_conv_layer( |
|
self.conv_cfg, |
|
num_inchannels[j], |
|
num_inchannels[j], |
|
kernel_size=3, |
|
stride=2, |
|
padding=1, |
|
groups=num_inchannels[j], |
|
bias=False, |
|
), |
|
build_norm_layer(self.norm_cfg, |
|
num_inchannels[j])[1], |
|
build_conv_layer( |
|
self.conv_cfg, |
|
num_inchannels[j], |
|
num_outchannels_conv3x3, |
|
kernel_size=1, |
|
stride=1, |
|
bias=False, |
|
), |
|
build_norm_layer(self.norm_cfg, |
|
num_outchannels_conv3x3)[1] |
|
] |
|
if with_out_act: |
|
sub_modules.append(nn.ReLU(False)) |
|
conv3x3s.append(nn.Sequential(*sub_modules)) |
|
fuse_layer.append(nn.Sequential(*conv3x3s)) |
|
fuse_layers.append(nn.ModuleList(fuse_layer)) |
|
|
|
return nn.ModuleList(fuse_layers) |
|
|
|
def get_num_inchannels(self): |
|
"""Return the number of input channels.""" |
|
return self.in_channels |
|
|
|
|
|
@BACKBONES.register_module() |
|
class HRFormer(HRNet): |
|
"""HRFormer backbone. |
|
|
|
This backbone is the implementation of `HRFormer: High-Resolution |
|
Transformer for Dense Prediction <https://arxiv.org/abs/2110.09408>`_. |
|
|
|
Args: |
|
extra (dict): Detailed configuration for each stage of HRNet. |
|
There must be 4 stages, the configuration for each stage must have |
|
5 keys: |
|
|
|
- num_modules (int): The number of HRModule in this stage. |
|
- num_branches (int): The number of branches in the HRModule. |
|
- block (str): The type of block. |
|
- num_blocks (tuple): The number of blocks in each branch. |
|
The length must be equal to num_branches. |
|
- num_channels (tuple): The number of channels in each branch. |
|
The length must be equal to num_branches. |
|
in_channels (int): Number of input image channels. Normally 3. |
|
conv_cfg (dict): Dictionary to construct and config conv layer. |
|
Default: None. |
|
norm_cfg (dict): Config of norm layer. |
|
Use `SyncBN` by default. |
|
transformer_norm_cfg (dict): Config of transformer norm layer. |
|
Use `LN` by default. |
|
norm_eval (bool): Whether to set norm layers to eval mode, namely, |
|
freeze running stats (mean and var). Note: Effect on Batch Norm |
|
and its variants only. Default: False. |
|
zero_init_residual (bool): Whether to use zero init for last norm layer |
|
in resblocks to let them behave as identity. Default: False. |
|
frozen_stages (int): Stages to be frozen (stop grad and set eval mode). |
|
-1 means not freezing any parameters. Default: -1. |
|
Example: |
|
>>> from mmpose.models import HRFormer |
|
>>> import torch |
|
>>> extra = dict( |
|
>>> stage1=dict( |
|
>>> num_modules=1, |
|
>>> num_branches=1, |
|
>>> block='BOTTLENECK', |
|
>>> num_blocks=(2, ), |
|
>>> num_channels=(64, )), |
|
>>> stage2=dict( |
|
>>> num_modules=1, |
|
>>> num_branches=2, |
|
>>> block='HRFORMER', |
|
>>> window_sizes=(7, 7), |
|
>>> num_heads=(1, 2), |
|
>>> mlp_ratios=(4, 4), |
|
>>> num_blocks=(2, 2), |
|
>>> num_channels=(32, 64)), |
|
>>> stage3=dict( |
|
>>> num_modules=4, |
|
>>> num_branches=3, |
|
>>> block='HRFORMER', |
|
>>> window_sizes=(7, 7, 7), |
|
>>> num_heads=(1, 2, 4), |
|
>>> mlp_ratios=(4, 4, 4), |
|
>>> num_blocks=(2, 2, 2), |
|
>>> num_channels=(32, 64, 128)), |
|
>>> stage4=dict( |
|
>>> num_modules=2, |
|
>>> num_branches=4, |
|
>>> block='HRFORMER', |
|
>>> window_sizes=(7, 7, 7, 7), |
|
>>> num_heads=(1, 2, 4, 8), |
|
>>> mlp_ratios=(4, 4, 4, 4), |
|
>>> num_blocks=(2, 2, 2, 2), |
|
>>> num_channels=(32, 64, 128, 256))) |
|
>>> self = HRFormer(extra, in_channels=1) |
|
>>> self.eval() |
|
>>> inputs = torch.rand(1, 1, 32, 32) |
|
>>> level_outputs = self.forward(inputs) |
|
>>> for level_out in level_outputs: |
|
... print(tuple(level_out.shape)) |
|
(1, 32, 8, 8) |
|
(1, 64, 4, 4) |
|
(1, 128, 2, 2) |
|
(1, 256, 1, 1) |
|
""" |
|
|
|
blocks_dict = {'BOTTLENECK': Bottleneck, 'HRFORMERBLOCK': HRFormerBlock} |
|
|
|
def __init__(self, |
|
extra, |
|
in_channels=3, |
|
conv_cfg=None, |
|
norm_cfg=dict(type='BN', requires_grad=True), |
|
transformer_norm_cfg=dict(type='LN', eps=1e-6), |
|
norm_eval=False, |
|
with_cp=False, |
|
zero_init_residual=False, |
|
frozen_stages=-1): |
|
|
|
|
|
depths = [ |
|
extra[stage]['num_blocks'][0] * extra[stage]['num_modules'] |
|
for stage in ['stage2', 'stage3', 'stage4'] |
|
] |
|
depth_s2, depth_s3, _ = depths |
|
drop_path_rate = extra['drop_path_rate'] |
|
dpr = [ |
|
x.item() for x in torch.linspace(0, drop_path_rate, sum(depths)) |
|
] |
|
extra['stage2']['drop_path_rates'] = dpr[0:depth_s2] |
|
extra['stage3']['drop_path_rates'] = dpr[depth_s2:depth_s2 + depth_s3] |
|
extra['stage4']['drop_path_rates'] = dpr[depth_s2 + depth_s3:] |
|
|
|
|
|
upsample_cfg = extra.get('upsample', { |
|
'mode': 'bilinear', |
|
'align_corners': False |
|
}) |
|
extra['upsample'] = upsample_cfg |
|
self.transformer_norm_cfg = transformer_norm_cfg |
|
self.with_rpe = extra.get('with_rpe', True) |
|
self.with_pad_mask = extra.get('with_pad_mask', False) |
|
|
|
super().__init__(extra, in_channels, conv_cfg, norm_cfg, norm_eval, |
|
with_cp, zero_init_residual, frozen_stages) |
|
|
|
def _make_stage(self, |
|
layer_config, |
|
num_inchannels, |
|
multiscale_output=True): |
|
"""Make each stage.""" |
|
num_modules = layer_config['num_modules'] |
|
num_branches = layer_config['num_branches'] |
|
num_blocks = layer_config['num_blocks'] |
|
num_channels = layer_config['num_channels'] |
|
block = self.blocks_dict[layer_config['block']] |
|
num_heads = layer_config['num_heads'] |
|
num_window_sizes = layer_config['window_sizes'] |
|
num_mlp_ratios = layer_config['mlp_ratios'] |
|
drop_path_rates = layer_config['drop_path_rates'] |
|
|
|
modules = [] |
|
for i in range(num_modules): |
|
|
|
if not multiscale_output and i == num_modules - 1: |
|
reset_multiscale_output = False |
|
else: |
|
reset_multiscale_output = True |
|
|
|
modules.append( |
|
HRFomerModule( |
|
num_branches, |
|
block, |
|
num_blocks, |
|
num_inchannels, |
|
num_channels, |
|
num_heads, |
|
num_window_sizes, |
|
num_mlp_ratios, |
|
reset_multiscale_output, |
|
drop_paths=drop_path_rates[num_blocks[0] * |
|
i:num_blocks[0] * (i + 1)], |
|
with_rpe=self.with_rpe, |
|
with_pad_mask=self.with_pad_mask, |
|
conv_cfg=self.conv_cfg, |
|
norm_cfg=self.norm_cfg, |
|
transformer_norm_cfg=self.transformer_norm_cfg, |
|
with_cp=self.with_cp, |
|
upsample_cfg=self.upsample_cfg)) |
|
num_inchannels = modules[-1].get_num_inchannels() |
|
|
|
return nn.Sequential(*modules), num_inchannels |
|
|