Spaces:
Running
Running
File size: 6,944 Bytes
4a3ab35 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 |
from collections import OrderedDict
import torch
import torch.nn as nn
from nets.CSPdarknet import darknet53
def conv2d(filter_in, filter_out, kernel_size, stride=1):
pad = (kernel_size - 1) // 2 if kernel_size else 0
return nn.Sequential(OrderedDict([
("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, bias=False)),
("bn", nn.BatchNorm2d(filter_out)),
("relu", nn.LeakyReLU(0.1)),
]))
#---------------------------------------------------#
# SPP结构,利用不同大小的池化核进行池化
# 池化后堆叠
#---------------------------------------------------#
class SpatialPyramidPooling(nn.Module):
def __init__(self, pool_sizes=[5, 9, 13]):
super(SpatialPyramidPooling, self).__init__()
self.maxpools = nn.ModuleList([nn.MaxPool2d(pool_size, 1, pool_size//2) for pool_size in pool_sizes])
def forward(self, x):
features = [maxpool(x) for maxpool in self.maxpools[::-1]]
features = torch.cat(features + [x], dim=1)
return features
#---------------------------------------------------#
# 卷积 + 上采样
#---------------------------------------------------#
class Upsample(nn.Module):
def __init__(self, in_channels, out_channels):
super(Upsample, self).__init__()
self.upsample = nn.Sequential(
conv2d(in_channels, out_channels, 1),
nn.Upsample(scale_factor=2, mode='nearest')
)
def forward(self, x,):
x = self.upsample(x)
return x
#---------------------------------------------------#
# 三次卷积块
#---------------------------------------------------#
def make_three_conv(filters_list, in_filters):
m = nn.Sequential(
conv2d(in_filters, filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
conv2d(filters_list[1], filters_list[0], 1),
)
return m
#---------------------------------------------------#
# 五次卷积块
#---------------------------------------------------#
def make_five_conv(filters_list, in_filters):
m = nn.Sequential(
conv2d(in_filters, filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
conv2d(filters_list[1], filters_list[0], 1),
conv2d(filters_list[0], filters_list[1], 3),
conv2d(filters_list[1], filters_list[0], 1),
)
return m
#---------------------------------------------------#
# 最后获得yolov4的输出
#---------------------------------------------------#
def yolo_head(filters_list, in_filters):
m = nn.Sequential(
conv2d(in_filters, filters_list[0], 3),
nn.Conv2d(filters_list[0], filters_list[1], 1),
)
return m
#---------------------------------------------------#
# yolo_body
#---------------------------------------------------#
class YoloBody(nn.Module):
def __init__(self, anchors_mask, num_classes, pretrained = False):
super(YoloBody, self).__init__()
#---------------------------------------------------#
# 生成CSPdarknet53的主干模型
# 获得三个有效特征层,他们的shape分别是:
# 52,52,256
# 26,26,512
# 13,13,1024
#---------------------------------------------------#
self.backbone = darknet53(pretrained)
self.conv1 = make_three_conv([512,1024],1024)
self.SPP = SpatialPyramidPooling()
self.conv2 = make_three_conv([512,1024],2048)
self.upsample1 = Upsample(512,256)
self.conv_for_P4 = conv2d(512,256,1)
self.make_five_conv1 = make_five_conv([256, 512],512)
self.upsample2 = Upsample(256,128)
self.conv_for_P3 = conv2d(256,128,1)
self.make_five_conv2 = make_five_conv([128, 256],256)
# 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
self.yolo_head3 = yolo_head([256, len(anchors_mask[0]) * (5 + num_classes)],128)
self.down_sample1 = conv2d(128,256,3,stride=2)
self.make_five_conv3 = make_five_conv([256, 512],512)
# 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
self.yolo_head2 = yolo_head([512, len(anchors_mask[1]) * (5 + num_classes)],256)
self.down_sample2 = conv2d(256,512,3,stride=2)
self.make_five_conv4 = make_five_conv([512, 1024],1024)
# 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
self.yolo_head1 = yolo_head([1024, len(anchors_mask[2]) * (5 + num_classes)],512)
def forward(self, x):
# backbone
x2, x1, x0 = self.backbone(x)
# 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,2048
P5 = self.conv1(x0)
P5 = self.SPP(P5)
# 13,13,2048 -> 13,13,512 -> 13,13,1024 -> 13,13,512
P5 = self.conv2(P5)
# 13,13,512 -> 13,13,256 -> 26,26,256
P5_upsample = self.upsample1(P5)
# 26,26,512 -> 26,26,256
P4 = self.conv_for_P4(x1)
# 26,26,256 + 26,26,256 -> 26,26,512
P4 = torch.cat([P4,P5_upsample],axis=1)
# 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
P4 = self.make_five_conv1(P4)
# 26,26,256 -> 26,26,128 -> 52,52,128
P4_upsample = self.upsample2(P4)
# 52,52,256 -> 52,52,128
P3 = self.conv_for_P3(x2)
# 52,52,128 + 52,52,128 -> 52,52,256
P3 = torch.cat([P3,P4_upsample],axis=1)
# 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128
P3 = self.make_five_conv2(P3)
# 52,52,128 -> 26,26,256
P3_downsample = self.down_sample1(P3)
# 26,26,256 + 26,26,256 -> 26,26,512
P4 = torch.cat([P3_downsample,P4],axis=1)
# 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
P4 = self.make_five_conv3(P4)
# 26,26,256 -> 13,13,512
P4_downsample = self.down_sample2(P4)
# 13,13,512 + 13,13,512 -> 13,13,1024
P5 = torch.cat([P4_downsample,P5],axis=1)
# 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512
P5 = self.make_five_conv4(P5)
#---------------------------------------------------#
# 第三个特征层
# y3=(batch_size,75,52,52)
#---------------------------------------------------#
out2 = self.yolo_head3(P3)
#---------------------------------------------------#
# 第二个特征层
# y2=(batch_size,75,26,26)
#---------------------------------------------------#
out1 = self.yolo_head2(P4)
#---------------------------------------------------#
# 第一个特征层
# y1=(batch_size,75,13,13)
#---------------------------------------------------#
out0 = self.yolo_head1(P5)
return out0, out1, out2
|