Spaces:
Runtime error
Runtime error
# Copyright (c) 2019 Western Digital Corporation or its affiliates. | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
from mmcv.cnn import ConvModule | |
from ..builder import NECKS | |
class DetectionBlock(nn.Module): | |
"""Detection block in YOLO neck. | |
Let out_channels = n, the DetectionBlock contains: | |
Six ConvLayers, 1 Conv2D Layer and 1 YoloLayer. | |
The first 6 ConvLayers are formed the following way: | |
1x1xn, 3x3x2n, 1x1xn, 3x3x2n, 1x1xn, 3x3x2n. | |
The Conv2D layer is 1x1x255. | |
Some block will have branch after the fifth ConvLayer. | |
The input channel is arbitrary (in_channels) | |
Args: | |
in_channels (int): The number of input channels. | |
out_channels (int): The number of output channels. | |
conv_cfg (dict): Config dict for convolution layer. Default: None. | |
norm_cfg (dict): Dictionary to construct and config norm layer. | |
Default: dict(type='BN', requires_grad=True) | |
act_cfg (dict): Config dict for activation layer. | |
Default: dict(type='LeakyReLU', negative_slope=0.1). | |
""" | |
def __init__(self, | |
in_channels, | |
out_channels, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN', requires_grad=True), | |
act_cfg=dict(type='LeakyReLU', negative_slope=0.1)): | |
super(DetectionBlock, self).__init__() | |
double_out_channels = out_channels * 2 | |
# shortcut | |
cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) | |
self.conv1 = ConvModule(in_channels, out_channels, 1, **cfg) | |
self.conv2 = ConvModule( | |
out_channels, double_out_channels, 3, padding=1, **cfg) | |
self.conv3 = ConvModule(double_out_channels, out_channels, 1, **cfg) | |
self.conv4 = ConvModule( | |
out_channels, double_out_channels, 3, padding=1, **cfg) | |
self.conv5 = ConvModule(double_out_channels, out_channels, 1, **cfg) | |
def forward(self, x): | |
tmp = self.conv1(x) | |
tmp = self.conv2(tmp) | |
tmp = self.conv3(tmp) | |
tmp = self.conv4(tmp) | |
out = self.conv5(tmp) | |
return out | |
class YOLOV3Neck(nn.Module): | |
"""The neck of YOLOV3. | |
It can be treated as a simplified version of FPN. It | |
will take the result from Darknet backbone and do some upsampling and | |
concatenation. It will finally output the detection result. | |
Note: | |
The input feats should be from top to bottom. | |
i.e., from high-lvl to low-lvl | |
But YOLOV3Neck will process them in reversed order. | |
i.e., from bottom (high-lvl) to top (low-lvl) | |
Args: | |
num_scales (int): The number of scales / stages. | |
in_channels (int): The number of input channels. | |
out_channels (int): The number of output channels. | |
conv_cfg (dict): Config dict for convolution layer. Default: None. | |
norm_cfg (dict): Dictionary to construct and config norm layer. | |
Default: dict(type='BN', requires_grad=True) | |
act_cfg (dict): Config dict for activation layer. | |
Default: dict(type='LeakyReLU', negative_slope=0.1). | |
""" | |
def __init__(self, | |
num_scales, | |
in_channels, | |
out_channels, | |
conv_cfg=None, | |
norm_cfg=dict(type='BN', requires_grad=True), | |
act_cfg=dict(type='LeakyReLU', negative_slope=0.1)): | |
super(YOLOV3Neck, self).__init__() | |
assert (num_scales == len(in_channels) == len(out_channels)) | |
self.num_scales = num_scales | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
# shortcut | |
cfg = dict(conv_cfg=conv_cfg, norm_cfg=norm_cfg, act_cfg=act_cfg) | |
# To support arbitrary scales, the code looks awful, but it works. | |
# Better solution is welcomed. | |
self.detect1 = DetectionBlock(in_channels[0], out_channels[0], **cfg) | |
for i in range(1, self.num_scales): | |
in_c, out_c = self.in_channels[i], self.out_channels[i] | |
self.add_module(f'conv{i}', ConvModule(in_c, out_c, 1, **cfg)) | |
# in_c + out_c : High-lvl feats will be cat with low-lvl feats | |
self.add_module(f'detect{i+1}', | |
DetectionBlock(in_c + out_c, out_c, **cfg)) | |
def forward(self, feats): | |
assert len(feats) == self.num_scales | |
# processed from bottom (high-lvl) to top (low-lvl) | |
outs = [] | |
out = self.detect1(feats[-1]) | |
outs.append(out) | |
for i, x in enumerate(reversed(feats[:-1])): | |
conv = getattr(self, f'conv{i+1}') | |
tmp = conv(out) | |
# Cat with low-lvl feats | |
tmp = F.interpolate(tmp, scale_factor=2) | |
tmp = torch.cat((tmp, x), 1) | |
detect = getattr(self, f'detect{i+2}') | |
out = detect(tmp) | |
outs.append(out) | |
return tuple(outs) | |
def init_weights(self): | |
"""Initialize the weights of module.""" | |
# init is done in ConvModule | |
pass | |