Spaces:
Runtime error
Runtime error
import torch.nn as nn | |
import torch.nn.functional as F | |
from annotator.uniformer.mmcv.cnn import ConvModule | |
from ..builder import NECKS | |
class MultiLevelNeck(nn.Module): | |
"""MultiLevelNeck. | |
A neck structure connect vit backbone and decoder_heads. | |
Args: | |
in_channels (List[int]): Number of input channels per scale. | |
out_channels (int): Number of output channels (used at each scale). | |
scales (List[int]): Scale factors for each input feature map. | |
norm_cfg (dict): Config dict for normalization layer. Default: None. | |
act_cfg (dict): Config dict for activation layer in ConvModule. | |
Default: None. | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
scales=[0.5, 1, 2, 4], | |
norm_cfg=None, | |
act_cfg=None): | |
super(MultiLevelNeck, self).__init__() | |
assert isinstance(in_channels, list) | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.scales = scales | |
self.num_outs = len(scales) | |
self.lateral_convs = nn.ModuleList() | |
self.convs = nn.ModuleList() | |
for in_channel in in_channels: | |
self.lateral_convs.append( | |
ConvModule( | |
in_channel, | |
out_channels, | |
kernel_size=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
for _ in range(self.num_outs): | |
self.convs.append( | |
ConvModule( | |
out_channels, | |
out_channels, | |
kernel_size=3, | |
padding=1, | |
stride=1, | |
norm_cfg=norm_cfg, | |
act_cfg=act_cfg)) | |
def forward(self, inputs): | |
assert len(inputs) == len(self.in_channels) | |
print(inputs[0].shape) | |
inputs = [ | |
lateral_conv(inputs[i]) | |
for i, lateral_conv in enumerate(self.lateral_convs) | |
] | |
# for len(inputs) not equal to self.num_outs | |
if len(inputs) == 1: | |
inputs = [inputs[0] for _ in range(self.num_outs)] | |
outs = [] | |
for i in range(self.num_outs): | |
x_resize = F.interpolate( | |
inputs[i], scale_factor=self.scales[i], mode='bilinear') | |
outs.append(self.convs[i](x_resize)) | |
return tuple(outs) | |