Spaces:
Running
Running
YoloGesture推理主要代码
Browse files- .gitattributes +1 -0
- img/anticlockwise.jpg +0 -0
- img/back.jpg +0 -0
- img/clockwise.jpg +0 -0
- img/down.jpg +0 -0
- img/front.jpg +0 -0
- img/left.jpg +0 -0
- img/right.jpg +0 -0
- img/up.jpg +0 -0
- model_data/gesture.yaml +20 -0
- model_data/gesture_classes.txt +8 -0
- model_data/simhei.ttf +3 -0
- model_data/yolo_anchors.txt +1 -0
- model_data/yolotiny_anchors.txt +1 -0
- nets/CSPdarknet.py +174 -0
- nets/CSPdarknet53_tiny.py +143 -0
- nets/__init__.py +1 -0
- nets/attention.py +114 -0
- nets/yolo.py +185 -0
- nets/yolo_tiny.py +99 -0
- nets/yolo_training.py +476 -0
- nets/yolotiny_training.py +474 -0
- utils/__init__.py +1 -0
- utils/callbacks.py +71 -0
- utils/dataloader.py +360 -0
- utils/utils.py +62 -0
- utils/utils_bbox.py +227 -0
- utils/utils_fit.py +128 -0
- utils/utils_map.py +901 -0
.gitattributes
CHANGED
@@ -32,3 +32,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
32 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
33 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
34 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
35 |
+
model_data/simhei.ttf filter=lfs diff=lfs merge=lfs -text
|
img/anticlockwise.jpg
ADDED
img/back.jpg
ADDED
img/clockwise.jpg
ADDED
img/down.jpg
ADDED
img/front.jpg
ADDED
img/left.jpg
ADDED
img/right.jpg
ADDED
img/up.jpg
ADDED
model_data/gesture.yaml
ADDED
@@ -0,0 +1,20 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#------------------------------detect.py--------------------------------#
|
2 |
+
# 这一部分是为了半自动标注数据,可以减轻负担,需要提前训练一个权重,以Labelme格式保存
|
3 |
+
# dir_origin_path 图片存放位置
|
4 |
+
# dir_save_path Annotation保存位置
|
5 |
+
# ----------------------------------------------------------------------#
|
6 |
+
dir_detect_path: ./JPEGImages
|
7 |
+
detect_save_path: ./Annotation
|
8 |
+
|
9 |
+
# ----------------------------- train.py -------------------------------#
|
10 |
+
nc: 8 # 类别的数量
|
11 |
+
classes: ["up","down","left","right","front","back","clockwise","anticlockwise"] # 类别
|
12 |
+
confidence: 0.5 # 置信度
|
13 |
+
nms_iou: 0.3
|
14 |
+
letterbox_image: False
|
15 |
+
|
16 |
+
lr_decay_type: cos # 使用到的学习率下降方式,可选的有step、cos
|
17 |
+
# 用于设置是否使用多线程读取数据
|
18 |
+
# 开启后会加快数据读取速度,但是会占用更多内存
|
19 |
+
# 内存较小的电脑可以设置为2或者0,win建议设为0
|
20 |
+
num_workers: 4
|
model_data/gesture_classes.txt
ADDED
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
up
|
2 |
+
down
|
3 |
+
left
|
4 |
+
right
|
5 |
+
front
|
6 |
+
back
|
7 |
+
clockwise
|
8 |
+
anticlockwise
|
model_data/simhei.ttf
ADDED
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
1 |
+
version https://git-lfs.github.com/spec/v1
|
2 |
+
oid sha256:aa4560dd8fe5645745fed3ffa301c3ca4d6c03cbd738145b613303961ba733b8
|
3 |
+
size 9753388
|
model_data/yolo_anchors.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
12, 16, 19, 36, 40, 28, 36, 75, 76, 55, 72, 146, 142, 110, 192, 243, 459, 401
|
model_data/yolotiny_anchors.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
10,14, 23,27, 37,58, 81,82, 135,169, 344,319
|
nets/CSPdarknet.py
ADDED
@@ -0,0 +1,174 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from collections import OrderedDict
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
|
8 |
+
|
9 |
+
#-------------------------------------------------#
|
10 |
+
# MISH激活函数
|
11 |
+
#-------------------------------------------------#
|
12 |
+
class Mish(nn.Module):
|
13 |
+
def __init__(self):
|
14 |
+
super(Mish, self).__init__()
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
return x * torch.tanh(F.softplus(x))
|
18 |
+
|
19 |
+
#---------------------------------------------------#
|
20 |
+
# 卷积块 -> 卷积 + 标准化 + 激活函数
|
21 |
+
# Conv2d + BatchNormalization + Mish
|
22 |
+
#---------------------------------------------------#
|
23 |
+
class BasicConv(nn.Module):
|
24 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
|
25 |
+
super(BasicConv, self).__init__()
|
26 |
+
|
27 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size//2, bias=False)
|
28 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
29 |
+
self.activation = Mish()
|
30 |
+
|
31 |
+
def forward(self, x):
|
32 |
+
x = self.conv(x)
|
33 |
+
x = self.bn(x)
|
34 |
+
x = self.activation(x)
|
35 |
+
return x
|
36 |
+
|
37 |
+
#---------------------------------------------------#
|
38 |
+
# CSPdarknet的结构块的组成部分
|
39 |
+
# 内部堆叠的残差块
|
40 |
+
#---------------------------------------------------#
|
41 |
+
class Resblock(nn.Module):
|
42 |
+
def __init__(self, channels, hidden_channels=None):
|
43 |
+
super(Resblock, self).__init__()
|
44 |
+
|
45 |
+
if hidden_channels is None:
|
46 |
+
hidden_channels = channels
|
47 |
+
|
48 |
+
self.block = nn.Sequential(
|
49 |
+
BasicConv(channels, hidden_channels, 1),
|
50 |
+
BasicConv(hidden_channels, channels, 3)
|
51 |
+
)
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
return x + self.block(x)
|
55 |
+
|
56 |
+
#--------------------------------------------------------------------#
|
57 |
+
# CSPdarknet的结构块
|
58 |
+
# 首先利用ZeroPadding2D和一个步长为2x2的卷积块进行高和宽的压缩
|
59 |
+
# 然后建立一个大的残差边shortconv、这个大残差边绕过了很多的残差结构
|
60 |
+
# 主干部分会对num_blocks进行循环,循环内部是残差结构。
|
61 |
+
# 对于整个CSPdarknet的结构块,就是一个大残差块+内部多个小残差块
|
62 |
+
#--------------------------------------------------------------------#
|
63 |
+
class Resblock_body(nn.Module):
|
64 |
+
def __init__(self, in_channels, out_channels, num_blocks, first):
|
65 |
+
super(Resblock_body, self).__init__()
|
66 |
+
#----------------------------------------------------------------#
|
67 |
+
# 利用一个步长为2x2的卷积块进行高和宽的压缩
|
68 |
+
#----------------------------------------------------------------#
|
69 |
+
self.downsample_conv = BasicConv(in_channels, out_channels, 3, stride=2)
|
70 |
+
|
71 |
+
if first:
|
72 |
+
#--------------------------------------------------------------------------#
|
73 |
+
# 然后建立一个大的残差边self.split_conv0、这个大残差边绕过了很多的残差结构
|
74 |
+
#--------------------------------------------------------------------------#
|
75 |
+
self.split_conv0 = BasicConv(out_channels, out_channels, 1)
|
76 |
+
|
77 |
+
#----------------------------------------------------------------#
|
78 |
+
# 主干部分会对num_blocks进行循环,循环内部是残差结构。
|
79 |
+
#----------------------------------------------------------------#
|
80 |
+
self.split_conv1 = BasicConv(out_channels, out_channels, 1)
|
81 |
+
self.blocks_conv = nn.Sequential(
|
82 |
+
Resblock(channels=out_channels, hidden_channels=out_channels//2),
|
83 |
+
BasicConv(out_channels, out_channels, 1)
|
84 |
+
)
|
85 |
+
|
86 |
+
self.concat_conv = BasicConv(out_channels*2, out_channels, 1)
|
87 |
+
else:
|
88 |
+
#--------------------------------------------------------------------------#
|
89 |
+
# 然后建立一个大的残差边self.split_conv0、这个大残差边绕过了很多的残差结构
|
90 |
+
#--------------------------------------------------------------------------#
|
91 |
+
self.split_conv0 = BasicConv(out_channels, out_channels//2, 1)
|
92 |
+
|
93 |
+
#----------------------------------------------------------------#
|
94 |
+
# 主干部分会对num_blocks进行循环,循环内部是残差结构。
|
95 |
+
#----------------------------------------------------------------#
|
96 |
+
self.split_conv1 = BasicConv(out_channels, out_channels//2, 1)
|
97 |
+
self.blocks_conv = nn.Sequential(
|
98 |
+
*[Resblock(out_channels//2) for _ in range(num_blocks)],
|
99 |
+
BasicConv(out_channels//2, out_channels//2, 1)
|
100 |
+
)
|
101 |
+
|
102 |
+
self.concat_conv = BasicConv(out_channels, out_channels, 1)
|
103 |
+
|
104 |
+
def forward(self, x):
|
105 |
+
x = self.downsample_conv(x)
|
106 |
+
|
107 |
+
x0 = self.split_conv0(x)
|
108 |
+
|
109 |
+
x1 = self.split_conv1(x)
|
110 |
+
x1 = self.blocks_conv(x1)
|
111 |
+
|
112 |
+
#------------------------------------#
|
113 |
+
# 将大残差边再堆叠回来
|
114 |
+
#------------------------------------#
|
115 |
+
x = torch.cat([x1, x0], dim=1)
|
116 |
+
#------------------------------------#
|
117 |
+
# 最���对通道数进行整合
|
118 |
+
#------------------------------------#
|
119 |
+
x = self.concat_conv(x)
|
120 |
+
|
121 |
+
return x
|
122 |
+
|
123 |
+
#---------------------------------------------------#
|
124 |
+
# CSPdarknet53 的主体部分
|
125 |
+
# 输入为一张416x416x3的图片
|
126 |
+
# 输出为三个有效特征层
|
127 |
+
#---------------------------------------------------#
|
128 |
+
class CSPDarkNet(nn.Module):
|
129 |
+
def __init__(self, layers):
|
130 |
+
super(CSPDarkNet, self).__init__()
|
131 |
+
self.inplanes = 32
|
132 |
+
# 416,416,3 -> 416,416,32
|
133 |
+
self.conv1 = BasicConv(3, self.inplanes, kernel_size=3, stride=1)
|
134 |
+
self.feature_channels = [64, 128, 256, 512, 1024]
|
135 |
+
|
136 |
+
self.stages = nn.ModuleList([
|
137 |
+
# 416,416,32 -> 208,208,64
|
138 |
+
Resblock_body(self.inplanes, self.feature_channels[0], layers[0], first=True),
|
139 |
+
# 208,208,64 -> 104,104,128
|
140 |
+
Resblock_body(self.feature_channels[0], self.feature_channels[1], layers[1], first=False),
|
141 |
+
# 104,104,128 -> 52,52,256
|
142 |
+
Resblock_body(self.feature_channels[1], self.feature_channels[2], layers[2], first=False),
|
143 |
+
# 52,52,256 -> 26,26,512
|
144 |
+
Resblock_body(self.feature_channels[2], self.feature_channels[3], layers[3], first=False),
|
145 |
+
# 26,26,512 -> 13,13,1024
|
146 |
+
Resblock_body(self.feature_channels[3], self.feature_channels[4], layers[4], first=False)
|
147 |
+
])
|
148 |
+
|
149 |
+
self.num_features = 1
|
150 |
+
for m in self.modules():
|
151 |
+
if isinstance(m, nn.Conv2d):
|
152 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
153 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
154 |
+
elif isinstance(m, nn.BatchNorm2d):
|
155 |
+
m.weight.data.fill_(1)
|
156 |
+
m.bias.data.zero_()
|
157 |
+
|
158 |
+
|
159 |
+
def forward(self, x):
|
160 |
+
x = self.conv1(x)
|
161 |
+
|
162 |
+
x = self.stages[0](x)
|
163 |
+
x = self.stages[1](x)
|
164 |
+
out3 = self.stages[2](x)
|
165 |
+
out4 = self.stages[3](out3)
|
166 |
+
out5 = self.stages[4](out4)
|
167 |
+
|
168 |
+
return out3, out4, out5
|
169 |
+
|
170 |
+
def darknet53(pretrained):
|
171 |
+
model = CSPDarkNet([1, 2, 8, 8, 4])
|
172 |
+
if pretrained:
|
173 |
+
model.load_state_dict(torch.load("model_data/CSPdarknet53_backbone_weights.pth"))
|
174 |
+
return model
|
nets/CSPdarknet53_tiny.py
ADDED
@@ -0,0 +1,143 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
|
7 |
+
#-------------------------------------------------#
|
8 |
+
# 卷积块
|
9 |
+
# Conv2d + BatchNorm2d + LeakyReLU
|
10 |
+
#-------------------------------------------------#
|
11 |
+
class BasicConv(nn.Module):
|
12 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
|
13 |
+
super(BasicConv, self).__init__()
|
14 |
+
|
15 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size//2, bias=False)
|
16 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
17 |
+
self.activation = nn.LeakyReLU(0.1)
|
18 |
+
|
19 |
+
def forward(self, x):
|
20 |
+
x = self.conv(x)
|
21 |
+
x = self.bn(x)
|
22 |
+
x = self.activation(x)
|
23 |
+
return x
|
24 |
+
|
25 |
+
|
26 |
+
'''
|
27 |
+
input
|
28 |
+
|
|
29 |
+
BasicConv
|
30 |
+
-----------------------
|
31 |
+
| |
|
32 |
+
route_group route
|
33 |
+
| |
|
34 |
+
BasicConv |
|
35 |
+
| |
|
36 |
+
------------------- |
|
37 |
+
| | |
|
38 |
+
route_1 BasicConv |
|
39 |
+
| | |
|
40 |
+
-----------------cat |
|
41 |
+
| |
|
42 |
+
---- BasicConv |
|
43 |
+
| | |
|
44 |
+
feat cat---------------------
|
45 |
+
|
|
46 |
+
MaxPooling2D
|
47 |
+
'''
|
48 |
+
#---------------------------------------------------#
|
49 |
+
# CSPdarknet53-tiny的结构块
|
50 |
+
# 存在一个大残差边
|
51 |
+
# 这个大残差边绕过了很多的残差结构
|
52 |
+
#---------------------------------------------------#
|
53 |
+
class Resblock_body(nn.Module):
|
54 |
+
def __init__(self, in_channels, out_channels):
|
55 |
+
super(Resblock_body, self).__init__()
|
56 |
+
self.out_channels = out_channels
|
57 |
+
|
58 |
+
self.conv1 = BasicConv(in_channels, out_channels, 3)
|
59 |
+
|
60 |
+
self.conv2 = BasicConv(out_channels//2, out_channels//2, 3)
|
61 |
+
self.conv3 = BasicConv(out_channels//2, out_channels//2, 3)
|
62 |
+
|
63 |
+
self.conv4 = BasicConv(out_channels, out_channels, 1)
|
64 |
+
self.maxpool = nn.MaxPool2d([2,2],[2,2])
|
65 |
+
|
66 |
+
def forward(self, x):
|
67 |
+
# 利用一个3x3卷积进行特征整合
|
68 |
+
x = self.conv1(x)
|
69 |
+
# 引出一个大的残差边route
|
70 |
+
route = x
|
71 |
+
|
72 |
+
c = self.out_channels
|
73 |
+
# 对特征层的通道进行分割,取第二部分作为主干部分。
|
74 |
+
x = torch.split(x, c//2, dim = 1)[1]
|
75 |
+
# 对主干部分进行3x3卷积
|
76 |
+
x = self.conv2(x)
|
77 |
+
# 引出一个小的残差边route_1
|
78 |
+
route1 = x
|
79 |
+
# 对第主干部分进行3x3卷积
|
80 |
+
x = self.conv3(x)
|
81 |
+
# 主干部分与残差部分进行相接
|
82 |
+
x = torch.cat([x,route1], dim = 1)
|
83 |
+
|
84 |
+
# 对相接后的结果进行1x1卷积
|
85 |
+
x = self.conv4(x)
|
86 |
+
feat = x
|
87 |
+
x = torch.cat([route, x], dim = 1)
|
88 |
+
|
89 |
+
# 利用最大池化进行高和宽的压缩
|
90 |
+
x = self.maxpool(x)
|
91 |
+
return x,feat
|
92 |
+
|
93 |
+
class CSPDarkNet(nn.Module):
|
94 |
+
def __init__(self):
|
95 |
+
super(CSPDarkNet, self).__init__()
|
96 |
+
# 首先利用两次步长为2x2的3x3卷积进行高和宽的压缩
|
97 |
+
# 416,416,3 -> 208,208,32 -> 104,104,64
|
98 |
+
self.conv1 = BasicConv(3, 32, kernel_size=3, stride=2)
|
99 |
+
self.conv2 = BasicConv(32, 64, kernel_size=3, stride=2)
|
100 |
+
|
101 |
+
# 104,104,64 -> 52,52,128
|
102 |
+
self.resblock_body1 = Resblock_body(64, 64)
|
103 |
+
# 52,52,128 -> 26,26,256
|
104 |
+
self.resblock_body2 = Resblock_body(128, 128)
|
105 |
+
# 26,26,256 -> 13,13,512
|
106 |
+
self.resblock_body3 = Resblock_body(256, 256)
|
107 |
+
# 13,13,512 -> 13,13,512
|
108 |
+
self.conv3 = BasicConv(512, 512, kernel_size=3)
|
109 |
+
|
110 |
+
self.num_features = 1
|
111 |
+
# 进行权值初始化
|
112 |
+
for m in self.modules():
|
113 |
+
if isinstance(m, nn.Conv2d):
|
114 |
+
n = m.kernel_size[0] * m.kernel_size[1] * m.out_channels
|
115 |
+
m.weight.data.normal_(0, math.sqrt(2. / n))
|
116 |
+
elif isinstance(m, nn.BatchNorm2d):
|
117 |
+
m.weight.data.fill_(1)
|
118 |
+
m.bias.data.zero_()
|
119 |
+
|
120 |
+
|
121 |
+
def forward(self, x):
|
122 |
+
# 416,416,3 -> 208,208,32 -> 104,104,64
|
123 |
+
x = self.conv1(x)
|
124 |
+
x = self.conv2(x)
|
125 |
+
|
126 |
+
# 104,104,64 -> 52,52,128
|
127 |
+
x, _ = self.resblock_body1(x)
|
128 |
+
# 52,52,128 -> 26,26,256
|
129 |
+
x, _ = self.resblock_body2(x)
|
130 |
+
# 26,26,256 -> x为13,13,512
|
131 |
+
# -> feat1为26,26,256
|
132 |
+
x, feat1 = self.resblock_body3(x)
|
133 |
+
|
134 |
+
# 13,13,512 -> 13,13,512
|
135 |
+
x = self.conv3(x)
|
136 |
+
feat2 = x
|
137 |
+
return feat1,feat2
|
138 |
+
|
139 |
+
def darknet53_tiny(pretrained, **kwargs):
|
140 |
+
model = CSPDarkNet()
|
141 |
+
if pretrained:
|
142 |
+
model.load_state_dict(torch.load("model_data/CSPdarknet53_tiny_backbone_weights.pth"))
|
143 |
+
return model
|
nets/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#
|
nets/attention.py
ADDED
@@ -0,0 +1,114 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
import math
|
4 |
+
|
5 |
+
class se_block(nn.Module):
|
6 |
+
def __init__(self, channel, ratio=16):
|
7 |
+
super(se_block, self).__init__()
|
8 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
9 |
+
self.fc = nn.Sequential(
|
10 |
+
nn.Linear(channel, channel // ratio, bias=False),
|
11 |
+
nn.ReLU(inplace=True),
|
12 |
+
nn.Linear(channel // ratio, channel, bias=False),
|
13 |
+
nn.Sigmoid()
|
14 |
+
)
|
15 |
+
|
16 |
+
def forward(self, x):
|
17 |
+
b, c, _, _ = x.size()
|
18 |
+
y = self.avg_pool(x).view(b, c)
|
19 |
+
y = self.fc(y).view(b, c, 1, 1)
|
20 |
+
return x * y
|
21 |
+
|
22 |
+
class ChannelAttention(nn.Module):
|
23 |
+
def __init__(self, in_planes, ratio=8):
|
24 |
+
super(ChannelAttention, self).__init__()
|
25 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
26 |
+
self.max_pool = nn.AdaptiveMaxPool2d(1)
|
27 |
+
|
28 |
+
# 利用1x1卷积代替全连接
|
29 |
+
self.fc1 = nn.Conv2d(in_planes, in_planes // ratio, 1, bias=False)
|
30 |
+
self.relu1 = nn.ReLU()
|
31 |
+
self.fc2 = nn.Conv2d(in_planes // ratio, in_planes, 1, bias=False)
|
32 |
+
|
33 |
+
self.sigmoid = nn.Sigmoid()
|
34 |
+
|
35 |
+
def forward(self, x):
|
36 |
+
avg_out = self.fc2(self.relu1(self.fc1(self.avg_pool(x))))
|
37 |
+
max_out = self.fc2(self.relu1(self.fc1(self.max_pool(x))))
|
38 |
+
out = avg_out + max_out
|
39 |
+
return self.sigmoid(out)
|
40 |
+
|
41 |
+
class SpatialAttention(nn.Module):
|
42 |
+
def __init__(self, kernel_size=7):
|
43 |
+
super(SpatialAttention, self).__init__()
|
44 |
+
|
45 |
+
assert kernel_size in (3, 7), 'kernel size must be 3 or 7'
|
46 |
+
padding = 3 if kernel_size == 7 else 1
|
47 |
+
self.conv1 = nn.Conv2d(2, 1, kernel_size, padding=padding, bias=False)
|
48 |
+
self.sigmoid = nn.Sigmoid()
|
49 |
+
|
50 |
+
def forward(self, x):
|
51 |
+
avg_out = torch.mean(x, dim=1, keepdim=True)
|
52 |
+
max_out, _ = torch.max(x, dim=1, keepdim=True)
|
53 |
+
x = torch.cat([avg_out, max_out], dim=1)
|
54 |
+
x = self.conv1(x)
|
55 |
+
return self.sigmoid(x)
|
56 |
+
|
57 |
+
class cbam_block(nn.Module):
|
58 |
+
def __init__(self, channel, ratio=8, kernel_size=7):
|
59 |
+
super(cbam_block, self).__init__()
|
60 |
+
self.channelattention = ChannelAttention(channel, ratio=ratio)
|
61 |
+
self.spatialattention = SpatialAttention(kernel_size=kernel_size)
|
62 |
+
|
63 |
+
def forward(self, x):
|
64 |
+
x = x*self.channelattention(x)
|
65 |
+
x = x*self.spatialattention(x)
|
66 |
+
return x
|
67 |
+
|
68 |
+
class eca_block(nn.Module):
|
69 |
+
def __init__(self, channel, b=1, gamma=2):
|
70 |
+
super(eca_block, self).__init__()
|
71 |
+
kernel_size = int(abs((math.log(channel, 2) + b) / gamma))
|
72 |
+
kernel_size = kernel_size if kernel_size % 2 else kernel_size + 1
|
73 |
+
|
74 |
+
self.avg_pool = nn.AdaptiveAvgPool2d(1)
|
75 |
+
self.conv = nn.Conv1d(1, 1, kernel_size=kernel_size, padding=(kernel_size - 1) // 2, bias=False)
|
76 |
+
self.sigmoid = nn.Sigmoid()
|
77 |
+
|
78 |
+
def forward(self, x):
|
79 |
+
y = self.avg_pool(x)
|
80 |
+
y = self.conv(y.squeeze(-1).transpose(-1, -2)).transpose(-1, -2).unsqueeze(-1)
|
81 |
+
y = self.sigmoid(y)
|
82 |
+
return x * y.expand_as(x)
|
83 |
+
|
84 |
+
class CA_Block(nn.Module):
|
85 |
+
def __init__(self, channel, reduction=16):
|
86 |
+
super(CA_Block, self).__init__()
|
87 |
+
|
88 |
+
self.conv_1x1 = nn.Conv2d(in_channels=channel, out_channels=channel//reduction, kernel_size=1, stride=1, bias=False)
|
89 |
+
|
90 |
+
self.relu = nn.ReLU()
|
91 |
+
self.bn = nn.BatchNorm2d(channel//reduction)
|
92 |
+
|
93 |
+
self.F_h = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)
|
94 |
+
self.F_w = nn.Conv2d(in_channels=channel//reduction, out_channels=channel, kernel_size=1, stride=1, bias=False)
|
95 |
+
|
96 |
+
self.sigmoid_h = nn.Sigmoid()
|
97 |
+
self.sigmoid_w = nn.Sigmoid()
|
98 |
+
|
99 |
+
def forward(self, x):
|
100 |
+
_, _, h, w = x.size()
|
101 |
+
|
102 |
+
x_h = torch.mean(x, dim = 3, keepdim = True).permute(0, 1, 3, 2)
|
103 |
+
x_w = torch.mean(x, dim = 2, keepdim = True)
|
104 |
+
|
105 |
+
x_cat_conv_relu = self.relu(self.bn(self.conv_1x1(torch.cat((x_h, x_w), 3))))
|
106 |
+
|
107 |
+
x_cat_conv_split_h, x_cat_conv_split_w = x_cat_conv_relu.split([h, w], 3)
|
108 |
+
|
109 |
+
s_h = self.sigmoid_h(self.F_h(x_cat_conv_split_h.permute(0, 1, 3, 2)))
|
110 |
+
s_w = self.sigmoid_w(self.F_w(x_cat_conv_split_w))
|
111 |
+
|
112 |
+
out = x * s_h.expand_as(x) * s_w.expand_as(x)
|
113 |
+
return out
|
114 |
+
|
nets/yolo.py
ADDED
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from collections import OrderedDict
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
|
6 |
+
from nets.CSPdarknet import darknet53
|
7 |
+
|
8 |
+
|
9 |
+
def conv2d(filter_in, filter_out, kernel_size, stride=1):
|
10 |
+
pad = (kernel_size - 1) // 2 if kernel_size else 0
|
11 |
+
return nn.Sequential(OrderedDict([
|
12 |
+
("conv", nn.Conv2d(filter_in, filter_out, kernel_size=kernel_size, stride=stride, padding=pad, bias=False)),
|
13 |
+
("bn", nn.BatchNorm2d(filter_out)),
|
14 |
+
("relu", nn.LeakyReLU(0.1)),
|
15 |
+
]))
|
16 |
+
|
17 |
+
#---------------------------------------------------#
|
18 |
+
# SPP结构,利用不同大小的池化核进行池化
|
19 |
+
# 池化后堆叠
|
20 |
+
#---------------------------------------------------#
|
21 |
+
class SpatialPyramidPooling(nn.Module):
|
22 |
+
def __init__(self, pool_sizes=[5, 9, 13]):
|
23 |
+
super(SpatialPyramidPooling, self).__init__()
|
24 |
+
|
25 |
+
self.maxpools = nn.ModuleList([nn.MaxPool2d(pool_size, 1, pool_size//2) for pool_size in pool_sizes])
|
26 |
+
|
27 |
+
def forward(self, x):
|
28 |
+
features = [maxpool(x) for maxpool in self.maxpools[::-1]]
|
29 |
+
features = torch.cat(features + [x], dim=1)
|
30 |
+
|
31 |
+
return features
|
32 |
+
|
33 |
+
#---------------------------------------------------#
|
34 |
+
# 卷积 + 上采样
|
35 |
+
#---------------------------------------------------#
|
36 |
+
class Upsample(nn.Module):
|
37 |
+
def __init__(self, in_channels, out_channels):
|
38 |
+
super(Upsample, self).__init__()
|
39 |
+
|
40 |
+
self.upsample = nn.Sequential(
|
41 |
+
conv2d(in_channels, out_channels, 1),
|
42 |
+
nn.Upsample(scale_factor=2, mode='nearest')
|
43 |
+
)
|
44 |
+
|
45 |
+
def forward(self, x,):
|
46 |
+
x = self.upsample(x)
|
47 |
+
return x
|
48 |
+
|
49 |
+
#---------------------------------------------------#
|
50 |
+
# 三次卷积块
|
51 |
+
#---------------------------------------------------#
|
52 |
+
def make_three_conv(filters_list, in_filters):
|
53 |
+
m = nn.Sequential(
|
54 |
+
conv2d(in_filters, filters_list[0], 1),
|
55 |
+
conv2d(filters_list[0], filters_list[1], 3),
|
56 |
+
conv2d(filters_list[1], filters_list[0], 1),
|
57 |
+
)
|
58 |
+
return m
|
59 |
+
|
60 |
+
#---------------------------------------------------#
|
61 |
+
# 五次卷积块
|
62 |
+
#---------------------------------------------------#
|
63 |
+
def make_five_conv(filters_list, in_filters):
|
64 |
+
m = nn.Sequential(
|
65 |
+
conv2d(in_filters, filters_list[0], 1),
|
66 |
+
conv2d(filters_list[0], filters_list[1], 3),
|
67 |
+
conv2d(filters_list[1], filters_list[0], 1),
|
68 |
+
conv2d(filters_list[0], filters_list[1], 3),
|
69 |
+
conv2d(filters_list[1], filters_list[0], 1),
|
70 |
+
)
|
71 |
+
return m
|
72 |
+
|
73 |
+
#---------------------------------------------------#
|
74 |
+
# 最后获得yolov4的输出
|
75 |
+
#---------------------------------------------------#
|
76 |
+
def yolo_head(filters_list, in_filters):
|
77 |
+
m = nn.Sequential(
|
78 |
+
conv2d(in_filters, filters_list[0], 3),
|
79 |
+
nn.Conv2d(filters_list[0], filters_list[1], 1),
|
80 |
+
)
|
81 |
+
return m
|
82 |
+
|
83 |
+
#---------------------------------------------------#
|
84 |
+
# yolo_body
|
85 |
+
#---------------------------------------------------#
|
86 |
+
class YoloBody(nn.Module):
|
87 |
+
def __init__(self, anchors_mask, num_classes, pretrained = False):
|
88 |
+
super(YoloBody, self).__init__()
|
89 |
+
#---------------------------------------------------#
|
90 |
+
# 生成CSPdarknet53的主干模型
|
91 |
+
# 获得三个有效特征层,他们的shape分别是:
|
92 |
+
# 52,52,256
|
93 |
+
# 26,26,512
|
94 |
+
# 13,13,1024
|
95 |
+
#---------------------------------------------------#
|
96 |
+
self.backbone = darknet53(pretrained)
|
97 |
+
|
98 |
+
self.conv1 = make_three_conv([512,1024],1024)
|
99 |
+
self.SPP = SpatialPyramidPooling()
|
100 |
+
self.conv2 = make_three_conv([512,1024],2048)
|
101 |
+
|
102 |
+
self.upsample1 = Upsample(512,256)
|
103 |
+
self.conv_for_P4 = conv2d(512,256,1)
|
104 |
+
self.make_five_conv1 = make_five_conv([256, 512],512)
|
105 |
+
|
106 |
+
self.upsample2 = Upsample(256,128)
|
107 |
+
self.conv_for_P3 = conv2d(256,128,1)
|
108 |
+
self.make_five_conv2 = make_five_conv([128, 256],256)
|
109 |
+
|
110 |
+
# 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
|
111 |
+
self.yolo_head3 = yolo_head([256, len(anchors_mask[0]) * (5 + num_classes)],128)
|
112 |
+
|
113 |
+
self.down_sample1 = conv2d(128,256,3,stride=2)
|
114 |
+
self.make_five_conv3 = make_five_conv([256, 512],512)
|
115 |
+
|
116 |
+
# 3*(5+num_classes) = 3*(5+20) = 3*(4+1+20)=75
|
117 |
+
self.yolo_head2 = yolo_head([512, len(anchors_mask[1]) * (5 + num_classes)],256)
|
118 |
+
|
119 |
+
self.down_sample2 = conv2d(256,512,3,stride=2)
|
120 |
+
self.make_five_conv4 = make_five_conv([512, 1024],1024)
|
121 |
+
|
122 |
+
# 3*(5+num_classes)=3*(5+20)=3*(4+1+20)=75
|
123 |
+
self.yolo_head1 = yolo_head([1024, len(anchors_mask[2]) * (5 + num_classes)],512)
|
124 |
+
|
125 |
+
|
126 |
+
def forward(self, x):
|
127 |
+
# backbone
|
128 |
+
x2, x1, x0 = self.backbone(x)
|
129 |
+
|
130 |
+
# 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,2048
|
131 |
+
P5 = self.conv1(x0)
|
132 |
+
P5 = self.SPP(P5)
|
133 |
+
# 13,13,2048 -> 13,13,512 -> 13,13,1024 -> 13,13,512
|
134 |
+
P5 = self.conv2(P5)
|
135 |
+
|
136 |
+
# 13,13,512 -> 13,13,256 -> 26,26,256
|
137 |
+
P5_upsample = self.upsample1(P5)
|
138 |
+
# 26,26,512 -> 26,26,256
|
139 |
+
P4 = self.conv_for_P4(x1)
|
140 |
+
# 26,26,256 + 26,26,256 -> 26,26,512
|
141 |
+
P4 = torch.cat([P4,P5_upsample],axis=1)
|
142 |
+
# 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
|
143 |
+
P4 = self.make_five_conv1(P4)
|
144 |
+
|
145 |
+
# 26,26,256 -> 26,26,128 -> 52,52,128
|
146 |
+
P4_upsample = self.upsample2(P4)
|
147 |
+
# 52,52,256 -> 52,52,128
|
148 |
+
P3 = self.conv_for_P3(x2)
|
149 |
+
# 52,52,128 + 52,52,128 -> 52,52,256
|
150 |
+
P3 = torch.cat([P3,P4_upsample],axis=1)
|
151 |
+
# 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128 -> 52,52,256 -> 52,52,128
|
152 |
+
P3 = self.make_five_conv2(P3)
|
153 |
+
|
154 |
+
# 52,52,128 -> 26,26,256
|
155 |
+
P3_downsample = self.down_sample1(P3)
|
156 |
+
# 26,26,256 + 26,26,256 -> 26,26,512
|
157 |
+
P4 = torch.cat([P3_downsample,P4],axis=1)
|
158 |
+
# 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256 -> 26,26,512 -> 26,26,256
|
159 |
+
P4 = self.make_five_conv3(P4)
|
160 |
+
|
161 |
+
# 26,26,256 -> 13,13,512
|
162 |
+
P4_downsample = self.down_sample2(P4)
|
163 |
+
# 13,13,512 + 13,13,512 -> 13,13,1024
|
164 |
+
P5 = torch.cat([P4_downsample,P5],axis=1)
|
165 |
+
# 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512 -> 13,13,1024 -> 13,13,512
|
166 |
+
P5 = self.make_five_conv4(P5)
|
167 |
+
|
168 |
+
#---------------------------------------------------#
|
169 |
+
# 第三个特征层
|
170 |
+
# y3=(batch_size,75,52,52)
|
171 |
+
#---------------------------------------------------#
|
172 |
+
out2 = self.yolo_head3(P3)
|
173 |
+
#---------------------------------------------------#
|
174 |
+
# 第二个特征层
|
175 |
+
# y2=(batch_size,75,26,26)
|
176 |
+
#---------------------------------------------------#
|
177 |
+
out1 = self.yolo_head2(P4)
|
178 |
+
#---------------------------------------------------#
|
179 |
+
# 第一个特征层
|
180 |
+
# y1=(batch_size,75,13,13)
|
181 |
+
#---------------------------------------------------#
|
182 |
+
out0 = self.yolo_head1(P5)
|
183 |
+
|
184 |
+
return out0, out1, out2
|
185 |
+
|
nets/yolo_tiny.py
ADDED
@@ -0,0 +1,99 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
|
4 |
+
from nets.CSPdarknet53_tiny import darknet53_tiny
|
5 |
+
from nets.attention import cbam_block, eca_block, se_block, CA_Block
|
6 |
+
|
7 |
+
attention_block = [se_block, cbam_block, eca_block, CA_Block]
|
8 |
+
|
9 |
+
#-------------------------------------------------#
|
10 |
+
# 卷积块 -> 卷积 + 标准化 + 激活函数
|
11 |
+
# Conv2d + BatchNormalization + LeakyReLU
|
12 |
+
#-------------------------------------------------#
|
13 |
+
class BasicConv(nn.Module):
|
14 |
+
def __init__(self, in_channels, out_channels, kernel_size, stride=1):
|
15 |
+
super(BasicConv, self).__init__()
|
16 |
+
|
17 |
+
self.conv = nn.Conv2d(in_channels, out_channels, kernel_size, stride, kernel_size//2, bias=False)
|
18 |
+
self.bn = nn.BatchNorm2d(out_channels)
|
19 |
+
self.activation = nn.LeakyReLU(0.1)
|
20 |
+
|
21 |
+
def forward(self, x):
|
22 |
+
x = self.conv(x)
|
23 |
+
x = self.bn(x)
|
24 |
+
x = self.activation(x)
|
25 |
+
return x
|
26 |
+
|
27 |
+
#---------------------------------------------------#
|
28 |
+
# 卷积 + 上采样
|
29 |
+
#---------------------------------------------------#
|
30 |
+
class Upsample(nn.Module):
|
31 |
+
def __init__(self, in_channels, out_channels):
|
32 |
+
super(Upsample, self).__init__()
|
33 |
+
|
34 |
+
self.upsample = nn.Sequential(
|
35 |
+
BasicConv(in_channels, out_channels, 1),
|
36 |
+
nn.Upsample(scale_factor=2, mode='nearest')
|
37 |
+
)
|
38 |
+
|
39 |
+
def forward(self, x,):
|
40 |
+
x = self.upsample(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
#---------------------------------------------------#
|
44 |
+
# 最后获得yolov4的输出
|
45 |
+
#---------------------------------------------------#
|
46 |
+
def yolo_head(filters_list, in_filters):
|
47 |
+
m = nn.Sequential(
|
48 |
+
BasicConv(in_filters, filters_list[0], 3),
|
49 |
+
nn.Conv2d(filters_list[0], filters_list[1], 1),
|
50 |
+
)
|
51 |
+
return m
|
52 |
+
#---------------------------------------------------#
|
53 |
+
# yolo_body
|
54 |
+
#---------------------------------------------------#
|
55 |
+
class YoloBodytiny(nn.Module):
|
56 |
+
def __init__(self, anchors_mask, num_classes, phi=0, pretrained=False):
|
57 |
+
super(YoloBodytiny, self).__init__()
|
58 |
+
self.phi = phi
|
59 |
+
self.backbone = darknet53_tiny(pretrained)
|
60 |
+
|
61 |
+
self.conv_for_P5 = BasicConv(512,256,1)
|
62 |
+
self.yolo_headP5 = yolo_head([512, len(anchors_mask[0]) * (5 + num_classes)],256)
|
63 |
+
|
64 |
+
self.upsample = Upsample(256,128)
|
65 |
+
self.yolo_headP4 = yolo_head([256, len(anchors_mask[1]) * (5 + num_classes)],384)
|
66 |
+
|
67 |
+
if 1 <= self.phi and self.phi <= 4:
|
68 |
+
self.feat1_att = attention_block[self.phi - 1](256)
|
69 |
+
self.feat2_att = attention_block[self.phi - 1](512)
|
70 |
+
self.upsample_att = attention_block[self.phi - 1](128)
|
71 |
+
|
72 |
+
def forward(self, x):
|
73 |
+
#---------------------------------------------------#
|
74 |
+
# 生成CSPdarknet53_tiny的主干模型
|
75 |
+
# feat1的shape为26,26,256
|
76 |
+
# feat2的shape为13,13,512
|
77 |
+
#---------------------------------------------------#
|
78 |
+
feat1, feat2 = self.backbone(x)
|
79 |
+
if 1 <= self.phi and self.phi <= 4:
|
80 |
+
feat1 = self.feat1_att(feat1)
|
81 |
+
feat2 = self.feat2_att(feat2)
|
82 |
+
|
83 |
+
# 13,13,512 -> 13,13,256
|
84 |
+
P5 = self.conv_for_P5(feat2)
|
85 |
+
# 13,13,256 -> 13,13,512 -> 13,13,255
|
86 |
+
out0 = self.yolo_headP5(P5)
|
87 |
+
|
88 |
+
# 13,13,256 -> 13,13,128 -> 26,26,128
|
89 |
+
P5_Upsample = self.upsample(P5)
|
90 |
+
# 26,26,256 + 26,26,128 -> 26,26,384
|
91 |
+
if 1 <= self.phi and self.phi <= 4:
|
92 |
+
P5_Upsample = self.upsample_att(P5_Upsample)
|
93 |
+
P4 = torch.cat([P5_Upsample,feat1],axis=1)
|
94 |
+
|
95 |
+
# 26,26,384 -> 26,26,256 -> 26,26,255
|
96 |
+
out1 = self.yolo_headP4(P4)
|
97 |
+
|
98 |
+
return out0, out1
|
99 |
+
|
nets/yolo_training.py
ADDED
@@ -0,0 +1,476 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
|
9 |
+
class YOLOLoss(nn.Module):
|
10 |
+
def __init__(self, anchors, num_classes, input_shape, cuda, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0, focal_loss = False, alpha = 0.25, gamma = 2):
|
11 |
+
super(YOLOLoss, self).__init__()
|
12 |
+
#-----------------------------------------------------------#
|
13 |
+
# 13x13的特征层对应的anchor是[142, 110],[192, 243],[459, 401]
|
14 |
+
# 26x26的特征层对应的anchor是[36, 75],[76, 55],[72, 146]
|
15 |
+
# 52x52的特征层对应的anchor是[12, 16],[19, 36],[40, 28]
|
16 |
+
#-----------------------------------------------------------#
|
17 |
+
self.anchors = anchors
|
18 |
+
self.num_classes = num_classes
|
19 |
+
self.bbox_attrs = 5 + num_classes
|
20 |
+
self.input_shape = input_shape
|
21 |
+
self.anchors_mask = anchors_mask
|
22 |
+
self.label_smoothing = label_smoothing
|
23 |
+
|
24 |
+
self.balance = [0.4, 1.0, 4]
|
25 |
+
self.box_ratio = 0.05
|
26 |
+
self.obj_ratio = 5 * (input_shape[0] * input_shape[1]) / (416 ** 2)
|
27 |
+
self.cls_ratio = 1 * (num_classes / 80)
|
28 |
+
|
29 |
+
self.focal_loss = focal_loss
|
30 |
+
self.focal_loss_ratio = 10
|
31 |
+
self.alpha = alpha
|
32 |
+
self.gamma = gamma
|
33 |
+
|
34 |
+
self.ignore_threshold = 0.5
|
35 |
+
self.cuda = cuda
|
36 |
+
|
37 |
+
def clip_by_tensor(self, t, t_min, t_max):
|
38 |
+
t = t.float()
|
39 |
+
result = (t >= t_min).float() * t + (t < t_min).float() * t_min
|
40 |
+
result = (result <= t_max).float() * result + (result > t_max).float() * t_max
|
41 |
+
return result
|
42 |
+
|
43 |
+
def MSELoss(self, pred, target):
|
44 |
+
return torch.pow(pred - target, 2)
|
45 |
+
|
46 |
+
def BCELoss(self, pred, target):
|
47 |
+
epsilon = 1e-7
|
48 |
+
pred = self.clip_by_tensor(pred, epsilon, 1.0 - epsilon)
|
49 |
+
output = - target * torch.log(pred) - (1.0 - target) * torch.log(1.0 - pred)
|
50 |
+
return output
|
51 |
+
|
52 |
+
def box_ciou(self, b1, b2):
|
53 |
+
"""
|
54 |
+
输入为:
|
55 |
+
----------
|
56 |
+
b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
|
57 |
+
b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
|
58 |
+
|
59 |
+
返回为:
|
60 |
+
-------
|
61 |
+
ciou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1)
|
62 |
+
"""
|
63 |
+
#----------------------------------------------------#
|
64 |
+
# 求出预测框左上角右下角
|
65 |
+
#----------------------------------------------------#
|
66 |
+
b1_xy = b1[..., :2]
|
67 |
+
b1_wh = b1[..., 2:4]
|
68 |
+
b1_wh_half = b1_wh/2.
|
69 |
+
b1_mins = b1_xy - b1_wh_half
|
70 |
+
b1_maxes = b1_xy + b1_wh_half
|
71 |
+
#----------------------------------------------------#
|
72 |
+
# 求出真实框左上角右下角
|
73 |
+
#----------------------------------------------------#
|
74 |
+
b2_xy = b2[..., :2]
|
75 |
+
b2_wh = b2[..., 2:4]
|
76 |
+
b2_wh_half = b2_wh/2.
|
77 |
+
b2_mins = b2_xy - b2_wh_half
|
78 |
+
b2_maxes = b2_xy + b2_wh_half
|
79 |
+
|
80 |
+
#----------------------------------------------------#
|
81 |
+
# 求真实框和预测框所有的iou
|
82 |
+
#----------------------------------------------------#
|
83 |
+
intersect_mins = torch.max(b1_mins, b2_mins)
|
84 |
+
intersect_maxes = torch.min(b1_maxes, b2_maxes)
|
85 |
+
intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
|
86 |
+
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
|
87 |
+
b1_area = b1_wh[..., 0] * b1_wh[..., 1]
|
88 |
+
b2_area = b2_wh[..., 0] * b2_wh[..., 1]
|
89 |
+
union_area = b1_area + b2_area - intersect_area
|
90 |
+
iou = intersect_area / torch.clamp(union_area,min = 1e-6)
|
91 |
+
|
92 |
+
#----------------------------------------------------#
|
93 |
+
# 计算中心的差距
|
94 |
+
#----------------------------------------------------#
|
95 |
+
center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)
|
96 |
+
|
97 |
+
#----------------------------------------------------#
|
98 |
+
# 找到包裹两个框的最小框的左上角和右下角
|
99 |
+
#----------------------------------------------------#
|
100 |
+
enclose_mins = torch.min(b1_mins, b2_mins)
|
101 |
+
enclose_maxes = torch.max(b1_maxes, b2_maxes)
|
102 |
+
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
|
103 |
+
#----------------------------------------------------#
|
104 |
+
# 计算对角线距离
|
105 |
+
#----------------------------------------------------#
|
106 |
+
enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
|
107 |
+
ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
|
108 |
+
|
109 |
+
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0] / torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0] / torch.clamp(b2_wh[..., 1], min = 1e-6))), 2)
|
110 |
+
alpha = v / torch.clamp((1.0 - iou + v), min=1e-6)
|
111 |
+
ciou = ciou - alpha * v
|
112 |
+
return ciou
|
113 |
+
|
114 |
+
#---------------------------------------------------#
|
115 |
+
# 平滑标签
|
116 |
+
#---------------------------------------------------#
|
117 |
+
def smooth_labels(self, y_true, label_smoothing, num_classes):
|
118 |
+
return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes
|
119 |
+
|
120 |
+
def forward(self, l, input, targets=None):
|
121 |
+
#----------------------------------------------------#
|
122 |
+
# l 代表使用的是第几个有效特征层
|
123 |
+
# input的shape为 bs, 3*(5+num_classes), 13, 13
|
124 |
+
# bs, 3*(5+num_classes), 26, 26
|
125 |
+
# bs, 3*(5+num_classes), 52, 52
|
126 |
+
# targets 真实框的标签情况 [batch_size, num_gt, 5]
|
127 |
+
#----------------------------------------------------#
|
128 |
+
#--------------------------------#
|
129 |
+
# 获得图片数量,特征层的高和宽
|
130 |
+
#--------------------------------#
|
131 |
+
bs = input.size(0)
|
132 |
+
in_h = input.size(2)
|
133 |
+
in_w = input.size(3)
|
134 |
+
#-----------------------------------------------------------------------#
|
135 |
+
# 计算步长
|
136 |
+
# 每一个特征点对应原来的图片上多少个像素点
|
137 |
+
#
|
138 |
+
# 如果特征层为13x13的话,一个特征点就对应原来的图片上的32个像素点
|
139 |
+
# 如果特征层为26x26的话,一个特征点就对应原来的图片上的16个像素点
|
140 |
+
# 如果特征层为52x52的话,一个特征点就对应原来的图片上的8个像素点
|
141 |
+
# stride_h = stride_w = 32、16、8
|
142 |
+
#-----------------------------------------------------------------------#
|
143 |
+
stride_h = self.input_shape[0] / in_h
|
144 |
+
stride_w = self.input_shape[1] / in_w
|
145 |
+
#-------------------------------------------------#
|
146 |
+
# 此时获得的scaled_anchors大小是相对于特征层的
|
147 |
+
#-------------------------------------------------#
|
148 |
+
scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors]
|
149 |
+
#-----------------------------------------------#
|
150 |
+
# 输入的input一共有三个,他们的shape分别是
|
151 |
+
# bs, 3 * (5+num_classes), 13, 13 => bs, 3, 5 + num_classes, 13, 13 => batch_size, 3, 13, 13, 5 + num_classes
|
152 |
+
|
153 |
+
# batch_size, 3, 13, 13, 5 + num_classes
|
154 |
+
# batch_size, 3, 26, 26, 5 + num_classes
|
155 |
+
# batch_size, 3, 52, 52, 5 + num_classes
|
156 |
+
#-----------------------------------------------#
|
157 |
+
prediction = input.view(bs, len(self.anchors_mask[l]), self.bbox_attrs, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
|
158 |
+
|
159 |
+
#-----------------------------------------------#
|
160 |
+
# 先验框的中心位置的调整参数
|
161 |
+
#-----------------------------------------------#
|
162 |
+
x = torch.sigmoid(prediction[..., 0])
|
163 |
+
y = torch.sigmoid(prediction[..., 1])
|
164 |
+
#-----------------------------------------------#
|
165 |
+
# 先验框的宽高调整参数
|
166 |
+
#-----------------------------------------------#
|
167 |
+
w = prediction[..., 2]
|
168 |
+
h = prediction[..., 3]
|
169 |
+
#-----------------------------------------------#
|
170 |
+
# 获得置信度,是否有物体
|
171 |
+
#-----------------------------------------------#
|
172 |
+
conf = torch.sigmoid(prediction[..., 4])
|
173 |
+
#-----------------------------------------------#
|
174 |
+
# 种类置信度
|
175 |
+
#-----------------------------------------------#
|
176 |
+
pred_cls = torch.sigmoid(prediction[..., 5:])
|
177 |
+
|
178 |
+
#-----------------------------------------------#
|
179 |
+
# 获得网络应该有的预测结果
|
180 |
+
#-----------------------------------------------#
|
181 |
+
y_true, noobj_mask, box_loss_scale = self.get_target(l, targets, scaled_anchors, in_h, in_w)
|
182 |
+
|
183 |
+
#---------------------------------------------------------------#
|
184 |
+
# 将预测结果进行解码,判断预测结果和真实值的重合程度
|
185 |
+
# 如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点
|
186 |
+
# 作为负样本不合适
|
187 |
+
#----------------------------------------------------------------#
|
188 |
+
noobj_mask, pred_boxes = self.get_ignore(l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask)
|
189 |
+
|
190 |
+
if self.cuda:
|
191 |
+
y_true = y_true.type_as(x)
|
192 |
+
noobj_mask = noobj_mask.type_as(x)
|
193 |
+
box_loss_scale = box_loss_scale.type_as(x)
|
194 |
+
#--------------------------------------------------------------------------#
|
195 |
+
# box_loss_scale是真实框宽高的乘积,宽高均在0-1之间,因此乘积也在0-1之间。
|
196 |
+
# 2-宽高的乘积代表真实框越大,比重越小,小框的比重更大。
|
197 |
+
# 使用iou损失时,大中小目标的回归损失不存在比例失衡问题,故弃用
|
198 |
+
#--------------------------------------------------------------------------#
|
199 |
+
box_loss_scale = 2 - box_loss_scale
|
200 |
+
|
201 |
+
loss = 0
|
202 |
+
obj_mask = y_true[..., 4] == 1
|
203 |
+
n = torch.sum(obj_mask)
|
204 |
+
if n != 0:
|
205 |
+
#---------------------------------------------------------------#
|
206 |
+
# 计算预测结果和真实结果的差距
|
207 |
+
# loss_loc ciou回归损失
|
208 |
+
# loss_cls 分类损失
|
209 |
+
#---------------------------------------------------------------#
|
210 |
+
ciou = self.box_ciou(pred_boxes, y_true[..., :4]).type_as(x)
|
211 |
+
# loss_loc = torch.mean((1 - ciou)[obj_mask] * box_loss_scale[obj_mask])
|
212 |
+
loss_loc = torch.mean((1 - ciou)[obj_mask])
|
213 |
+
|
214 |
+
loss_cls = torch.mean(self.BCELoss(pred_cls[obj_mask], y_true[..., 5:][obj_mask]))
|
215 |
+
loss += loss_loc * self.box_ratio + loss_cls * self.cls_ratio
|
216 |
+
|
217 |
+
#---------------------------------------------------------------#
|
218 |
+
# 计算是否包含物体的置信度损失
|
219 |
+
#---------------------------------------------------------------#
|
220 |
+
if self.focal_loss:
|
221 |
+
pos_neg_ratio = torch.where(obj_mask, torch.ones_like(conf) * self.alpha, torch.ones_like(conf) * (1 - self.alpha))
|
222 |
+
hard_easy_ratio = torch.where(obj_mask, torch.ones_like(conf) - conf, conf) ** self.gamma
|
223 |
+
loss_conf = torch.mean((self.BCELoss(conf, obj_mask.type_as(conf)) * pos_neg_ratio * hard_easy_ratio)[noobj_mask.bool() | obj_mask]) * self.focal_loss_ratio
|
224 |
+
else:
|
225 |
+
loss_conf = torch.mean(self.BCELoss(conf, obj_mask.type_as(conf))[noobj_mask.bool() | obj_mask])
|
226 |
+
loss += loss_conf * self.balance[l] * self.obj_ratio
|
227 |
+
# if n != 0:
|
228 |
+
# print(loss_loc * self.box_ratio, loss_cls * self.cls_ratio, loss_conf * self.balance[l] * self.obj_ratio)
|
229 |
+
return loss
|
230 |
+
|
231 |
+
def calculate_iou(self, _box_a, _box_b):
|
232 |
+
#-----------------------------------------------------------#
|
233 |
+
# 计算真实框的左上角和右下角
|
234 |
+
#-----------------------------------------------------------#
|
235 |
+
b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
|
236 |
+
b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
|
237 |
+
#-----------------------------------------------------------#
|
238 |
+
# 计算先验框获得的预测框的左上角和右下角
|
239 |
+
#-----------------------------------------------------------#
|
240 |
+
b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
|
241 |
+
b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2
|
242 |
+
|
243 |
+
#-----------------------------------------------------------#
|
244 |
+
# 将真实框和预测框都转化成左上角右下角的形式
|
245 |
+
#-----------------------------------------------------------#
|
246 |
+
box_a = torch.zeros_like(_box_a)
|
247 |
+
box_b = torch.zeros_like(_box_b)
|
248 |
+
box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
|
249 |
+
box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2
|
250 |
+
|
251 |
+
#-----------------------------------------------------------#
|
252 |
+
# A为真实框的数量,B为先验框的数量
|
253 |
+
#-----------------------------------------------------------#
|
254 |
+
A = box_a.size(0)
|
255 |
+
B = box_b.size(0)
|
256 |
+
|
257 |
+
#-----------------------------------------------------------#
|
258 |
+
# 计算交的面积
|
259 |
+
#-----------------------------------------------------------#
|
260 |
+
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
|
261 |
+
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
|
262 |
+
inter = torch.clamp((max_xy - min_xy), min=0)
|
263 |
+
inter = inter[:, :, 0] * inter[:, :, 1]
|
264 |
+
#-----------------------------------------------------------#
|
265 |
+
# 计算预测框和真实框各自的面积
|
266 |
+
#-----------------------------------------------------------#
|
267 |
+
area_a = ((box_a[:, 2]-box_a[:, 0]) * (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
|
268 |
+
area_b = ((box_b[:, 2]-box_b[:, 0]) * (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
|
269 |
+
#-----------------------------------------------------------#
|
270 |
+
# 求IOU
|
271 |
+
#-----------------------------------------------------------#
|
272 |
+
union = area_a + area_b - inter
|
273 |
+
return inter / union # [A,B]
|
274 |
+
|
275 |
+
def get_target(self, l, targets, anchors, in_h, in_w):
|
276 |
+
#-----------------------------------------------------#
|
277 |
+
# 计算一共有多少张图片
|
278 |
+
#-----------------------------------------------------#
|
279 |
+
bs = len(targets)
|
280 |
+
#-----------------------------------------------------#
|
281 |
+
# 用于选取哪些先验框不包含物体
|
282 |
+
#-----------------------------------------------------#
|
283 |
+
noobj_mask = torch.ones(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
|
284 |
+
#-----------------------------------------------------#
|
285 |
+
# 让网络更加去关注小目标
|
286 |
+
#-----------------------------------------------------#
|
287 |
+
box_loss_scale = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
|
288 |
+
#-----------------------------------------------------#
|
289 |
+
# batch_size, 3, 13, 13, 5 + num_classes
|
290 |
+
#-----------------------------------------------------#
|
291 |
+
y_true = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, self.bbox_attrs, requires_grad = False)
|
292 |
+
for b in range(bs):
|
293 |
+
if len(targets[b])==0:
|
294 |
+
continue
|
295 |
+
batch_target = torch.zeros_like(targets[b])
|
296 |
+
#-------------------------------------------------------#
|
297 |
+
# 计算出正样本在特征层上的中心点
|
298 |
+
#-------------------------------------------------------#
|
299 |
+
batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
|
300 |
+
batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
|
301 |
+
batch_target[:, 4] = targets[b][:, 4]
|
302 |
+
batch_target = batch_target.cpu()
|
303 |
+
|
304 |
+
#-------------------------------------------------------#
|
305 |
+
# 将真实框转换一个形式
|
306 |
+
# num_true_box, 4
|
307 |
+
#-------------------------------------------------------#
|
308 |
+
gt_box = torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
|
309 |
+
#-------------------------------------------------------#
|
310 |
+
# 将先验框转换一个形式
|
311 |
+
# 9, 4
|
312 |
+
#-------------------------------------------------------#
|
313 |
+
anchor_shapes = torch.FloatTensor(torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
|
314 |
+
#-------------------------------------------------------#
|
315 |
+
# 计算交并比
|
316 |
+
# self.calculate_iou(gt_box, anchor_shapes) = [num_true_box, 9]每一个真实框和9个先验框的重合情况
|
317 |
+
# best_ns:
|
318 |
+
# [每个真实框最大的重合度max_iou, 每一个真实框最重合的先验框的序号]
|
319 |
+
#-------------------------------------------------------#
|
320 |
+
best_ns = torch.argmax(self.calculate_iou(gt_box, anchor_shapes), dim=-1)
|
321 |
+
|
322 |
+
for t, best_n in enumerate(best_ns):
|
323 |
+
if best_n not in self.anchors_mask[l]:
|
324 |
+
continue
|
325 |
+
#----------------------------------------#
|
326 |
+
# 判断这个先验框是当前特征点的哪一个先验框
|
327 |
+
#----------------------------------------#
|
328 |
+
k = self.anchors_mask[l].index(best_n)
|
329 |
+
#----------------------------------------#
|
330 |
+
# 获得真实框属于哪个网格点
|
331 |
+
#----------------------------------------#
|
332 |
+
i = torch.floor(batch_target[t, 0]).long()
|
333 |
+
j = torch.floor(batch_target[t, 1]).long()
|
334 |
+
#----------------------------------------#
|
335 |
+
# 取出真实框的种类
|
336 |
+
#----------------------------------------#
|
337 |
+
c = batch_target[t, 4].long()
|
338 |
+
|
339 |
+
#----------------------------------------#
|
340 |
+
# noobj_mask代表无目标的特征点
|
341 |
+
#----------------------------------------#
|
342 |
+
noobj_mask[b, k, j, i] = 0
|
343 |
+
#----------------------------------------#
|
344 |
+
# tx、ty代表中心调整参数的真实值
|
345 |
+
#----------------------------------------#
|
346 |
+
y_true[b, k, j, i, 0] = batch_target[t, 0]
|
347 |
+
y_true[b, k, j, i, 1] = batch_target[t, 1]
|
348 |
+
y_true[b, k, j, i, 2] = batch_target[t, 2]
|
349 |
+
y_true[b, k, j, i, 3] = batch_target[t, 3]
|
350 |
+
y_true[b, k, j, i, 4] = 1
|
351 |
+
y_true[b, k, j, i, c + 5] = 1
|
352 |
+
#----------------------------------------#
|
353 |
+
# 用于获得xywh的比例
|
354 |
+
# 大目标loss权重小,小目标loss权重大
|
355 |
+
#----------------------------------------#
|
356 |
+
box_loss_scale[b, k, j, i] = batch_target[t, 2] * batch_target[t, 3] / in_w / in_h
|
357 |
+
return y_true, noobj_mask, box_loss_scale
|
358 |
+
|
359 |
+
def get_ignore(self, l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask):
|
360 |
+
#-----------------------------------------------------#
|
361 |
+
# 计算一共有多少张图片
|
362 |
+
#-----------------------------------------------------#
|
363 |
+
bs = len(targets)
|
364 |
+
|
365 |
+
#-----------------------------------------------------#
|
366 |
+
# 生成网格,先验框中心,网格左上角
|
367 |
+
#-----------------------------------------------------#
|
368 |
+
grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
|
369 |
+
int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type_as(x)
|
370 |
+
grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
|
371 |
+
int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type_as(x)
|
372 |
+
|
373 |
+
# 生成先验框的宽高
|
374 |
+
scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]]
|
375 |
+
anchor_w = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([0])).type_as(x)
|
376 |
+
anchor_h = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([1])).type_as(x)
|
377 |
+
|
378 |
+
anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
|
379 |
+
anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
|
380 |
+
#-------------------------------------------------------#
|
381 |
+
# 计算调整后的先验框中心与宽高
|
382 |
+
#-------------------------------------------------------#
|
383 |
+
pred_boxes_x = torch.unsqueeze(x + grid_x, -1)
|
384 |
+
pred_boxes_y = torch.unsqueeze(y + grid_y, -1)
|
385 |
+
pred_boxes_w = torch.unsqueeze(torch.exp(w) * anchor_w, -1)
|
386 |
+
pred_boxes_h = torch.unsqueeze(torch.exp(h) * anchor_h, -1)
|
387 |
+
pred_boxes = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim = -1)
|
388 |
+
|
389 |
+
for b in range(bs):
|
390 |
+
#-------------------------------------------------------#
|
391 |
+
# 将预测结果转换一个形式
|
392 |
+
# pred_boxes_for_ignore num_anchors, 4
|
393 |
+
#-------------------------------------------------------#
|
394 |
+
pred_boxes_for_ignore = pred_boxes[b].view(-1, 4)
|
395 |
+
#-------------------------------------------------------#
|
396 |
+
# 计算真实框,并把真实框转换成相对于特征层的大小
|
397 |
+
# gt_box num_true_box, 4
|
398 |
+
#-------------------------------------------------------#
|
399 |
+
if len(targets[b]) > 0:
|
400 |
+
batch_target = torch.zeros_like(targets[b])
|
401 |
+
#-------------------------------------------------------#
|
402 |
+
# 计算出正样本在特征层上的中心点
|
403 |
+
#-------------------------------------------------------#
|
404 |
+
batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
|
405 |
+
batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
|
406 |
+
batch_target = batch_target[:, :4].type_as(x)
|
407 |
+
#-------------------------------------------------------#
|
408 |
+
# 计算交并比
|
409 |
+
# anch_ious num_true_box, num_anchors
|
410 |
+
#-------------------------------------------------------#
|
411 |
+
anch_ious = self.calculate_iou(batch_target, pred_boxes_for_ignore)
|
412 |
+
#-------------------------------------------------------#
|
413 |
+
# 每个先验框对应真实框的最大重合度
|
414 |
+
# anch_ious_max num_anchors
|
415 |
+
#-------------------------------------------------------#
|
416 |
+
anch_ious_max, _ = torch.max(anch_ious, dim = 0)
|
417 |
+
anch_ious_max = anch_ious_max.view(pred_boxes[b].size()[:3])
|
418 |
+
noobj_mask[b][anch_ious_max > self.ignore_threshold] = 0
|
419 |
+
return noobj_mask, pred_boxes
|
420 |
+
|
421 |
+
def weights_init(net, init_type='normal', init_gain = 0.02):
|
422 |
+
def init_func(m):
|
423 |
+
classname = m.__class__.__name__
|
424 |
+
if hasattr(m, 'weight') and classname.find('Conv') != -1:
|
425 |
+
if init_type == 'normal':
|
426 |
+
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
|
427 |
+
elif init_type == 'xavier':
|
428 |
+
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
|
429 |
+
elif init_type == 'kaiming':
|
430 |
+
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
431 |
+
elif init_type == 'orthogonal':
|
432 |
+
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
|
433 |
+
else:
|
434 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
435 |
+
elif classname.find('BatchNorm2d') != -1:
|
436 |
+
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
437 |
+
torch.nn.init.constant_(m.bias.data, 0.0)
|
438 |
+
print('initialize network with %s type' % init_type)
|
439 |
+
net.apply(init_func)
|
440 |
+
|
441 |
+
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
|
442 |
+
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
|
443 |
+
if iters <= warmup_total_iters:
|
444 |
+
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
|
445 |
+
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
|
446 |
+
elif iters >= total_iters - no_aug_iter:
|
447 |
+
lr = min_lr
|
448 |
+
else:
|
449 |
+
lr = min_lr + 0.5 * (lr - min_lr) * (
|
450 |
+
1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
|
451 |
+
)
|
452 |
+
return lr
|
453 |
+
|
454 |
+
def step_lr(lr, decay_rate, step_size, iters):
|
455 |
+
if step_size < 1:
|
456 |
+
raise ValueError("step_size must above 1.")
|
457 |
+
n = iters // step_size
|
458 |
+
out_lr = lr * decay_rate ** n
|
459 |
+
return out_lr
|
460 |
+
|
461 |
+
if lr_decay_type == "cos":
|
462 |
+
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
|
463 |
+
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
|
464 |
+
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
|
465 |
+
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
|
466 |
+
else:
|
467 |
+
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
|
468 |
+
step_size = total_iters / step_num
|
469 |
+
func = partial(step_lr, lr, decay_rate, step_size)
|
470 |
+
|
471 |
+
return func
|
472 |
+
|
473 |
+
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
|
474 |
+
lr = lr_scheduler_func(epoch)
|
475 |
+
for param_group in optimizer.param_groups:
|
476 |
+
param_group['lr'] = lr
|
nets/yolotiny_training.py
ADDED
@@ -0,0 +1,474 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import math
|
2 |
+
from functools import partial
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
import torch.nn as nn
|
7 |
+
|
8 |
+
class YOLOLosstiny(nn.Module):
|
9 |
+
def __init__(self, anchors, num_classes, input_shape, cuda, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]], label_smoothing = 0):
|
10 |
+
super(YOLOLosstiny, self).__init__()
|
11 |
+
#-----------------------------------------------------------#
|
12 |
+
# 13x13的特征层对应的anchor是[81,82],[135,169],[344,319]
|
13 |
+
# 26x26的特征层对应的anchor是[10,14],[23,27],[37,58]
|
14 |
+
#-----------------------------------------------------------#
|
15 |
+
self.anchors = anchors
|
16 |
+
self.num_classes = num_classes
|
17 |
+
self.bbox_attrs = 5 + num_classes
|
18 |
+
self.input_shape = input_shape
|
19 |
+
self.anchors_mask = anchors_mask
|
20 |
+
self.label_smoothing = label_smoothing
|
21 |
+
|
22 |
+
self.balance = [0.4, 1.0, 4]
|
23 |
+
self.box_ratio = 0.05
|
24 |
+
self.obj_ratio = 5 * (input_shape[0] * input_shape[1]) / (416 ** 2)
|
25 |
+
self.cls_ratio = 1 * (num_classes / 80)
|
26 |
+
|
27 |
+
self.ignore_threshold = 0.5
|
28 |
+
self.cuda = cuda
|
29 |
+
|
30 |
+
def clip_by_tensor(self, t, t_min, t_max):
|
31 |
+
t = t.float()
|
32 |
+
result = (t >= t_min).float() * t + (t < t_min).float() * t_min
|
33 |
+
result = (result <= t_max).float() * result + (result > t_max).float() * t_max
|
34 |
+
return result
|
35 |
+
|
36 |
+
def MSELoss(self, pred, target):
|
37 |
+
return torch.pow(pred - target, 2)
|
38 |
+
|
39 |
+
def BCELoss(self, pred, target):
|
40 |
+
epsilon = 1e-7
|
41 |
+
pred = self.clip_by_tensor(pred, epsilon, 1.0 - epsilon)
|
42 |
+
output = - target * torch.log(pred) - (1.0 - target) * torch.log(1.0 - pred)
|
43 |
+
return output
|
44 |
+
|
45 |
+
def box_ciou(self, b1, b2):
|
46 |
+
"""
|
47 |
+
输入为:
|
48 |
+
----------
|
49 |
+
b1: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
|
50 |
+
b2: tensor, shape=(batch, feat_w, feat_h, anchor_num, 4), xywh
|
51 |
+
|
52 |
+
返回为:
|
53 |
+
-------
|
54 |
+
ciou: tensor, shape=(batch, feat_w, feat_h, anchor_num, 1)
|
55 |
+
"""
|
56 |
+
#----------------------------------------------------#
|
57 |
+
# 求出预测框左上角右下角
|
58 |
+
#----------------------------------------------------#
|
59 |
+
b1_xy = b1[..., :2]
|
60 |
+
b1_wh = b1[..., 2:4]
|
61 |
+
b1_wh_half = b1_wh/2.
|
62 |
+
b1_mins = b1_xy - b1_wh_half
|
63 |
+
b1_maxes = b1_xy + b1_wh_half
|
64 |
+
#----------------------------------------------------#
|
65 |
+
# 求出真实框左上角右下角
|
66 |
+
#----------------------------------------------------#
|
67 |
+
b2_xy = b2[..., :2]
|
68 |
+
b2_wh = b2[..., 2:4]
|
69 |
+
b2_wh_half = b2_wh/2.
|
70 |
+
b2_mins = b2_xy - b2_wh_half
|
71 |
+
b2_maxes = b2_xy + b2_wh_half
|
72 |
+
|
73 |
+
#----------------------------------------------------#
|
74 |
+
# 求真实框和预测框所有的iou
|
75 |
+
#----------------------------------------------------#
|
76 |
+
intersect_mins = torch.max(b1_mins, b2_mins)
|
77 |
+
intersect_maxes = torch.min(b1_maxes, b2_maxes)
|
78 |
+
intersect_wh = torch.max(intersect_maxes - intersect_mins, torch.zeros_like(intersect_maxes))
|
79 |
+
intersect_area = intersect_wh[..., 0] * intersect_wh[..., 1]
|
80 |
+
b1_area = b1_wh[..., 0] * b1_wh[..., 1]
|
81 |
+
b2_area = b2_wh[..., 0] * b2_wh[..., 1]
|
82 |
+
union_area = b1_area + b2_area - intersect_area
|
83 |
+
iou = intersect_area / torch.clamp(union_area,min = 1e-6)
|
84 |
+
|
85 |
+
#----------------------------------------------------#
|
86 |
+
# 计算中心的差距
|
87 |
+
#----------------------------------------------------#
|
88 |
+
center_distance = torch.sum(torch.pow((b1_xy - b2_xy), 2), axis=-1)
|
89 |
+
|
90 |
+
#----------------------------------------------------#
|
91 |
+
# 找到包裹两个框的最小框的左上角和右下角
|
92 |
+
#----------------------------------------------------#
|
93 |
+
enclose_mins = torch.min(b1_mins, b2_mins)
|
94 |
+
enclose_maxes = torch.max(b1_maxes, b2_maxes)
|
95 |
+
enclose_wh = torch.max(enclose_maxes - enclose_mins, torch.zeros_like(intersect_maxes))
|
96 |
+
#----------------------------------------------------#
|
97 |
+
# 计算对角线距离
|
98 |
+
#----------------------------------------------------#
|
99 |
+
enclose_diagonal = torch.sum(torch.pow(enclose_wh,2), axis=-1)
|
100 |
+
ciou = iou - 1.0 * (center_distance) / torch.clamp(enclose_diagonal,min = 1e-6)
|
101 |
+
|
102 |
+
v = (4 / (math.pi ** 2)) * torch.pow((torch.atan(b1_wh[..., 0] / torch.clamp(b1_wh[..., 1],min = 1e-6)) - torch.atan(b2_wh[..., 0] / torch.clamp(b2_wh[..., 1], min = 1e-6))), 2)
|
103 |
+
alpha = v / torch.clamp((1.0 - iou + v), min=1e-6)
|
104 |
+
ciou = ciou - alpha * v
|
105 |
+
return ciou
|
106 |
+
|
107 |
+
#---------------------------------------------------#
|
108 |
+
# 平滑标签
|
109 |
+
#---------------------------------------------------#
|
110 |
+
def smooth_labels(self, y_true, label_smoothing, num_classes):
|
111 |
+
return y_true * (1.0 - label_smoothing) + label_smoothing / num_classes
|
112 |
+
|
113 |
+
def forward(self, l, input, targets=None):
|
114 |
+
#----------------------------------------------------#
|
115 |
+
# l 代表使用的是第几个有效特征层
|
116 |
+
# input的shape为 bs, 3*(5+num_classes), 13, 13
|
117 |
+
# bs, 3*(5+num_classes), 26, 26
|
118 |
+
# targets 真实框的标签情况 [batch_size, num_gt, 5]
|
119 |
+
#----------------------------------------------------#
|
120 |
+
#--------------------------------#
|
121 |
+
# 获得图片数量,特征层的高和宽
|
122 |
+
#--------------------------------#
|
123 |
+
bs = input.size(0)
|
124 |
+
in_h = input.size(2)
|
125 |
+
in_w = input.size(3)
|
126 |
+
#-----------------------------------------------------------------------#
|
127 |
+
# 计算步长
|
128 |
+
# 每一个特征点对应原来的图片上多少个像素点
|
129 |
+
#
|
130 |
+
# 如果特征层为13x13的话,一个特征点就对应原来的图片上的32个像素点
|
131 |
+
# 如果特征层为26x26的话,一个特征点就对应原来的图片上的16个像素点
|
132 |
+
# stride_h = stride_w = 32、16
|
133 |
+
#-----------------------------------------------------------------------#
|
134 |
+
stride_h = self.input_shape[0] / in_h
|
135 |
+
stride_w = self.input_shape[1] / in_w
|
136 |
+
#-------------------------------------------------#
|
137 |
+
# 此时获得的scaled_anchors大小是相对于特征层的
|
138 |
+
#-------------------------------------------------#
|
139 |
+
scaled_anchors = [(a_w / stride_w, a_h / stride_h) for a_w, a_h in self.anchors]
|
140 |
+
#-----------------------------------------------#
|
141 |
+
# 输入的input一共有三个,他们的shape分别是
|
142 |
+
# bs, 3 * (5+num_classes), 13, 13 => bs, 3, 5 + num_classes, 13, 13 => batch_size, 3, 13, 13, 5 + num_classes
|
143 |
+
|
144 |
+
# batch_size, 3, 13, 13, 5 + num_classes
|
145 |
+
# batch_size, 3, 26, 26, 5 + num_classes
|
146 |
+
#-----------------------------------------------#
|
147 |
+
prediction = input.view(bs, len(self.anchors_mask[l]), self.bbox_attrs, in_h, in_w).permute(0, 1, 3, 4, 2).contiguous()
|
148 |
+
|
149 |
+
#-----------------------------------------------#
|
150 |
+
# 先验框的中心位置的调整参数
|
151 |
+
#-----------------------------------------------#
|
152 |
+
x = torch.sigmoid(prediction[..., 0])
|
153 |
+
y = torch.sigmoid(prediction[..., 1])
|
154 |
+
#-----------------------------------------------#
|
155 |
+
# 先验框的宽高调整参数
|
156 |
+
#-----------------------------------------------#
|
157 |
+
w = prediction[..., 2]
|
158 |
+
h = prediction[..., 3]
|
159 |
+
#-----------------------------------------------#
|
160 |
+
# 获得置信度,是否有物体
|
161 |
+
#-----------------------------------------------#
|
162 |
+
conf = torch.sigmoid(prediction[..., 4])
|
163 |
+
#-----------------------------------------------#
|
164 |
+
# 种类置信度
|
165 |
+
#-----------------------------------------------#
|
166 |
+
pred_cls = torch.sigmoid(prediction[..., 5:])
|
167 |
+
|
168 |
+
#-----------------------------------------------#
|
169 |
+
# 获得网络应该有的预测结果
|
170 |
+
#-----------------------------------------------#
|
171 |
+
y_true, noobj_mask, box_loss_scale = self.get_target(l, targets, scaled_anchors, in_h, in_w)
|
172 |
+
|
173 |
+
#---------------------------------------------------------------#
|
174 |
+
# 将预测结果进行解码,判断预测结果和真实值的重合程度
|
175 |
+
# 如果重合程度过大则忽略,因为这些特征点属于预测比较准确的特征点
|
176 |
+
# 作为负样本不合适
|
177 |
+
#----------------------------------------------------------------#
|
178 |
+
noobj_mask, pred_boxes = self.get_ignore(l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask)
|
179 |
+
|
180 |
+
if self.cuda:
|
181 |
+
y_true = y_true.type_as(x)
|
182 |
+
noobj_mask = noobj_mask.type_as(x)
|
183 |
+
box_loss_scale = box_loss_scale.type_as(x)
|
184 |
+
#--------------------------------------------------------------------------#
|
185 |
+
# box_loss_scale是真实框宽高的乘积,宽高均在0-1之间,因此乘积也在0-1之间。
|
186 |
+
# 2-宽高的乘积代表真实框越大,比重越小,小框的比重更大。
|
187 |
+
# 使用iou损失时,大中小目标的回归损失不存在比例失衡问题,故弃用
|
188 |
+
#--------------------------------------------------------------------------#
|
189 |
+
box_loss_scale = 2 - box_loss_scale
|
190 |
+
|
191 |
+
loss = 0
|
192 |
+
obj_mask = y_true[..., 4] == 1
|
193 |
+
n = torch.sum(obj_mask)
|
194 |
+
if n != 0:
|
195 |
+
#---------------------------------------------------------------#
|
196 |
+
# 计算预测结果和真实结果的差距
|
197 |
+
# loss_loc ciou回归损失
|
198 |
+
# loss_cls 分类损失
|
199 |
+
#---------------------------------------------------------------#
|
200 |
+
ciou = self.box_ciou(pred_boxes, y_true[..., :4]).type_as(x)
|
201 |
+
# loss_loc = torch.mean((1 - ciou)[obj_mask] * box_loss_scale[obj_mask])
|
202 |
+
loss_loc = torch.mean((1 - ciou)[obj_mask])
|
203 |
+
|
204 |
+
loss_cls = torch.mean(self.BCELoss(pred_cls[obj_mask], y_true[..., 5:][obj_mask]))
|
205 |
+
loss += loss_loc * self.box_ratio + loss_cls * self.cls_ratio
|
206 |
+
|
207 |
+
loss_conf = torch.mean(self.BCELoss(conf, obj_mask.type_as(conf))[noobj_mask.bool() | obj_mask])
|
208 |
+
loss += loss_conf * self.balance[l] * self.obj_ratio
|
209 |
+
# if n != 0:
|
210 |
+
# print(loss_loc * self.box_ratio, loss_cls * self.cls_ratio, loss_conf * self.balance[l] * self.obj_ratio)
|
211 |
+
return loss
|
212 |
+
|
213 |
+
def calculate_iou(self, _box_a, _box_b):
|
214 |
+
#-----------------------------------------------------------#
|
215 |
+
# 计算真实框的左上角和右下角
|
216 |
+
#-----------------------------------------------------------#
|
217 |
+
b1_x1, b1_x2 = _box_a[:, 0] - _box_a[:, 2] / 2, _box_a[:, 0] + _box_a[:, 2] / 2
|
218 |
+
b1_y1, b1_y2 = _box_a[:, 1] - _box_a[:, 3] / 2, _box_a[:, 1] + _box_a[:, 3] / 2
|
219 |
+
#-----------------------------------------------------------#
|
220 |
+
# 计算先验框获得的预测框的左上角和右下角
|
221 |
+
#-----------------------------------------------------------#
|
222 |
+
b2_x1, b2_x2 = _box_b[:, 0] - _box_b[:, 2] / 2, _box_b[:, 0] + _box_b[:, 2] / 2
|
223 |
+
b2_y1, b2_y2 = _box_b[:, 1] - _box_b[:, 3] / 2, _box_b[:, 1] + _box_b[:, 3] / 2
|
224 |
+
|
225 |
+
#-----------------------------------------------------------#
|
226 |
+
# 将真实框和预测框都转化成左上角右下角的形式
|
227 |
+
#-----------------------------------------------------------#
|
228 |
+
box_a = torch.zeros_like(_box_a)
|
229 |
+
box_b = torch.zeros_like(_box_b)
|
230 |
+
box_a[:, 0], box_a[:, 1], box_a[:, 2], box_a[:, 3] = b1_x1, b1_y1, b1_x2, b1_y2
|
231 |
+
box_b[:, 0], box_b[:, 1], box_b[:, 2], box_b[:, 3] = b2_x1, b2_y1, b2_x2, b2_y2
|
232 |
+
|
233 |
+
#-----------------------------------------------------------#
|
234 |
+
# A为真实框的数量,B为先验框的数量
|
235 |
+
#-----------------------------------------------------------#
|
236 |
+
A = box_a.size(0)
|
237 |
+
B = box_b.size(0)
|
238 |
+
|
239 |
+
#-----------------------------------------------------------#
|
240 |
+
# 计算交的面积
|
241 |
+
#-----------------------------------------------------------#
|
242 |
+
max_xy = torch.min(box_a[:, 2:].unsqueeze(1).expand(A, B, 2), box_b[:, 2:].unsqueeze(0).expand(A, B, 2))
|
243 |
+
min_xy = torch.max(box_a[:, :2].unsqueeze(1).expand(A, B, 2), box_b[:, :2].unsqueeze(0).expand(A, B, 2))
|
244 |
+
inter = torch.clamp((max_xy - min_xy), min=0)
|
245 |
+
inter = inter[:, :, 0] * inter[:, :, 1]
|
246 |
+
#-----------------------------------------------------------#
|
247 |
+
# 计算预测框和真实框各自的面积
|
248 |
+
#-----------------------------------------------------------#
|
249 |
+
area_a = ((box_a[:, 2]-box_a[:, 0]) * (box_a[:, 3]-box_a[:, 1])).unsqueeze(1).expand_as(inter) # [A,B]
|
250 |
+
area_b = ((box_b[:, 2]-box_b[:, 0]) * (box_b[:, 3]-box_b[:, 1])).unsqueeze(0).expand_as(inter) # [A,B]
|
251 |
+
#-----------------------------------------------------------#
|
252 |
+
# 求IOU
|
253 |
+
#-----------------------------------------------------------#
|
254 |
+
union = area_a + area_b - inter
|
255 |
+
return inter / union # [A,B]
|
256 |
+
|
257 |
+
def get_target(self, l, targets, anchors, in_h, in_w):
|
258 |
+
#-----------------------------------------------------#
|
259 |
+
# 计算一共有多少张图片
|
260 |
+
#-----------------------------------------------------#
|
261 |
+
bs = len(targets)
|
262 |
+
#-----------------------------------------------------#
|
263 |
+
# 用于选取哪些先验框不包含物体
|
264 |
+
#-----------------------------------------------------#
|
265 |
+
noobj_mask = torch.ones(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
|
266 |
+
#-----------------------------------------------------#
|
267 |
+
# 让网络更加去关注小目标
|
268 |
+
#-----------------------------------------------------#
|
269 |
+
box_loss_scale = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, requires_grad = False)
|
270 |
+
#-----------------------------------------------------#
|
271 |
+
# batch_size, 3, 13, 13, 5 + num_classes
|
272 |
+
#-----------------------------------------------------#
|
273 |
+
y_true = torch.zeros(bs, len(self.anchors_mask[l]), in_h, in_w, self.bbox_attrs, requires_grad = False)
|
274 |
+
for b in range(bs):
|
275 |
+
if len(targets[b])==0:
|
276 |
+
continue
|
277 |
+
batch_target = torch.zeros_like(targets[b])
|
278 |
+
#-------------------------------------------------------#
|
279 |
+
# 计算出正样本在特征层上的中心点
|
280 |
+
#-------------------------------------------------------#
|
281 |
+
batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
|
282 |
+
batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
|
283 |
+
batch_target[:, 4] = targets[b][:, 4]
|
284 |
+
batch_target = batch_target.cpu()
|
285 |
+
|
286 |
+
#-------------------------------------------------------#
|
287 |
+
# 将真实框转换一个形式
|
288 |
+
# num_true_box, 4
|
289 |
+
#-------------------------------------------------------#
|
290 |
+
gt_box = torch.FloatTensor(torch.cat((torch.zeros((batch_target.size(0), 2)), batch_target[:, 2:4]), 1))
|
291 |
+
#-------------------------------------------------------#
|
292 |
+
# 将先验框转换一个形式
|
293 |
+
# 9, 4
|
294 |
+
#-------------------------------------------------------#
|
295 |
+
anchor_shapes = torch.FloatTensor(torch.cat((torch.zeros((len(anchors), 2)), torch.FloatTensor(anchors)), 1))
|
296 |
+
#-------------------------------------------------------#
|
297 |
+
# 计算交并比
|
298 |
+
# self.calculate_iou(gt_box, anchor_shapes) = [num_true_box, 9]每一个真实框和9个先验框的重合情况
|
299 |
+
# best_ns:
|
300 |
+
# [每个真实框最大的重合度max_iou, 每一个真实框最重合的先验框的序号]
|
301 |
+
#-------------------------------------------------------#
|
302 |
+
iou = self.calculate_iou(gt_box, anchor_shapes)
|
303 |
+
best_ns = torch.argmax(iou, dim=-1)
|
304 |
+
sort_ns = torch.argsort(iou, dim=-1, descending=True)
|
305 |
+
|
306 |
+
def check_in_anchors_mask(index, anchors_mask):
|
307 |
+
for sub_anchors_mask in anchors_mask:
|
308 |
+
if index in sub_anchors_mask:
|
309 |
+
return True
|
310 |
+
return False
|
311 |
+
|
312 |
+
for t, best_n in enumerate(best_ns):
|
313 |
+
#----------------------------------------#
|
314 |
+
# 防止匹配到的先验框不在anchors_mask中
|
315 |
+
#----------------------------------------#
|
316 |
+
if not check_in_anchors_mask(best_n, self.anchors_mask):
|
317 |
+
for index in sort_ns[t]:
|
318 |
+
if check_in_anchors_mask(index, self.anchors_mask):
|
319 |
+
best_n = index
|
320 |
+
break
|
321 |
+
|
322 |
+
if best_n not in self.anchors_mask[l]:
|
323 |
+
continue
|
324 |
+
#----------------------------------------#
|
325 |
+
# 判断这个先验框是当前特征点的哪一个先验框
|
326 |
+
#----------------------------------------#
|
327 |
+
k = self.anchors_mask[l].index(best_n)
|
328 |
+
#----------------------------------------#
|
329 |
+
# 获得真实框属于哪个网格点
|
330 |
+
#----------------------------------------#
|
331 |
+
i = torch.floor(batch_target[t, 0]).long()
|
332 |
+
j = torch.floor(batch_target[t, 1]).long()
|
333 |
+
#----------------------------------------#
|
334 |
+
# 取出真实框的种类
|
335 |
+
#----------------------------------------#
|
336 |
+
c = batch_target[t, 4].long()
|
337 |
+
|
338 |
+
#----------------------------------------#
|
339 |
+
# noobj_mask代表无目标的特征点
|
340 |
+
#----------------------------------------#
|
341 |
+
noobj_mask[b, k, j, i] = 0
|
342 |
+
#----------------------------------------#
|
343 |
+
# tx、ty代表中心调整参数的真实值
|
344 |
+
#----------------------------------------#
|
345 |
+
y_true[b, k, j, i, 0] = batch_target[t, 0]
|
346 |
+
y_true[b, k, j, i, 1] = batch_target[t, 1]
|
347 |
+
y_true[b, k, j, i, 2] = batch_target[t, 2]
|
348 |
+
y_true[b, k, j, i, 3] = batch_target[t, 3]
|
349 |
+
y_true[b, k, j, i, 4] = 1
|
350 |
+
y_true[b, k, j, i, c + 5] = 1
|
351 |
+
#----------------------------------------#
|
352 |
+
# 用于获得xywh的比例
|
353 |
+
# 大目标loss权重小,小目标loss权重大
|
354 |
+
#----------------------------------------#
|
355 |
+
box_loss_scale[b, k, j, i] = batch_target[t, 2] * batch_target[t, 3] / in_w / in_h
|
356 |
+
return y_true, noobj_mask, box_loss_scale
|
357 |
+
|
358 |
+
def get_ignore(self, l, x, y, h, w, targets, scaled_anchors, in_h, in_w, noobj_mask):
|
359 |
+
#-----------------------------------------------------#
|
360 |
+
# 计算一共有多少张图片
|
361 |
+
#-----------------------------------------------------#
|
362 |
+
bs = len(targets)
|
363 |
+
|
364 |
+
#-----------------------------------------------------#
|
365 |
+
# 生成网格,先验框中心,网格左上角
|
366 |
+
#-----------------------------------------------------#
|
367 |
+
grid_x = torch.linspace(0, in_w - 1, in_w).repeat(in_h, 1).repeat(
|
368 |
+
int(bs * len(self.anchors_mask[l])), 1, 1).view(x.shape).type_as(x)
|
369 |
+
grid_y = torch.linspace(0, in_h - 1, in_h).repeat(in_w, 1).t().repeat(
|
370 |
+
int(bs * len(self.anchors_mask[l])), 1, 1).view(y.shape).type_as(x)
|
371 |
+
|
372 |
+
# 生成先验框的宽高
|
373 |
+
scaled_anchors_l = np.array(scaled_anchors)[self.anchors_mask[l]]
|
374 |
+
anchor_w = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([0])).type_as(x)
|
375 |
+
anchor_h = torch.Tensor(scaled_anchors_l).index_select(1, torch.LongTensor([1])).type_as(x)
|
376 |
+
|
377 |
+
anchor_w = anchor_w.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(w.shape)
|
378 |
+
anchor_h = anchor_h.repeat(bs, 1).repeat(1, 1, in_h * in_w).view(h.shape)
|
379 |
+
#-------------------------------------------------------#
|
380 |
+
# 计算调整后的先验框中心与宽高
|
381 |
+
#-------------------------------------------------------#
|
382 |
+
pred_boxes_x = torch.unsqueeze(x + grid_x, -1)
|
383 |
+
pred_boxes_y = torch.unsqueeze(y + grid_y, -1)
|
384 |
+
pred_boxes_w = torch.unsqueeze(torch.exp(w) * anchor_w, -1)
|
385 |
+
pred_boxes_h = torch.unsqueeze(torch.exp(h) * anchor_h, -1)
|
386 |
+
pred_boxes = torch.cat([pred_boxes_x, pred_boxes_y, pred_boxes_w, pred_boxes_h], dim = -1)
|
387 |
+
for b in range(bs):
|
388 |
+
#-------------------------------------------------------#
|
389 |
+
# 将预测结果转换一个形式
|
390 |
+
# pred_boxes_for_ignore num_anchors, 4
|
391 |
+
#-------------------------------------------------------#
|
392 |
+
pred_boxes_for_ignore = pred_boxes[b].view(-1, 4)
|
393 |
+
#-------------------------------------------------------#
|
394 |
+
# 计算真实框,并把真实框转换成相对于特征层的大小
|
395 |
+
# gt_box num_true_box, 4
|
396 |
+
#-------------------------------------------------------#
|
397 |
+
if len(targets[b]) > 0:
|
398 |
+
batch_target = torch.zeros_like(targets[b])
|
399 |
+
#-------------------------------------------------------#
|
400 |
+
# 计算出正样本在特征层上的中心点
|
401 |
+
#-------------------------------------------------------#
|
402 |
+
batch_target[:, [0,2]] = targets[b][:, [0,2]] * in_w
|
403 |
+
batch_target[:, [1,3]] = targets[b][:, [1,3]] * in_h
|
404 |
+
batch_target = batch_target[:, :4].type_as(x)
|
405 |
+
#-------------------------------------------------------#
|
406 |
+
# 计算交并比
|
407 |
+
# anch_ious num_true_box, num_anchors
|
408 |
+
#-------------------------------------------------------#
|
409 |
+
anch_ious = self.calculate_iou(batch_target, pred_boxes_for_ignore)
|
410 |
+
#-------------------------------------------------------#
|
411 |
+
# 每个先验框对应真实框的最大重合度
|
412 |
+
# anch_ious_max num_anchors
|
413 |
+
#-------------------------------------------------------#
|
414 |
+
anch_ious_max, _ = torch.max(anch_ious, dim = 0)
|
415 |
+
anch_ious_max = anch_ious_max.view(pred_boxes[b].size()[:3])
|
416 |
+
noobj_mask[b][anch_ious_max > self.ignore_threshold] = 0
|
417 |
+
return noobj_mask, pred_boxes
|
418 |
+
|
419 |
+
def weights_init(net, init_type='normal', init_gain = 0.02):
|
420 |
+
def init_func(m):
|
421 |
+
classname = m.__class__.__name__
|
422 |
+
if hasattr(m, 'weight') and classname.find('Conv') != -1:
|
423 |
+
if init_type == 'normal':
|
424 |
+
torch.nn.init.normal_(m.weight.data, 0.0, init_gain)
|
425 |
+
elif init_type == 'xavier':
|
426 |
+
torch.nn.init.xavier_normal_(m.weight.data, gain=init_gain)
|
427 |
+
elif init_type == 'kaiming':
|
428 |
+
torch.nn.init.kaiming_normal_(m.weight.data, a=0, mode='fan_in')
|
429 |
+
elif init_type == 'orthogonal':
|
430 |
+
torch.nn.init.orthogonal_(m.weight.data, gain=init_gain)
|
431 |
+
else:
|
432 |
+
raise NotImplementedError('initialization method [%s] is not implemented' % init_type)
|
433 |
+
elif classname.find('BatchNorm2d') != -1:
|
434 |
+
torch.nn.init.normal_(m.weight.data, 1.0, 0.02)
|
435 |
+
torch.nn.init.constant_(m.bias.data, 0.0)
|
436 |
+
print('initialize network with %s type' % init_type)
|
437 |
+
net.apply(init_func)
|
438 |
+
|
439 |
+
def get_lr_scheduler(lr_decay_type, lr, min_lr, total_iters, warmup_iters_ratio = 0.05, warmup_lr_ratio = 0.1, no_aug_iter_ratio = 0.05, step_num = 10):
|
440 |
+
def yolox_warm_cos_lr(lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter, iters):
|
441 |
+
if iters <= warmup_total_iters:
|
442 |
+
# lr = (lr - warmup_lr_start) * iters / float(warmup_total_iters) + warmup_lr_start
|
443 |
+
lr = (lr - warmup_lr_start) * pow(iters / float(warmup_total_iters), 2) + warmup_lr_start
|
444 |
+
elif iters >= total_iters - no_aug_iter:
|
445 |
+
lr = min_lr
|
446 |
+
else:
|
447 |
+
lr = min_lr + 0.5 * (lr - min_lr) * (
|
448 |
+
1.0 + math.cos(math.pi* (iters - warmup_total_iters) / (total_iters - warmup_total_iters - no_aug_iter))
|
449 |
+
)
|
450 |
+
return lr
|
451 |
+
|
452 |
+
def step_lr(lr, decay_rate, step_size, iters):
|
453 |
+
if step_size < 1:
|
454 |
+
raise ValueError("step_size must above 1.")
|
455 |
+
n = iters // step_size
|
456 |
+
out_lr = lr * decay_rate ** n
|
457 |
+
return out_lr
|
458 |
+
|
459 |
+
if lr_decay_type == "cos":
|
460 |
+
warmup_total_iters = min(max(warmup_iters_ratio * total_iters, 1), 3)
|
461 |
+
warmup_lr_start = max(warmup_lr_ratio * lr, 1e-6)
|
462 |
+
no_aug_iter = min(max(no_aug_iter_ratio * total_iters, 1), 15)
|
463 |
+
func = partial(yolox_warm_cos_lr ,lr, min_lr, total_iters, warmup_total_iters, warmup_lr_start, no_aug_iter)
|
464 |
+
else:
|
465 |
+
decay_rate = (min_lr / lr) ** (1 / (step_num - 1))
|
466 |
+
step_size = total_iters / step_num
|
467 |
+
func = partial(step_lr, lr, decay_rate, step_size)
|
468 |
+
|
469 |
+
return func
|
470 |
+
|
471 |
+
def set_optimizer_lr(optimizer, lr_scheduler_func, epoch):
|
472 |
+
lr = lr_scheduler_func(epoch)
|
473 |
+
for param_group in optimizer.param_groups:
|
474 |
+
param_group['lr'] = lr
|
utils/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
#
|
utils/callbacks.py
ADDED
@@ -0,0 +1,71 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import datetime
|
2 |
+
import os
|
3 |
+
|
4 |
+
import torch
|
5 |
+
import matplotlib
|
6 |
+
matplotlib.use('Agg')
|
7 |
+
import scipy.signal
|
8 |
+
from matplotlib import pyplot as plt
|
9 |
+
from torch.utils.tensorboard import SummaryWriter
|
10 |
+
|
11 |
+
|
12 |
+
class LossHistory():
|
13 |
+
def __init__(self, log_dir, model, input_shape):
|
14 |
+
time_str = datetime.datetime.strftime(datetime.datetime.now(),'%Y_%m_%d_%H_%M_%S')
|
15 |
+
self.log_dir = os.path.join(log_dir, "loss_" + str(time_str))
|
16 |
+
self.losses = []
|
17 |
+
self.val_loss = []
|
18 |
+
|
19 |
+
os.makedirs(self.log_dir)
|
20 |
+
self.writer = SummaryWriter(self.log_dir)
|
21 |
+
try:
|
22 |
+
dummy_input = torch.randn(2, 3, input_shape[0], input_shape[1])
|
23 |
+
self.writer.add_graph(model, dummy_input)
|
24 |
+
except:
|
25 |
+
pass
|
26 |
+
|
27 |
+
|
28 |
+
def append_loss(self, epoch, loss, val_loss):
|
29 |
+
if not os.path.exists(self.log_dir):
|
30 |
+
os.makedirs(self.log_dir)
|
31 |
+
|
32 |
+
self.losses.append(loss)
|
33 |
+
self.val_loss.append(val_loss)
|
34 |
+
|
35 |
+
with open(os.path.join(self.log_dir, "epoch_loss.txt"), 'a') as f:
|
36 |
+
f.write(str(loss))
|
37 |
+
f.write("\n")
|
38 |
+
with open(os.path.join(self.log_dir, "epoch_val_loss.txt"), 'a') as f:
|
39 |
+
f.write(str(val_loss))
|
40 |
+
f.write("\n")
|
41 |
+
|
42 |
+
self.writer.add_scalar('loss', loss, epoch)
|
43 |
+
self.writer.add_scalar('val_loss', val_loss, epoch)
|
44 |
+
self.loss_plot()
|
45 |
+
|
46 |
+
def loss_plot(self):
|
47 |
+
iters = range(len(self.losses))
|
48 |
+
|
49 |
+
plt.figure()
|
50 |
+
plt.plot(iters, self.losses, 'red', linewidth = 2, label='train loss')
|
51 |
+
plt.plot(iters, self.val_loss, 'coral', linewidth = 2, label='val loss')
|
52 |
+
try:
|
53 |
+
if len(self.losses) < 25:
|
54 |
+
num = 5
|
55 |
+
else:
|
56 |
+
num = 15
|
57 |
+
|
58 |
+
plt.plot(iters, scipy.signal.savgol_filter(self.losses, num, 3), 'green', linestyle = '--', linewidth = 2, label='smooth train loss')
|
59 |
+
plt.plot(iters, scipy.signal.savgol_filter(self.val_loss, num, 3), '#8B4513', linestyle = '--', linewidth = 2, label='smooth val loss')
|
60 |
+
except:
|
61 |
+
pass
|
62 |
+
|
63 |
+
plt.grid(True)
|
64 |
+
plt.xlabel('Epoch')
|
65 |
+
plt.ylabel('Loss')
|
66 |
+
plt.legend(loc="upper right")
|
67 |
+
|
68 |
+
plt.savefig(os.path.join(self.log_dir, "epoch_loss.png"))
|
69 |
+
|
70 |
+
plt.cla()
|
71 |
+
plt.close("all")
|
utils/dataloader.py
ADDED
@@ -0,0 +1,360 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from random import sample, shuffle
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import numpy as np
|
5 |
+
import torch
|
6 |
+
from PIL import Image
|
7 |
+
from torch.utils.data.dataset import Dataset
|
8 |
+
|
9 |
+
from utils.utils import cvtColor, preprocess_input
|
10 |
+
|
11 |
+
|
12 |
+
class YoloDataset(Dataset):
|
13 |
+
def __init__(self, annotation_lines, input_shape, num_classes, epoch_length, mosaic, train, mosaic_ratio = 0.7):
|
14 |
+
super(YoloDataset, self).__init__()
|
15 |
+
self.annotation_lines = annotation_lines
|
16 |
+
self.input_shape = input_shape
|
17 |
+
self.num_classes = num_classes
|
18 |
+
self.epoch_length = epoch_length
|
19 |
+
self.mosaic = mosaic
|
20 |
+
self.train = train
|
21 |
+
self.mosaic_ratio = mosaic_ratio
|
22 |
+
|
23 |
+
self.epoch_now = -1
|
24 |
+
self.length = len(self.annotation_lines)
|
25 |
+
|
26 |
+
def __len__(self):
|
27 |
+
return self.length
|
28 |
+
|
29 |
+
def __getitem__(self, index):
|
30 |
+
index = index % self.length
|
31 |
+
|
32 |
+
#---------------------------------------------------#
|
33 |
+
# 训练时进行数据的随机增强
|
34 |
+
# 验证时不进行数据的随机增强
|
35 |
+
#---------------------------------------------------#
|
36 |
+
if self.mosaic:
|
37 |
+
if self.rand() < 0.5 and self.epoch_now < self.epoch_length * self.mosaic_ratio:
|
38 |
+
lines = sample(self.annotation_lines, 3)
|
39 |
+
lines.append(self.annotation_lines[index])
|
40 |
+
shuffle(lines)
|
41 |
+
image, box = self.get_random_data_with_Mosaic(lines, self.input_shape)
|
42 |
+
else:
|
43 |
+
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
|
44 |
+
else:
|
45 |
+
image, box = self.get_random_data(self.annotation_lines[index], self.input_shape, random = self.train)
|
46 |
+
image = np.transpose(preprocess_input(np.array(image, dtype=np.float32)), (2, 0, 1))
|
47 |
+
box = np.array(box, dtype=np.float32)
|
48 |
+
if len(box) != 0:
|
49 |
+
box[:, [0, 2]] = box[:, [0, 2]] / self.input_shape[1]
|
50 |
+
box[:, [1, 3]] = box[:, [1, 3]] / self.input_shape[0]
|
51 |
+
|
52 |
+
box[:, 2:4] = box[:, 2:4] - box[:, 0:2]
|
53 |
+
box[:, 0:2] = box[:, 0:2] + box[:, 2:4] / 2
|
54 |
+
return image, box
|
55 |
+
|
56 |
+
def rand(self, a=0, b=1):
|
57 |
+
return np.random.rand()*(b-a) + a
|
58 |
+
|
59 |
+
def get_random_data(self, annotation_line, input_shape, jitter=.3, hue=.1, sat=0.7, val=0.4, random=True):
|
60 |
+
line = annotation_line.split()
|
61 |
+
#------------------------------#
|
62 |
+
# 读取图像并转换成RGB图像
|
63 |
+
#------------------------------#
|
64 |
+
image = Image.open(line[0])
|
65 |
+
image = cvtColor(image)
|
66 |
+
#------------------------------#
|
67 |
+
# 获得图像的高宽与目标高宽
|
68 |
+
#------------------------------#
|
69 |
+
iw, ih = image.size
|
70 |
+
h, w = input_shape
|
71 |
+
#------------------------------#
|
72 |
+
# 获得预测框
|
73 |
+
#------------------------------#
|
74 |
+
box = np.array([np.array(list(map(int,box.split(',')))) for box in line[1:]])
|
75 |
+
|
76 |
+
if not random:
|
77 |
+
scale = min(w/iw, h/ih)
|
78 |
+
nw = int(iw*scale)
|
79 |
+
nh = int(ih*scale)
|
80 |
+
dx = (w-nw)//2
|
81 |
+
dy = (h-nh)//2
|
82 |
+
|
83 |
+
#---------------------------------#
|
84 |
+
# 将图像多余的部分加上灰条
|
85 |
+
#---------------------------------#
|
86 |
+
image = image.resize((nw,nh), Image.BICUBIC)
|
87 |
+
new_image = Image.new('RGB', (w,h), (128,128,128))
|
88 |
+
new_image.paste(image, (dx, dy))
|
89 |
+
image_data = np.array(new_image, np.float32)
|
90 |
+
|
91 |
+
#---------------------------------#
|
92 |
+
# 对真实框进行调整
|
93 |
+
#---------------------------------#
|
94 |
+
if len(box)>0:
|
95 |
+
np.random.shuffle(box)
|
96 |
+
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
|
97 |
+
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
|
98 |
+
box[:, 0:2][box[:, 0:2]<0] = 0
|
99 |
+
box[:, 2][box[:, 2]>w] = w
|
100 |
+
box[:, 3][box[:, 3]>h] = h
|
101 |
+
box_w = box[:, 2] - box[:, 0]
|
102 |
+
box_h = box[:, 3] - box[:, 1]
|
103 |
+
box = box[np.logical_and(box_w>1, box_h>1)] # discard invalid box
|
104 |
+
|
105 |
+
return image_data, box
|
106 |
+
|
107 |
+
#------------------------------------------#
|
108 |
+
# 对图像进行缩放并且进行长和宽的扭曲
|
109 |
+
#------------------------------------------#
|
110 |
+
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
|
111 |
+
scale = self.rand(.25, 2)
|
112 |
+
if new_ar < 1:
|
113 |
+
nh = int(scale*h)
|
114 |
+
nw = int(nh*new_ar)
|
115 |
+
else:
|
116 |
+
nw = int(scale*w)
|
117 |
+
nh = int(nw/new_ar)
|
118 |
+
image = image.resize((nw,nh), Image.BICUBIC)
|
119 |
+
|
120 |
+
#------------------------------------------#
|
121 |
+
# 将图像多余的部分加上灰条
|
122 |
+
#------------------------------------------#
|
123 |
+
dx = int(self.rand(0, w-nw))
|
124 |
+
dy = int(self.rand(0, h-nh))
|
125 |
+
new_image = Image.new('RGB', (w,h), (128,128,128))
|
126 |
+
new_image.paste(image, (dx, dy))
|
127 |
+
image = new_image
|
128 |
+
|
129 |
+
#------------------------------------------#
|
130 |
+
# 翻转图像
|
131 |
+
#------------------------------------------#
|
132 |
+
flip = self.rand()<.5
|
133 |
+
if flip: image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
134 |
+
|
135 |
+
image_data = np.array(image, np.uint8)
|
136 |
+
#---------------------------------#
|
137 |
+
# 对图像进行色域变换
|
138 |
+
# 计算色域变换的参数
|
139 |
+
#---------------------------------#
|
140 |
+
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
|
141 |
+
#---------------------------------#
|
142 |
+
# 将图像转到HSV上
|
143 |
+
#---------------------------------#
|
144 |
+
hue, sat, val = cv2.split(cv2.cvtColor(image_data, cv2.COLOR_RGB2HSV))
|
145 |
+
dtype = image_data.dtype
|
146 |
+
#---------------------------------#
|
147 |
+
# 应用变换
|
148 |
+
#---------------------------------#
|
149 |
+
x = np.arange(0, 256, dtype=r.dtype)
|
150 |
+
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
151 |
+
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
152 |
+
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
153 |
+
|
154 |
+
image_data = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
155 |
+
image_data = cv2.cvtColor(image_data, cv2.COLOR_HSV2RGB)
|
156 |
+
|
157 |
+
#---------------------------------#
|
158 |
+
# 对真实框进行调整
|
159 |
+
#---------------------------------#
|
160 |
+
if len(box)>0:
|
161 |
+
np.random.shuffle(box)
|
162 |
+
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
|
163 |
+
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
|
164 |
+
if flip: box[:, [0,2]] = w - box[:, [2,0]]
|
165 |
+
box[:, 0:2][box[:, 0:2]<0] = 0
|
166 |
+
box[:, 2][box[:, 2]>w] = w
|
167 |
+
box[:, 3][box[:, 3]>h] = h
|
168 |
+
box_w = box[:, 2] - box[:, 0]
|
169 |
+
box_h = box[:, 3] - box[:, 1]
|
170 |
+
box = box[np.logical_and(box_w>1, box_h>1)]
|
171 |
+
|
172 |
+
return image_data, box
|
173 |
+
|
174 |
+
def merge_bboxes(self, bboxes, cutx, cuty):
|
175 |
+
merge_bbox = []
|
176 |
+
for i in range(len(bboxes)):
|
177 |
+
for box in bboxes[i]:
|
178 |
+
tmp_box = []
|
179 |
+
x1, y1, x2, y2 = box[0], box[1], box[2], box[3]
|
180 |
+
|
181 |
+
if i == 0:
|
182 |
+
if y1 > cuty or x1 > cutx:
|
183 |
+
continue
|
184 |
+
if y2 >= cuty and y1 <= cuty:
|
185 |
+
y2 = cuty
|
186 |
+
if x2 >= cutx and x1 <= cutx:
|
187 |
+
x2 = cutx
|
188 |
+
|
189 |
+
if i == 1:
|
190 |
+
if y2 < cuty or x1 > cutx:
|
191 |
+
continue
|
192 |
+
if y2 >= cuty and y1 <= cuty:
|
193 |
+
y1 = cuty
|
194 |
+
if x2 >= cutx and x1 <= cutx:
|
195 |
+
x2 = cutx
|
196 |
+
|
197 |
+
if i == 2:
|
198 |
+
if y2 < cuty or x2 < cutx:
|
199 |
+
continue
|
200 |
+
if y2 >= cuty and y1 <= cuty:
|
201 |
+
y1 = cuty
|
202 |
+
if x2 >= cutx and x1 <= cutx:
|
203 |
+
x1 = cutx
|
204 |
+
|
205 |
+
if i == 3:
|
206 |
+
if y1 > cuty or x2 < cutx:
|
207 |
+
continue
|
208 |
+
if y2 >= cuty and y1 <= cuty:
|
209 |
+
y2 = cuty
|
210 |
+
if x2 >= cutx and x1 <= cutx:
|
211 |
+
x1 = cutx
|
212 |
+
tmp_box.append(x1)
|
213 |
+
tmp_box.append(y1)
|
214 |
+
tmp_box.append(x2)
|
215 |
+
tmp_box.append(y2)
|
216 |
+
tmp_box.append(box[-1])
|
217 |
+
merge_bbox.append(tmp_box)
|
218 |
+
return merge_bbox
|
219 |
+
|
220 |
+
def get_random_data_with_Mosaic(self, annotation_line, input_shape, jitter=0.3, hue=.1, sat=0.7, val=0.4):
|
221 |
+
h, w = input_shape
|
222 |
+
min_offset_x = self.rand(0.3, 0.7)
|
223 |
+
min_offset_y = self.rand(0.3, 0.7)
|
224 |
+
|
225 |
+
image_datas = []
|
226 |
+
box_datas = []
|
227 |
+
index = 0
|
228 |
+
for line in annotation_line:
|
229 |
+
#---------------------------------#
|
230 |
+
# 每一行进行分割
|
231 |
+
#---------------------------------#
|
232 |
+
line_content = line.split()
|
233 |
+
#---------------------------------#
|
234 |
+
# 打开图片
|
235 |
+
#---------------------------------#
|
236 |
+
image = Image.open(line_content[0])
|
237 |
+
image = cvtColor(image)
|
238 |
+
|
239 |
+
#---------------------------------#
|
240 |
+
# 图片的大小
|
241 |
+
#---------------------------------#
|
242 |
+
iw, ih = image.size
|
243 |
+
#---------------------------------#
|
244 |
+
# 保存框的位置
|
245 |
+
#---------------------------------#
|
246 |
+
box = np.array([np.array(list(map(int,box.split(',')))) for box in line_content[1:]])
|
247 |
+
|
248 |
+
#---------------------------------#
|
249 |
+
# 是否翻转图片
|
250 |
+
#---------------------------------#
|
251 |
+
flip = self.rand()<.5
|
252 |
+
if flip and len(box)>0:
|
253 |
+
image = image.transpose(Image.FLIP_LEFT_RIGHT)
|
254 |
+
box[:, [0,2]] = iw - box[:, [2,0]]
|
255 |
+
|
256 |
+
#------------------------------------------#
|
257 |
+
# 对图像进行缩放并且进行长和宽的扭曲
|
258 |
+
#------------------------------------------#
|
259 |
+
new_ar = iw/ih * self.rand(1-jitter,1+jitter) / self.rand(1-jitter,1+jitter)
|
260 |
+
scale = self.rand(.4, 1)
|
261 |
+
if new_ar < 1:
|
262 |
+
nh = int(scale*h)
|
263 |
+
nw = int(nh*new_ar)
|
264 |
+
else:
|
265 |
+
nw = int(scale*w)
|
266 |
+
nh = int(nw/new_ar)
|
267 |
+
image = image.resize((nw, nh), Image.BICUBIC)
|
268 |
+
|
269 |
+
#-----------------------------------------------#
|
270 |
+
# 将图片进行放置,分别对应四张分割图片的位置
|
271 |
+
#-----------------------------------------------#
|
272 |
+
if index == 0:
|
273 |
+
dx = int(w*min_offset_x) - nw
|
274 |
+
dy = int(h*min_offset_y) - nh
|
275 |
+
elif index == 1:
|
276 |
+
dx = int(w*min_offset_x) - nw
|
277 |
+
dy = int(h*min_offset_y)
|
278 |
+
elif index == 2:
|
279 |
+
dx = int(w*min_offset_x)
|
280 |
+
dy = int(h*min_offset_y)
|
281 |
+
elif index == 3:
|
282 |
+
dx = int(w*min_offset_x)
|
283 |
+
dy = int(h*min_offset_y) - nh
|
284 |
+
|
285 |
+
new_image = Image.new('RGB', (w,h), (128,128,128))
|
286 |
+
new_image.paste(image, (dx, dy))
|
287 |
+
image_data = np.array(new_image)
|
288 |
+
|
289 |
+
index = index + 1
|
290 |
+
box_data = []
|
291 |
+
#---------------------------------#
|
292 |
+
# 对box进行重新处理
|
293 |
+
#---------------------------------#
|
294 |
+
if len(box)>0:
|
295 |
+
np.random.shuffle(box)
|
296 |
+
box[:, [0,2]] = box[:, [0,2]]*nw/iw + dx
|
297 |
+
box[:, [1,3]] = box[:, [1,3]]*nh/ih + dy
|
298 |
+
box[:, 0:2][box[:, 0:2]<0] = 0
|
299 |
+
box[:, 2][box[:, 2]>w] = w
|
300 |
+
box[:, 3][box[:, 3]>h] = h
|
301 |
+
box_w = box[:, 2] - box[:, 0]
|
302 |
+
box_h = box[:, 3] - box[:, 1]
|
303 |
+
box = box[np.logical_and(box_w>1, box_h>1)]
|
304 |
+
box_data = np.zeros((len(box),5))
|
305 |
+
box_data[:len(box)] = box
|
306 |
+
|
307 |
+
image_datas.append(image_data)
|
308 |
+
box_datas.append(box_data)
|
309 |
+
|
310 |
+
#---------------------------------#
|
311 |
+
# 将图片分割,放在一起
|
312 |
+
#---------------------------------#
|
313 |
+
cutx = int(w * min_offset_x)
|
314 |
+
cuty = int(h * min_offset_y)
|
315 |
+
|
316 |
+
new_image = np.zeros([h, w, 3])
|
317 |
+
new_image[:cuty, :cutx, :] = image_datas[0][:cuty, :cutx, :]
|
318 |
+
new_image[cuty:, :cutx, :] = image_datas[1][cuty:, :cutx, :]
|
319 |
+
new_image[cuty:, cutx:, :] = image_datas[2][cuty:, cutx:, :]
|
320 |
+
new_image[:cuty, cutx:, :] = image_datas[3][:cuty, cutx:, :]
|
321 |
+
|
322 |
+
new_image = np.array(new_image, np.uint8)
|
323 |
+
#---------------------------------#
|
324 |
+
# 对图像进行色域变换
|
325 |
+
# 计算色域变换的参数
|
326 |
+
#---------------------------------#
|
327 |
+
r = np.random.uniform(-1, 1, 3) * [hue, sat, val] + 1
|
328 |
+
#---------------------------------#
|
329 |
+
# 将图像转到HSV上
|
330 |
+
#---------------------------------#
|
331 |
+
hue, sat, val = cv2.split(cv2.cvtColor(new_image, cv2.COLOR_RGB2HSV))
|
332 |
+
dtype = new_image.dtype
|
333 |
+
#---------------------------------#
|
334 |
+
# 应用变换
|
335 |
+
#---------------------------------#
|
336 |
+
x = np.arange(0, 256, dtype=r.dtype)
|
337 |
+
lut_hue = ((x * r[0]) % 180).astype(dtype)
|
338 |
+
lut_sat = np.clip(x * r[1], 0, 255).astype(dtype)
|
339 |
+
lut_val = np.clip(x * r[2], 0, 255).astype(dtype)
|
340 |
+
|
341 |
+
new_image = cv2.merge((cv2.LUT(hue, lut_hue), cv2.LUT(sat, lut_sat), cv2.LUT(val, lut_val)))
|
342 |
+
new_image = cv2.cvtColor(new_image, cv2.COLOR_HSV2RGB)
|
343 |
+
|
344 |
+
#---------------------------------#
|
345 |
+
# 对框进行进一步的处理
|
346 |
+
#---------------------------------#
|
347 |
+
new_boxes = self.merge_bboxes(box_datas, cutx, cuty)
|
348 |
+
|
349 |
+
return new_image, new_boxes
|
350 |
+
|
351 |
+
# DataLoader中collate_fn使用
|
352 |
+
def yolo_dataset_collate(batch):
|
353 |
+
images = []
|
354 |
+
bboxes = []
|
355 |
+
for img, box in batch:
|
356 |
+
images.append(img)
|
357 |
+
bboxes.append(box)
|
358 |
+
images = torch.from_numpy(np.array(images)).type(torch.FloatTensor)
|
359 |
+
bboxes = [torch.from_numpy(ann).type(torch.FloatTensor) for ann in bboxes]
|
360 |
+
return images, bboxes
|
utils/utils.py
ADDED
@@ -0,0 +1,62 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from PIL import Image
|
3 |
+
|
4 |
+
#---------------------------------------------------------#
|
5 |
+
# 将图像转换成RGB图像,防止灰度图在预测时报错。
|
6 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
7 |
+
#---------------------------------------------------------#
|
8 |
+
def cvtColor(image):
|
9 |
+
if len(np.shape(image)) == 3 and np.shape(image)[2] == 3:
|
10 |
+
return image
|
11 |
+
else:
|
12 |
+
image = image.convert('RGB')
|
13 |
+
return image
|
14 |
+
|
15 |
+
#---------------------------------------------------#
|
16 |
+
# 对输入图像进行resize
|
17 |
+
#---------------------------------------------------#
|
18 |
+
def resize_image(image, size, letterbox_image):
|
19 |
+
iw, ih = image.size
|
20 |
+
w, h = size
|
21 |
+
if letterbox_image:
|
22 |
+
scale = min(w/iw, h/ih)
|
23 |
+
nw = int(iw*scale)
|
24 |
+
nh = int(ih*scale)
|
25 |
+
|
26 |
+
image = image.resize((nw,nh), Image.BICUBIC)
|
27 |
+
new_image = Image.new('RGB', size, (128,128,128))
|
28 |
+
new_image.paste(image, ((w-nw)//2, (h-nh)//2))
|
29 |
+
else:
|
30 |
+
new_image = image.resize((w, h), Image.BICUBIC)
|
31 |
+
return new_image
|
32 |
+
|
33 |
+
#---------------------------------------------------#
|
34 |
+
# 获得类
|
35 |
+
#---------------------------------------------------#
|
36 |
+
def get_classes(classes_path):
|
37 |
+
with open(classes_path, encoding='utf-8') as f:
|
38 |
+
class_names = f.readlines()
|
39 |
+
class_names = [c.strip() for c in class_names]
|
40 |
+
return class_names, len(class_names)
|
41 |
+
|
42 |
+
#---------------------------------------------------#
|
43 |
+
# 获得先验框
|
44 |
+
#---------------------------------------------------#
|
45 |
+
def get_anchors(anchors_path):
|
46 |
+
'''loads the anchors from a file'''
|
47 |
+
with open(anchors_path, encoding='utf-8') as f:
|
48 |
+
anchors = f.readline()
|
49 |
+
anchors = [float(x) for x in anchors.split(',')]
|
50 |
+
anchors = np.array(anchors).reshape(-1, 2)
|
51 |
+
return anchors, len(anchors)
|
52 |
+
|
53 |
+
#---------------------------------------------------#
|
54 |
+
# 获得学习率
|
55 |
+
#---------------------------------------------------#
|
56 |
+
def get_lr(optimizer):
|
57 |
+
for param_group in optimizer.param_groups:
|
58 |
+
return param_group['lr']
|
59 |
+
|
60 |
+
def preprocess_input(image):
|
61 |
+
image /= 255.0
|
62 |
+
return image
|
utils/utils_bbox.py
ADDED
@@ -0,0 +1,227 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import torch.nn as nn
|
3 |
+
from torchvision.ops import nms
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
class DecodeBox():
|
7 |
+
def __init__(self, anchors, num_classes, input_shape, anchors_mask = [[6,7,8], [3,4,5], [0,1,2]]):
|
8 |
+
super(DecodeBox, self).__init__()
|
9 |
+
self.anchors = anchors
|
10 |
+
self.num_classes = num_classes
|
11 |
+
self.bbox_attrs = 5 + num_classes
|
12 |
+
self.input_shape = input_shape
|
13 |
+
#-----------------------------------------------------------#
|
14 |
+
# 13x13的特征层对应的anchor是[142, 110],[192, 243],[459, 401]
|
15 |
+
# 26x26的特征层对应的anchor是[36, 75],[76, 55],[72, 146]
|
16 |
+
# 52x52的特征层对应的anchor是[12, 16],[19, 36],[40, 28]
|
17 |
+
#-----------------------------------------------------------#
|
18 |
+
self.anchors_mask = anchors_mask
|
19 |
+
|
20 |
+
def decode_box(self, inputs):
|
21 |
+
outputs = []
|
22 |
+
for i, input in enumerate(inputs):
|
23 |
+
#-----------------------------------------------#
|
24 |
+
# 输入的input一共有三个,他们的shape分别是
|
25 |
+
# batch_size, 255, 13, 13
|
26 |
+
# batch_size, 255, 26, 26
|
27 |
+
# batch_size, 255, 52, 52
|
28 |
+
#-----------------------------------------------#
|
29 |
+
batch_size = input.size(0)
|
30 |
+
input_height = input.size(2)
|
31 |
+
input_width = input.size(3)
|
32 |
+
|
33 |
+
#-----------------------------------------------#
|
34 |
+
# 输入为416x416时
|
35 |
+
# stride_h = stride_w = 32、16、8
|
36 |
+
#-----------------------------------------------#
|
37 |
+
stride_h = self.input_shape[0] / input_height
|
38 |
+
stride_w = self.input_shape[1] / input_width
|
39 |
+
#-------------------------------------------------#
|
40 |
+
# 此时获得的scaled_anchors大小是相对于特征层的
|
41 |
+
#-------------------------------------------------#
|
42 |
+
scaled_anchors = [(anchor_width / stride_w, anchor_height / stride_h) for anchor_width, anchor_height in self.anchors[self.anchors_mask[i]]]
|
43 |
+
|
44 |
+
#-----------------------------------------------#
|
45 |
+
# 输入的input一共有三个,他们的shape分别是
|
46 |
+
# batch_size, 3, 13, 13, 85
|
47 |
+
# batch_size, 3, 26, 26, 85
|
48 |
+
# batch_size, 3, 52, 52, 85
|
49 |
+
#-----------------------------------------------#
|
50 |
+
prediction = input.view(batch_size, len(self.anchors_mask[i]),
|
51 |
+
self.bbox_attrs, input_height, input_width).permute(0, 1, 3, 4, 2).contiguous()
|
52 |
+
|
53 |
+
#-----------------------------------------------#
|
54 |
+
# 先验框的中心位置的调整参数
|
55 |
+
#-----------------------------------------------#
|
56 |
+
x = torch.sigmoid(prediction[..., 0])
|
57 |
+
y = torch.sigmoid(prediction[..., 1])
|
58 |
+
#-----------------------------------------------#
|
59 |
+
# 先验框的宽高调整参数
|
60 |
+
#-----------------------------------------------#
|
61 |
+
w = prediction[..., 2]
|
62 |
+
h = prediction[..., 3]
|
63 |
+
#-----------------------------------------------#
|
64 |
+
# 获得置信度,是否有物体
|
65 |
+
#-----------------------------------------------#
|
66 |
+
conf = torch.sigmoid(prediction[..., 4])
|
67 |
+
#-----------------------------------------------#
|
68 |
+
# 种类置信度
|
69 |
+
#-----------------------------------------------#
|
70 |
+
pred_cls = torch.sigmoid(prediction[..., 5:])
|
71 |
+
|
72 |
+
FloatTensor = torch.cuda.FloatTensor if x.is_cuda else torch.FloatTensor
|
73 |
+
LongTensor = torch.cuda.LongTensor if x.is_cuda else torch.LongTensor
|
74 |
+
|
75 |
+
#----------------------------------------------------------#
|
76 |
+
# 生成网格,先验框中心,网格左上角
|
77 |
+
# batch_size,3,13,13
|
78 |
+
#----------------------------------------------------------#
|
79 |
+
grid_x = torch.linspace(0, input_width - 1, input_width).repeat(input_height, 1).repeat(
|
80 |
+
batch_size * len(self.anchors_mask[i]), 1, 1).view(x.shape).type(FloatTensor)
|
81 |
+
grid_y = torch.linspace(0, input_height - 1, input_height).repeat(input_width, 1).t().repeat(
|
82 |
+
batch_size * len(self.anchors_mask[i]), 1, 1).view(y.shape).type(FloatTensor)
|
83 |
+
|
84 |
+
#----------------------------------------------------------#
|
85 |
+
# 按照网格格式生成先验框的宽高
|
86 |
+
# batch_size,3,13,13
|
87 |
+
#----------------------------------------------------------#
|
88 |
+
anchor_w = FloatTensor(scaled_anchors).index_select(1, LongTensor([0]))
|
89 |
+
anchor_h = FloatTensor(scaled_anchors).index_select(1, LongTensor([1]))
|
90 |
+
anchor_w = anchor_w.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(w.shape)
|
91 |
+
anchor_h = anchor_h.repeat(batch_size, 1).repeat(1, 1, input_height * input_width).view(h.shape)
|
92 |
+
|
93 |
+
#----------------------------------------------------------#
|
94 |
+
# 利用预测结果对先验框进行调整
|
95 |
+
# 首先调整先验框的中心,从先验框中心向右下角偏移
|
96 |
+
# 再调整先验框的宽高。
|
97 |
+
#----------------------------------------------------------#
|
98 |
+
pred_boxes = FloatTensor(prediction[..., :4].shape)
|
99 |
+
pred_boxes[..., 0] = x.data + grid_x
|
100 |
+
pred_boxes[..., 1] = y.data + grid_y
|
101 |
+
pred_boxes[..., 2] = torch.exp(w.data) * anchor_w
|
102 |
+
pred_boxes[..., 3] = torch.exp(h.data) * anchor_h
|
103 |
+
|
104 |
+
#----------------------------------------------------------#
|
105 |
+
# 将输出结果归一化成小数的形式
|
106 |
+
#----------------------------------------------------------#
|
107 |
+
_scale = torch.Tensor([input_width, input_height, input_width, input_height]).type(FloatTensor)
|
108 |
+
output = torch.cat((pred_boxes.view(batch_size, -1, 4) / _scale,
|
109 |
+
conf.view(batch_size, -1, 1), pred_cls.view(batch_size, -1, self.num_classes)), -1)
|
110 |
+
outputs.append(output.data)
|
111 |
+
return outputs
|
112 |
+
|
113 |
+
def yolo_correct_boxes(self, box_xy, box_wh, input_shape, image_shape, letterbox_image):
|
114 |
+
#-----------------------------------------------------------------#
|
115 |
+
# 把y轴放前面是因为方便预测框和图像的宽高进行相乘
|
116 |
+
#-----------------------------------------------------------------#
|
117 |
+
box_yx = box_xy[..., ::-1]
|
118 |
+
box_hw = box_wh[..., ::-1]
|
119 |
+
input_shape = np.array(input_shape)
|
120 |
+
image_shape = np.array(image_shape)
|
121 |
+
|
122 |
+
if letterbox_image:
|
123 |
+
#-----------------------------------------------------------------#
|
124 |
+
# 这里求出来的offset是图像有效区域相对于图像左上角的偏移情况
|
125 |
+
# new_shape指的是宽高缩放情况
|
126 |
+
#-----------------------------------------------------------------#
|
127 |
+
new_shape = np.round(image_shape * np.min(input_shape/image_shape))
|
128 |
+
offset = (input_shape - new_shape)/2./input_shape
|
129 |
+
scale = input_shape/new_shape
|
130 |
+
|
131 |
+
box_yx = (box_yx - offset) * scale
|
132 |
+
box_hw *= scale
|
133 |
+
|
134 |
+
box_mins = box_yx - (box_hw / 2.)
|
135 |
+
box_maxes = box_yx + (box_hw / 2.)
|
136 |
+
boxes = np.concatenate([box_mins[..., 0:1], box_mins[..., 1:2], box_maxes[..., 0:1], box_maxes[..., 1:2]], axis=-1)
|
137 |
+
boxes *= np.concatenate([image_shape, image_shape], axis=-1)
|
138 |
+
return boxes
|
139 |
+
|
140 |
+
def non_max_suppression(self, prediction, num_classes, input_shape, image_shape, letterbox_image, conf_thres=0.5, nms_thres=0.4):
|
141 |
+
#----------------------------------------------------------#
|
142 |
+
# 将预测结果的格式转换成左上角右下角的格式。
|
143 |
+
# prediction [batch_size, num_anchors, 85]
|
144 |
+
#----------------------------------------------------------#
|
145 |
+
box_corner = prediction.new(prediction.shape)
|
146 |
+
box_corner[:, :, 0] = prediction[:, :, 0] - prediction[:, :, 2] / 2
|
147 |
+
box_corner[:, :, 1] = prediction[:, :, 1] - prediction[:, :, 3] / 2
|
148 |
+
box_corner[:, :, 2] = prediction[:, :, 0] + prediction[:, :, 2] / 2
|
149 |
+
box_corner[:, :, 3] = prediction[:, :, 1] + prediction[:, :, 3] / 2
|
150 |
+
prediction[:, :, :4] = box_corner[:, :, :4]
|
151 |
+
|
152 |
+
output = [None for _ in range(len(prediction))]
|
153 |
+
for i, image_pred in enumerate(prediction):
|
154 |
+
#----------------------------------------------------------#
|
155 |
+
# 对种类预测部分取max。
|
156 |
+
# class_conf [num_anchors, 1] 种类置信度
|
157 |
+
# class_pred [num_anchors, 1] 种类
|
158 |
+
#----------------------------------------------------------#
|
159 |
+
class_conf, class_pred = torch.max(image_pred[:, 5:5 + num_classes], 1, keepdim=True)
|
160 |
+
|
161 |
+
#----------------------------------------------------------#
|
162 |
+
# 利用置信度进行第一轮筛选
|
163 |
+
#----------------------------------------------------------#
|
164 |
+
conf_mask = (image_pred[:, 4] * class_conf[:, 0] >= conf_thres).squeeze()
|
165 |
+
|
166 |
+
#----------------------------------------------------------#
|
167 |
+
# 根据置信度进行预测结果的筛选
|
168 |
+
#----------------------------------------------------------#
|
169 |
+
image_pred = image_pred[conf_mask]
|
170 |
+
class_conf = class_conf[conf_mask]
|
171 |
+
class_pred = class_pred[conf_mask]
|
172 |
+
if not image_pred.size(0):
|
173 |
+
continue
|
174 |
+
#-------------------------------------------------------------------------#
|
175 |
+
# detections [num_anchors, 7]
|
176 |
+
# 7的内容为:x1, y1, x2, y2, obj_conf, class_conf, class_pred
|
177 |
+
#-------------------------------------------------------------------------#
|
178 |
+
detections = torch.cat((image_pred[:, :5], class_conf.float(), class_pred.float()), 1)
|
179 |
+
|
180 |
+
#------------------------------------------#
|
181 |
+
# 获得预测结果中包含的所有种类
|
182 |
+
#------------------------------------------#
|
183 |
+
unique_labels = detections[:, -1].cpu().unique()
|
184 |
+
|
185 |
+
if prediction.is_cuda:
|
186 |
+
unique_labels = unique_labels.cuda()
|
187 |
+
detections = detections.cuda()
|
188 |
+
|
189 |
+
for c in unique_labels:
|
190 |
+
#------------------------------------------#
|
191 |
+
# 获得某一类得分筛选后全部的预测结果
|
192 |
+
#------------------------------------------#
|
193 |
+
detections_class = detections[detections[:, -1] == c]
|
194 |
+
|
195 |
+
#------------------------------------------#
|
196 |
+
# 使用官方自带的非极大抑制会速度更快一些!
|
197 |
+
#------------------------------------------#
|
198 |
+
keep = nms(
|
199 |
+
detections_class[:, :4],
|
200 |
+
detections_class[:, 4] * detections_class[:, 5],
|
201 |
+
nms_thres
|
202 |
+
)
|
203 |
+
max_detections = detections_class[keep]
|
204 |
+
|
205 |
+
# # 按照存在物体的置信度排序
|
206 |
+
# _, conf_sort_index = torch.sort(detections_class[:, 4]*detections_class[:, 5], descending=True)
|
207 |
+
# detections_class = detections_class[conf_sort_index]
|
208 |
+
# # 进行非极大抑制
|
209 |
+
# max_detections = []
|
210 |
+
# while detections_class.size(0):
|
211 |
+
# # 取出这一类置信度最高的,一步一步往下判断,判断重合程度是否大于nms_thres,如果是则去除掉
|
212 |
+
# max_detections.append(detections_class[0].unsqueeze(0))
|
213 |
+
# if len(detections_class) == 1:
|
214 |
+
# break
|
215 |
+
# ious = bbox_iou(max_detections[-1], detections_class[1:])
|
216 |
+
# detections_class = detections_class[1:][ious < nms_thres]
|
217 |
+
# # 堆叠
|
218 |
+
# max_detections = torch.cat(max_detections).data
|
219 |
+
|
220 |
+
# Add max detections to outputs
|
221 |
+
output[i] = max_detections if output[i] is None else torch.cat((output[i], max_detections))
|
222 |
+
|
223 |
+
if output[i] is not None:
|
224 |
+
output[i] = output[i].cpu().numpy()
|
225 |
+
box_xy, box_wh = (output[i][:, 0:2] + output[i][:, 2:4])/2, output[i][:, 2:4] - output[i][:, 0:2]
|
226 |
+
output[i][:, :4] = self.yolo_correct_boxes(box_xy, box_wh, input_shape, image_shape, letterbox_image)
|
227 |
+
return output
|
utils/utils_fit.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
|
3 |
+
import torch
|
4 |
+
from tqdm import tqdm
|
5 |
+
|
6 |
+
from utils.utils import get_lr
|
7 |
+
|
8 |
+
|
9 |
+
def fit_one_epoch(model_train, model, yolo_loss, loss_history, optimizer, epoch, epoch_step, epoch_step_val, gen, gen_val, Epoch, cuda, fp16, scaler, save_period, save_dir, local_rank=0):
|
10 |
+
loss = 0
|
11 |
+
val_loss = 0
|
12 |
+
|
13 |
+
if local_rank == 0:
|
14 |
+
print('Start Train')
|
15 |
+
pbar = tqdm(total=epoch_step,desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
|
16 |
+
model_train.train()
|
17 |
+
for iteration, batch in enumerate(gen):
|
18 |
+
if iteration >= epoch_step:
|
19 |
+
break
|
20 |
+
|
21 |
+
images, targets = batch[0], batch[1]
|
22 |
+
with torch.no_grad():
|
23 |
+
if cuda:
|
24 |
+
images = images.cuda()
|
25 |
+
targets = [ann.cuda() for ann in targets]
|
26 |
+
#----------------------#
|
27 |
+
# 清零梯度
|
28 |
+
#----------------------#
|
29 |
+
optimizer.zero_grad()
|
30 |
+
if not fp16:
|
31 |
+
#----------------------#
|
32 |
+
# 前向传播
|
33 |
+
#----------------------#
|
34 |
+
outputs = model_train(images)
|
35 |
+
|
36 |
+
loss_value_all = 0
|
37 |
+
#----------------------#
|
38 |
+
# 计算损失
|
39 |
+
#----------------------#
|
40 |
+
for l in range(len(outputs)):
|
41 |
+
loss_item = yolo_loss(l, outputs[l], targets)
|
42 |
+
loss_value_all += loss_item
|
43 |
+
loss_value = loss_value_all
|
44 |
+
|
45 |
+
#----------------------#
|
46 |
+
# 反向传播
|
47 |
+
#----------------------#
|
48 |
+
loss_value.backward()
|
49 |
+
optimizer.step()
|
50 |
+
else:
|
51 |
+
from torch.cuda.amp import autocast
|
52 |
+
with autocast():
|
53 |
+
#----------------------#
|
54 |
+
# 前向传播
|
55 |
+
#----------------------#
|
56 |
+
outputs = model_train(images)
|
57 |
+
|
58 |
+
loss_value_all = 0
|
59 |
+
#----------------------#
|
60 |
+
# 计算损失
|
61 |
+
#----------------------#
|
62 |
+
for l in range(len(outputs)):
|
63 |
+
loss_item = yolo_loss(l, outputs[l], targets)
|
64 |
+
loss_value_all += loss_item
|
65 |
+
loss_value = loss_value_all
|
66 |
+
|
67 |
+
#----------------------#
|
68 |
+
# 反向传播
|
69 |
+
#----------------------#
|
70 |
+
scaler.scale(loss_value).backward()
|
71 |
+
scaler.step(optimizer)
|
72 |
+
scaler.update()
|
73 |
+
|
74 |
+
loss += loss_value.item()
|
75 |
+
|
76 |
+
if local_rank == 0:
|
77 |
+
pbar.set_postfix(**{'loss' : loss / (iteration + 1),
|
78 |
+
'lr' : get_lr(optimizer)})
|
79 |
+
pbar.update(1)
|
80 |
+
|
81 |
+
if local_rank == 0:
|
82 |
+
pbar.close()
|
83 |
+
print('Finish Train')
|
84 |
+
print('Start Validation')
|
85 |
+
pbar = tqdm(total=epoch_step_val, desc=f'Epoch {epoch + 1}/{Epoch}',postfix=dict,mininterval=0.3)
|
86 |
+
|
87 |
+
model_train.eval()
|
88 |
+
for iteration, batch in enumerate(gen_val):
|
89 |
+
if iteration >= epoch_step_val:
|
90 |
+
break
|
91 |
+
images, targets = batch[0], batch[1]
|
92 |
+
with torch.no_grad():
|
93 |
+
if cuda:
|
94 |
+
images = images.cuda()
|
95 |
+
targets = [ann.cuda() for ann in targets]
|
96 |
+
#----------------------#
|
97 |
+
# 清零梯度
|
98 |
+
#----------------------#
|
99 |
+
optimizer.zero_grad()
|
100 |
+
#----------------------#
|
101 |
+
# 前向传播
|
102 |
+
#----------------------#
|
103 |
+
outputs = model_train(images)
|
104 |
+
|
105 |
+
loss_value_all = 0
|
106 |
+
#----------------------#
|
107 |
+
# 计算损失
|
108 |
+
#----------------------#
|
109 |
+
for l in range(len(outputs)):
|
110 |
+
loss_item = yolo_loss(l, outputs[l], targets)
|
111 |
+
loss_value_all += loss_item
|
112 |
+
loss_value = loss_value_all
|
113 |
+
|
114 |
+
val_loss += loss_value.item()
|
115 |
+
if local_rank == 0:
|
116 |
+
pbar.set_postfix(**{'val_loss': val_loss / (iteration + 1)})
|
117 |
+
pbar.update(1)
|
118 |
+
|
119 |
+
if local_rank == 0:
|
120 |
+
pbar.close()
|
121 |
+
print('Finish Validation')
|
122 |
+
loss_history.append_loss(epoch + 1, loss / epoch_step, val_loss / epoch_step_val)
|
123 |
+
print('Epoch:'+ str(epoch + 1) + '/' + str(Epoch))
|
124 |
+
print('Total Loss: %.3f || Val Loss: %.3f ' % (loss / epoch_step, val_loss / epoch_step_val))
|
125 |
+
if (epoch + 1) % save_period == 0 or epoch + 1 == Epoch:
|
126 |
+
torch.save(model.state_dict(), os.path.join(save_dir, "ep%03d-loss%.3f-val_loss%.3f.pth" % (epoch + 1, loss / epoch_step, val_loss / epoch_step_val)))
|
127 |
+
# 每次保存最后一个权重
|
128 |
+
torch.save(model.state_dict(), os.path.join(save_dir, "last.pth" ))
|
utils/utils_map.py
ADDED
@@ -0,0 +1,901 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import glob
|
2 |
+
import json
|
3 |
+
import math
|
4 |
+
import operator
|
5 |
+
import os
|
6 |
+
import shutil
|
7 |
+
import sys
|
8 |
+
|
9 |
+
import cv2
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import numpy as np
|
12 |
+
|
13 |
+
'''
|
14 |
+
0,0 ------> x (width)
|
15 |
+
|
|
16 |
+
| (Left,Top)
|
17 |
+
| *_________
|
18 |
+
| | |
|
19 |
+
| |
|
20 |
+
y |_________|
|
21 |
+
(height) *
|
22 |
+
(Right,Bottom)
|
23 |
+
'''
|
24 |
+
|
25 |
+
def log_average_miss_rate(precision, fp_cumsum, num_images):
|
26 |
+
"""
|
27 |
+
log-average miss rate:
|
28 |
+
Calculated by averaging miss rates at 9 evenly spaced FPPI points
|
29 |
+
between 10e-2 and 10e0, in log-space.
|
30 |
+
|
31 |
+
output:
|
32 |
+
lamr | log-average miss rate
|
33 |
+
mr | miss rate
|
34 |
+
fppi | false positives per image
|
35 |
+
|
36 |
+
references:
|
37 |
+
[1] Dollar, Piotr, et al. "Pedestrian Detection: An Evaluation of the
|
38 |
+
State of the Art." Pattern Analysis and Machine Intelligence, IEEE
|
39 |
+
Transactions on 34.4 (2012): 743 - 761.
|
40 |
+
"""
|
41 |
+
|
42 |
+
if precision.size == 0:
|
43 |
+
lamr = 0
|
44 |
+
mr = 1
|
45 |
+
fppi = 0
|
46 |
+
return lamr, mr, fppi
|
47 |
+
|
48 |
+
fppi = fp_cumsum / float(num_images)
|
49 |
+
mr = (1 - precision)
|
50 |
+
|
51 |
+
fppi_tmp = np.insert(fppi, 0, -1.0)
|
52 |
+
mr_tmp = np.insert(mr, 0, 1.0)
|
53 |
+
|
54 |
+
ref = np.logspace(-2.0, 0.0, num = 9)
|
55 |
+
for i, ref_i in enumerate(ref):
|
56 |
+
j = np.where(fppi_tmp <= ref_i)[-1][-1]
|
57 |
+
ref[i] = mr_tmp[j]
|
58 |
+
|
59 |
+
lamr = math.exp(np.mean(np.log(np.maximum(1e-10, ref))))
|
60 |
+
|
61 |
+
return lamr, mr, fppi
|
62 |
+
|
63 |
+
"""
|
64 |
+
throw error and exit
|
65 |
+
"""
|
66 |
+
def error(msg):
|
67 |
+
print(msg)
|
68 |
+
sys.exit(0)
|
69 |
+
|
70 |
+
"""
|
71 |
+
check if the number is a float between 0.0 and 1.0
|
72 |
+
"""
|
73 |
+
def is_float_between_0_and_1(value):
|
74 |
+
try:
|
75 |
+
val = float(value)
|
76 |
+
if val > 0.0 and val < 1.0:
|
77 |
+
return True
|
78 |
+
else:
|
79 |
+
return False
|
80 |
+
except ValueError:
|
81 |
+
return False
|
82 |
+
|
83 |
+
"""
|
84 |
+
Calculate the AP given the recall and precision array
|
85 |
+
1st) We compute a version of the measured precision/recall curve with
|
86 |
+
precision monotonically decreasing
|
87 |
+
2nd) We compute the AP as the area under this curve by numerical integration.
|
88 |
+
"""
|
89 |
+
def voc_ap(rec, prec):
|
90 |
+
"""
|
91 |
+
--- Official matlab code VOC2012---
|
92 |
+
mrec=[0 ; rec ; 1];
|
93 |
+
mpre=[0 ; prec ; 0];
|
94 |
+
for i=numel(mpre)-1:-1:1
|
95 |
+
mpre(i)=max(mpre(i),mpre(i+1));
|
96 |
+
end
|
97 |
+
i=find(mrec(2:end)~=mrec(1:end-1))+1;
|
98 |
+
ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
|
99 |
+
"""
|
100 |
+
rec.insert(0, 0.0) # insert 0.0 at begining of list
|
101 |
+
rec.append(1.0) # insert 1.0 at end of list
|
102 |
+
mrec = rec[:]
|
103 |
+
prec.insert(0, 0.0) # insert 0.0 at begining of list
|
104 |
+
prec.append(0.0) # insert 0.0 at end of list
|
105 |
+
mpre = prec[:]
|
106 |
+
"""
|
107 |
+
This part makes the precision monotonically decreasing
|
108 |
+
(goes from the end to the beginning)
|
109 |
+
matlab: for i=numel(mpre)-1:-1:1
|
110 |
+
mpre(i)=max(mpre(i),mpre(i+1));
|
111 |
+
"""
|
112 |
+
for i in range(len(mpre)-2, -1, -1):
|
113 |
+
mpre[i] = max(mpre[i], mpre[i+1])
|
114 |
+
"""
|
115 |
+
This part creates a list of indexes where the recall changes
|
116 |
+
matlab: i=find(mrec(2:end)~=mrec(1:end-1))+1;
|
117 |
+
"""
|
118 |
+
i_list = []
|
119 |
+
for i in range(1, len(mrec)):
|
120 |
+
if mrec[i] != mrec[i-1]:
|
121 |
+
i_list.append(i) # if it was matlab would be i + 1
|
122 |
+
"""
|
123 |
+
The Average Precision (AP) is the area under the curve
|
124 |
+
(numerical integration)
|
125 |
+
matlab: ap=sum((mrec(i)-mrec(i-1)).*mpre(i));
|
126 |
+
"""
|
127 |
+
ap = 0.0
|
128 |
+
for i in i_list:
|
129 |
+
ap += ((mrec[i]-mrec[i-1])*mpre[i])
|
130 |
+
return ap, mrec, mpre
|
131 |
+
|
132 |
+
|
133 |
+
"""
|
134 |
+
Convert the lines of a file to a list
|
135 |
+
"""
|
136 |
+
def file_lines_to_list(path):
|
137 |
+
# open txt file lines to a list
|
138 |
+
with open(path) as f:
|
139 |
+
content = f.readlines()
|
140 |
+
# remove whitespace characters like `\n` at the end of each line
|
141 |
+
content = [x.strip() for x in content]
|
142 |
+
return content
|
143 |
+
|
144 |
+
"""
|
145 |
+
Draws text in image
|
146 |
+
"""
|
147 |
+
def draw_text_in_image(img, text, pos, color, line_width):
|
148 |
+
font = cv2.FONT_HERSHEY_PLAIN
|
149 |
+
fontScale = 1
|
150 |
+
lineType = 1
|
151 |
+
bottomLeftCornerOfText = pos
|
152 |
+
cv2.putText(img, text,
|
153 |
+
bottomLeftCornerOfText,
|
154 |
+
font,
|
155 |
+
fontScale,
|
156 |
+
color,
|
157 |
+
lineType)
|
158 |
+
text_width, _ = cv2.getTextSize(text, font, fontScale, lineType)[0]
|
159 |
+
return img, (line_width + text_width)
|
160 |
+
|
161 |
+
"""
|
162 |
+
Plot - adjust axes
|
163 |
+
"""
|
164 |
+
def adjust_axes(r, t, fig, axes):
|
165 |
+
# get text width for re-scaling
|
166 |
+
bb = t.get_window_extent(renderer=r)
|
167 |
+
text_width_inches = bb.width / fig.dpi
|
168 |
+
# get axis width in inches
|
169 |
+
current_fig_width = fig.get_figwidth()
|
170 |
+
new_fig_width = current_fig_width + text_width_inches
|
171 |
+
propotion = new_fig_width / current_fig_width
|
172 |
+
# get axis limit
|
173 |
+
x_lim = axes.get_xlim()
|
174 |
+
axes.set_xlim([x_lim[0], x_lim[1]*propotion])
|
175 |
+
|
176 |
+
"""
|
177 |
+
Draw plot using Matplotlib
|
178 |
+
"""
|
179 |
+
def draw_plot_func(dictionary, n_classes, window_title, plot_title, x_label, output_path, to_show, plot_color, true_p_bar):
|
180 |
+
# sort the dictionary by decreasing value, into a list of tuples
|
181 |
+
sorted_dic_by_value = sorted(dictionary.items(), key=operator.itemgetter(1))
|
182 |
+
# unpacking the list of tuples into two lists
|
183 |
+
sorted_keys, sorted_values = zip(*sorted_dic_by_value)
|
184 |
+
#
|
185 |
+
if true_p_bar != "":
|
186 |
+
"""
|
187 |
+
Special case to draw in:
|
188 |
+
- green -> TP: True Positives (object detected and matches ground-truth)
|
189 |
+
- red -> FP: False Positives (object detected but does not match ground-truth)
|
190 |
+
- orange -> FN: False Negatives (object not detected but present in the ground-truth)
|
191 |
+
"""
|
192 |
+
fp_sorted = []
|
193 |
+
tp_sorted = []
|
194 |
+
for key in sorted_keys:
|
195 |
+
fp_sorted.append(dictionary[key] - true_p_bar[key])
|
196 |
+
tp_sorted.append(true_p_bar[key])
|
197 |
+
plt.barh(range(n_classes), fp_sorted, align='center', color='crimson', label='False Positive')
|
198 |
+
plt.barh(range(n_classes), tp_sorted, align='center', color='forestgreen', label='True Positive', left=fp_sorted)
|
199 |
+
# add legend
|
200 |
+
plt.legend(loc='lower right')
|
201 |
+
"""
|
202 |
+
Write number on side of bar
|
203 |
+
"""
|
204 |
+
fig = plt.gcf() # gcf - get current figure
|
205 |
+
axes = plt.gca()
|
206 |
+
r = fig.canvas.get_renderer()
|
207 |
+
for i, val in enumerate(sorted_values):
|
208 |
+
fp_val = fp_sorted[i]
|
209 |
+
tp_val = tp_sorted[i]
|
210 |
+
fp_str_val = " " + str(fp_val)
|
211 |
+
tp_str_val = fp_str_val + " " + str(tp_val)
|
212 |
+
# trick to paint multicolor with offset:
|
213 |
+
# first paint everything and then repaint the first number
|
214 |
+
t = plt.text(val, i, tp_str_val, color='forestgreen', va='center', fontweight='bold')
|
215 |
+
plt.text(val, i, fp_str_val, color='crimson', va='center', fontweight='bold')
|
216 |
+
if i == (len(sorted_values)-1): # largest bar
|
217 |
+
adjust_axes(r, t, fig, axes)
|
218 |
+
else:
|
219 |
+
plt.barh(range(n_classes), sorted_values, color=plot_color)
|
220 |
+
"""
|
221 |
+
Write number on side of bar
|
222 |
+
"""
|
223 |
+
fig = plt.gcf() # gcf - get current figure
|
224 |
+
axes = plt.gca()
|
225 |
+
r = fig.canvas.get_renderer()
|
226 |
+
for i, val in enumerate(sorted_values):
|
227 |
+
str_val = " " + str(val) # add a space before
|
228 |
+
if val < 1.0:
|
229 |
+
str_val = " {0:.2f}".format(val)
|
230 |
+
t = plt.text(val, i, str_val, color=plot_color, va='center', fontweight='bold')
|
231 |
+
# re-set axes to show number inside the figure
|
232 |
+
if i == (len(sorted_values)-1): # largest bar
|
233 |
+
adjust_axes(r, t, fig, axes)
|
234 |
+
# set window title
|
235 |
+
fig.canvas.set_window_title(window_title)
|
236 |
+
# write classes in y axis
|
237 |
+
tick_font_size = 12
|
238 |
+
plt.yticks(range(n_classes), sorted_keys, fontsize=tick_font_size)
|
239 |
+
"""
|
240 |
+
Re-scale height accordingly
|
241 |
+
"""
|
242 |
+
init_height = fig.get_figheight()
|
243 |
+
# comput the matrix height in points and inches
|
244 |
+
dpi = fig.dpi
|
245 |
+
height_pt = n_classes * (tick_font_size * 1.4) # 1.4 (some spacing)
|
246 |
+
height_in = height_pt / dpi
|
247 |
+
# compute the required figure height
|
248 |
+
top_margin = 0.15 # in percentage of the figure height
|
249 |
+
bottom_margin = 0.05 # in percentage of the figure height
|
250 |
+
figure_height = height_in / (1 - top_margin - bottom_margin)
|
251 |
+
# set new height
|
252 |
+
if figure_height > init_height:
|
253 |
+
fig.set_figheight(figure_height)
|
254 |
+
|
255 |
+
# set plot title
|
256 |
+
plt.title(plot_title, fontsize=14)
|
257 |
+
# set axis titles
|
258 |
+
# plt.xlabel('classes')
|
259 |
+
plt.xlabel(x_label, fontsize='large')
|
260 |
+
# adjust size of window
|
261 |
+
fig.tight_layout()
|
262 |
+
# save the plot
|
263 |
+
fig.savefig(output_path)
|
264 |
+
# show image
|
265 |
+
if to_show:
|
266 |
+
plt.show()
|
267 |
+
# close the plot
|
268 |
+
plt.close()
|
269 |
+
|
270 |
+
def get_map(MINOVERLAP, draw_plot, path = './map_out'):
|
271 |
+
GT_PATH = os.path.join(path, 'ground-truth')
|
272 |
+
DR_PATH = os.path.join(path, 'detection-results')
|
273 |
+
IMG_PATH = os.path.join(path, 'images-optional')
|
274 |
+
TEMP_FILES_PATH = os.path.join(path, '.temp_files')
|
275 |
+
RESULTS_FILES_PATH = os.path.join(path, 'results')
|
276 |
+
|
277 |
+
show_animation = True
|
278 |
+
if os.path.exists(IMG_PATH):
|
279 |
+
for dirpath, dirnames, files in os.walk(IMG_PATH):
|
280 |
+
if not files:
|
281 |
+
show_animation = False
|
282 |
+
else:
|
283 |
+
show_animation = False
|
284 |
+
|
285 |
+
if not os.path.exists(TEMP_FILES_PATH):
|
286 |
+
os.makedirs(TEMP_FILES_PATH)
|
287 |
+
|
288 |
+
if os.path.exists(RESULTS_FILES_PATH):
|
289 |
+
shutil.rmtree(RESULTS_FILES_PATH)
|
290 |
+
if draw_plot:
|
291 |
+
os.makedirs(os.path.join(RESULTS_FILES_PATH, "AP"))
|
292 |
+
os.makedirs(os.path.join(RESULTS_FILES_PATH, "F1"))
|
293 |
+
os.makedirs(os.path.join(RESULTS_FILES_PATH, "Recall"))
|
294 |
+
os.makedirs(os.path.join(RESULTS_FILES_PATH, "Precision"))
|
295 |
+
if show_animation:
|
296 |
+
os.makedirs(os.path.join(RESULTS_FILES_PATH, "images", "detections_one_by_one"))
|
297 |
+
|
298 |
+
ground_truth_files_list = glob.glob(GT_PATH + '/*.txt')
|
299 |
+
if len(ground_truth_files_list) == 0:
|
300 |
+
error("Error: No ground-truth files found!")
|
301 |
+
ground_truth_files_list.sort()
|
302 |
+
gt_counter_per_class = {}
|
303 |
+
counter_images_per_class = {}
|
304 |
+
|
305 |
+
for txt_file in ground_truth_files_list:
|
306 |
+
file_id = txt_file.split(".txt", 1)[0]
|
307 |
+
file_id = os.path.basename(os.path.normpath(file_id))
|
308 |
+
temp_path = os.path.join(DR_PATH, (file_id + ".txt"))
|
309 |
+
if not os.path.exists(temp_path):
|
310 |
+
error_msg = "Error. File not found: {}\n".format(temp_path)
|
311 |
+
error(error_msg)
|
312 |
+
lines_list = file_lines_to_list(txt_file)
|
313 |
+
bounding_boxes = []
|
314 |
+
is_difficult = False
|
315 |
+
already_seen_classes = []
|
316 |
+
for line in lines_list:
|
317 |
+
try:
|
318 |
+
if "difficult" in line:
|
319 |
+
class_name, left, top, right, bottom, _difficult = line.split()
|
320 |
+
is_difficult = True
|
321 |
+
else:
|
322 |
+
class_name, left, top, right, bottom = line.split()
|
323 |
+
except:
|
324 |
+
if "difficult" in line:
|
325 |
+
line_split = line.split()
|
326 |
+
_difficult = line_split[-1]
|
327 |
+
bottom = line_split[-2]
|
328 |
+
right = line_split[-3]
|
329 |
+
top = line_split[-4]
|
330 |
+
left = line_split[-5]
|
331 |
+
class_name = ""
|
332 |
+
for name in line_split[:-5]:
|
333 |
+
class_name += name + " "
|
334 |
+
class_name = class_name[:-1]
|
335 |
+
is_difficult = True
|
336 |
+
else:
|
337 |
+
line_split = line.split()
|
338 |
+
bottom = line_split[-1]
|
339 |
+
right = line_split[-2]
|
340 |
+
top = line_split[-3]
|
341 |
+
left = line_split[-4]
|
342 |
+
class_name = ""
|
343 |
+
for name in line_split[:-4]:
|
344 |
+
class_name += name + " "
|
345 |
+
class_name = class_name[:-1]
|
346 |
+
|
347 |
+
bbox = left + " " + top + " " + right + " " + bottom
|
348 |
+
if is_difficult:
|
349 |
+
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False, "difficult":True})
|
350 |
+
is_difficult = False
|
351 |
+
else:
|
352 |
+
bounding_boxes.append({"class_name":class_name, "bbox":bbox, "used":False})
|
353 |
+
if class_name in gt_counter_per_class:
|
354 |
+
gt_counter_per_class[class_name] += 1
|
355 |
+
else:
|
356 |
+
gt_counter_per_class[class_name] = 1
|
357 |
+
|
358 |
+
if class_name not in already_seen_classes:
|
359 |
+
if class_name in counter_images_per_class:
|
360 |
+
counter_images_per_class[class_name] += 1
|
361 |
+
else:
|
362 |
+
counter_images_per_class[class_name] = 1
|
363 |
+
already_seen_classes.append(class_name)
|
364 |
+
|
365 |
+
with open(TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json", 'w') as outfile:
|
366 |
+
json.dump(bounding_boxes, outfile)
|
367 |
+
|
368 |
+
gt_classes = list(gt_counter_per_class.keys())
|
369 |
+
gt_classes = sorted(gt_classes)
|
370 |
+
n_classes = len(gt_classes)
|
371 |
+
|
372 |
+
dr_files_list = glob.glob(DR_PATH + '/*.txt')
|
373 |
+
dr_files_list.sort()
|
374 |
+
for class_index, class_name in enumerate(gt_classes):
|
375 |
+
bounding_boxes = []
|
376 |
+
for txt_file in dr_files_list:
|
377 |
+
file_id = txt_file.split(".txt",1)[0]
|
378 |
+
file_id = os.path.basename(os.path.normpath(file_id))
|
379 |
+
temp_path = os.path.join(GT_PATH, (file_id + ".txt"))
|
380 |
+
if class_index == 0:
|
381 |
+
if not os.path.exists(temp_path):
|
382 |
+
error_msg = "Error. File not found: {}\n".format(temp_path)
|
383 |
+
error(error_msg)
|
384 |
+
lines = file_lines_to_list(txt_file)
|
385 |
+
for line in lines:
|
386 |
+
try:
|
387 |
+
tmp_class_name, confidence, left, top, right, bottom = line.split()
|
388 |
+
except:
|
389 |
+
line_split = line.split()
|
390 |
+
bottom = line_split[-1]
|
391 |
+
right = line_split[-2]
|
392 |
+
top = line_split[-3]
|
393 |
+
left = line_split[-4]
|
394 |
+
confidence = line_split[-5]
|
395 |
+
tmp_class_name = ""
|
396 |
+
for name in line_split[:-5]:
|
397 |
+
tmp_class_name += name + " "
|
398 |
+
tmp_class_name = tmp_class_name[:-1]
|
399 |
+
|
400 |
+
if tmp_class_name == class_name:
|
401 |
+
bbox = left + " " + top + " " + right + " " +bottom
|
402 |
+
bounding_boxes.append({"confidence":confidence, "file_id":file_id, "bbox":bbox})
|
403 |
+
|
404 |
+
bounding_boxes.sort(key=lambda x:float(x['confidence']), reverse=True)
|
405 |
+
with open(TEMP_FILES_PATH + "/" + class_name + "_dr.json", 'w') as outfile:
|
406 |
+
json.dump(bounding_boxes, outfile)
|
407 |
+
|
408 |
+
sum_AP = 0.0
|
409 |
+
ap_dictionary = {}
|
410 |
+
lamr_dictionary = {}
|
411 |
+
with open(RESULTS_FILES_PATH + "/results.txt", 'w') as results_file:
|
412 |
+
results_file.write("# AP and precision/recall per class\n")
|
413 |
+
count_true_positives = {}
|
414 |
+
|
415 |
+
for class_index, class_name in enumerate(gt_classes):
|
416 |
+
count_true_positives[class_name] = 0
|
417 |
+
dr_file = TEMP_FILES_PATH + "/" + class_name + "_dr.json"
|
418 |
+
dr_data = json.load(open(dr_file))
|
419 |
+
|
420 |
+
nd = len(dr_data)
|
421 |
+
tp = [0] * nd
|
422 |
+
fp = [0] * nd
|
423 |
+
score = [0] * nd
|
424 |
+
score05_idx = 0
|
425 |
+
for idx, detection in enumerate(dr_data):
|
426 |
+
file_id = detection["file_id"]
|
427 |
+
score[idx] = float(detection["confidence"])
|
428 |
+
if score[idx] > 0.5:
|
429 |
+
score05_idx = idx
|
430 |
+
|
431 |
+
if show_animation:
|
432 |
+
ground_truth_img = glob.glob1(IMG_PATH, file_id + ".*")
|
433 |
+
if len(ground_truth_img) == 0:
|
434 |
+
error("Error. Image not found with id: " + file_id)
|
435 |
+
elif len(ground_truth_img) > 1:
|
436 |
+
error("Error. Multiple image with id: " + file_id)
|
437 |
+
else:
|
438 |
+
img = cv2.imread(IMG_PATH + "/" + ground_truth_img[0])
|
439 |
+
img_cumulative_path = RESULTS_FILES_PATH + "/images/" + ground_truth_img[0]
|
440 |
+
if os.path.isfile(img_cumulative_path):
|
441 |
+
img_cumulative = cv2.imread(img_cumulative_path)
|
442 |
+
else:
|
443 |
+
img_cumulative = img.copy()
|
444 |
+
bottom_border = 60
|
445 |
+
BLACK = [0, 0, 0]
|
446 |
+
img = cv2.copyMakeBorder(img, 0, bottom_border, 0, 0, cv2.BORDER_CONSTANT, value=BLACK)
|
447 |
+
|
448 |
+
gt_file = TEMP_FILES_PATH + "/" + file_id + "_ground_truth.json"
|
449 |
+
ground_truth_data = json.load(open(gt_file))
|
450 |
+
ovmax = -1
|
451 |
+
gt_match = -1
|
452 |
+
bb = [float(x) for x in detection["bbox"].split()]
|
453 |
+
for obj in ground_truth_data:
|
454 |
+
if obj["class_name"] == class_name:
|
455 |
+
bbgt = [ float(x) for x in obj["bbox"].split() ]
|
456 |
+
bi = [max(bb[0],bbgt[0]), max(bb[1],bbgt[1]), min(bb[2],bbgt[2]), min(bb[3],bbgt[3])]
|
457 |
+
iw = bi[2] - bi[0] + 1
|
458 |
+
ih = bi[3] - bi[1] + 1
|
459 |
+
if iw > 0 and ih > 0:
|
460 |
+
ua = (bb[2] - bb[0] + 1) * (bb[3] - bb[1] + 1) + (bbgt[2] - bbgt[0]
|
461 |
+
+ 1) * (bbgt[3] - bbgt[1] + 1) - iw * ih
|
462 |
+
ov = iw * ih / ua
|
463 |
+
if ov > ovmax:
|
464 |
+
ovmax = ov
|
465 |
+
gt_match = obj
|
466 |
+
|
467 |
+
if show_animation:
|
468 |
+
status = "NO MATCH FOUND!"
|
469 |
+
|
470 |
+
min_overlap = MINOVERLAP
|
471 |
+
if ovmax >= min_overlap:
|
472 |
+
if "difficult" not in gt_match:
|
473 |
+
if not bool(gt_match["used"]):
|
474 |
+
tp[idx] = 1
|
475 |
+
gt_match["used"] = True
|
476 |
+
count_true_positives[class_name] += 1
|
477 |
+
with open(gt_file, 'w') as f:
|
478 |
+
f.write(json.dumps(ground_truth_data))
|
479 |
+
if show_animation:
|
480 |
+
status = "MATCH!"
|
481 |
+
else:
|
482 |
+
fp[idx] = 1
|
483 |
+
if show_animation:
|
484 |
+
status = "REPEATED MATCH!"
|
485 |
+
else:
|
486 |
+
fp[idx] = 1
|
487 |
+
if ovmax > 0:
|
488 |
+
status = "INSUFFICIENT OVERLAP"
|
489 |
+
|
490 |
+
"""
|
491 |
+
Draw image to show animation
|
492 |
+
"""
|
493 |
+
if show_animation:
|
494 |
+
height, widht = img.shape[:2]
|
495 |
+
white = (255,255,255)
|
496 |
+
light_blue = (255,200,100)
|
497 |
+
green = (0,255,0)
|
498 |
+
light_red = (30,30,255)
|
499 |
+
margin = 10
|
500 |
+
# 1nd line
|
501 |
+
v_pos = int(height - margin - (bottom_border / 2.0))
|
502 |
+
text = "Image: " + ground_truth_img[0] + " "
|
503 |
+
img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
|
504 |
+
text = "Class [" + str(class_index) + "/" + str(n_classes) + "]: " + class_name + " "
|
505 |
+
img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), light_blue, line_width)
|
506 |
+
if ovmax != -1:
|
507 |
+
color = light_red
|
508 |
+
if status == "INSUFFICIENT OVERLAP":
|
509 |
+
text = "IoU: {0:.2f}% ".format(ovmax*100) + "< {0:.2f}% ".format(min_overlap*100)
|
510 |
+
else:
|
511 |
+
text = "IoU: {0:.2f}% ".format(ovmax*100) + ">= {0:.2f}% ".format(min_overlap*100)
|
512 |
+
color = green
|
513 |
+
img, _ = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
|
514 |
+
# 2nd line
|
515 |
+
v_pos += int(bottom_border / 2.0)
|
516 |
+
rank_pos = str(idx+1)
|
517 |
+
text = "Detection #rank: " + rank_pos + " confidence: {0:.2f}% ".format(float(detection["confidence"])*100)
|
518 |
+
img, line_width = draw_text_in_image(img, text, (margin, v_pos), white, 0)
|
519 |
+
color = light_red
|
520 |
+
if status == "MATCH!":
|
521 |
+
color = green
|
522 |
+
text = "Result: " + status + " "
|
523 |
+
img, line_width = draw_text_in_image(img, text, (margin + line_width, v_pos), color, line_width)
|
524 |
+
|
525 |
+
font = cv2.FONT_HERSHEY_SIMPLEX
|
526 |
+
if ovmax > 0:
|
527 |
+
bbgt = [ int(round(float(x))) for x in gt_match["bbox"].split() ]
|
528 |
+
cv2.rectangle(img,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
|
529 |
+
cv2.rectangle(img_cumulative,(bbgt[0],bbgt[1]),(bbgt[2],bbgt[3]),light_blue,2)
|
530 |
+
cv2.putText(img_cumulative, class_name, (bbgt[0],bbgt[1] - 5), font, 0.6, light_blue, 1, cv2.LINE_AA)
|
531 |
+
bb = [int(i) for i in bb]
|
532 |
+
cv2.rectangle(img,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
|
533 |
+
cv2.rectangle(img_cumulative,(bb[0],bb[1]),(bb[2],bb[3]),color,2)
|
534 |
+
cv2.putText(img_cumulative, class_name, (bb[0],bb[1] - 5), font, 0.6, color, 1, cv2.LINE_AA)
|
535 |
+
|
536 |
+
cv2.imshow("Animation", img)
|
537 |
+
cv2.waitKey(20)
|
538 |
+
output_img_path = RESULTS_FILES_PATH + "/images/detections_one_by_one/" + class_name + "_detection" + str(idx) + ".jpg"
|
539 |
+
cv2.imwrite(output_img_path, img)
|
540 |
+
cv2.imwrite(img_cumulative_path, img_cumulative)
|
541 |
+
|
542 |
+
cumsum = 0
|
543 |
+
for idx, val in enumerate(fp):
|
544 |
+
fp[idx] += cumsum
|
545 |
+
cumsum += val
|
546 |
+
|
547 |
+
cumsum = 0
|
548 |
+
for idx, val in enumerate(tp):
|
549 |
+
tp[idx] += cumsum
|
550 |
+
cumsum += val
|
551 |
+
|
552 |
+
rec = tp[:]
|
553 |
+
for idx, val in enumerate(tp):
|
554 |
+
rec[idx] = float(tp[idx]) / np.maximum(gt_counter_per_class[class_name], 1)
|
555 |
+
|
556 |
+
prec = tp[:]
|
557 |
+
for idx, val in enumerate(tp):
|
558 |
+
prec[idx] = float(tp[idx]) / np.maximum((fp[idx] + tp[idx]), 1)
|
559 |
+
|
560 |
+
ap, mrec, mprec = voc_ap(rec[:], prec[:])
|
561 |
+
F1 = np.array(rec)*np.array(prec)*2 / np.where((np.array(prec)+np.array(rec))==0, 1, (np.array(prec)+np.array(rec)))
|
562 |
+
|
563 |
+
sum_AP += ap
|
564 |
+
text = "{0:.2f}%".format(ap*100) + " = " + class_name + " AP " #class_name + " AP = {0:.2f}%".format(ap*100)
|
565 |
+
|
566 |
+
if len(prec)>0:
|
567 |
+
F1_text = "{0:.2f}".format(F1[score05_idx]) + " = " + class_name + " F1 "
|
568 |
+
Recall_text = "{0:.2f}%".format(rec[score05_idx]*100) + " = " + class_name + " Recall "
|
569 |
+
Precision_text = "{0:.2f}%".format(prec[score05_idx]*100) + " = " + class_name + " Precision "
|
570 |
+
else:
|
571 |
+
F1_text = "0.00" + " = " + class_name + " F1 "
|
572 |
+
Recall_text = "0.00%" + " = " + class_name + " Recall "
|
573 |
+
Precision_text = "0.00%" + " = " + class_name + " Precision "
|
574 |
+
|
575 |
+
rounded_prec = [ '%.2f' % elem for elem in prec ]
|
576 |
+
rounded_rec = [ '%.2f' % elem for elem in rec ]
|
577 |
+
results_file.write(text + "\n Precision: " + str(rounded_prec) + "\n Recall :" + str(rounded_rec) + "\n\n")
|
578 |
+
if len(prec)>0:
|
579 |
+
print(text + "\t||\tscore_threhold=0.5 : " + "F1=" + "{0:.2f}".format(F1[score05_idx])\
|
580 |
+
+ " ; Recall=" + "{0:.2f}%".format(rec[score05_idx]*100) + " ; Precision=" + "{0:.2f}%".format(prec[score05_idx]*100))
|
581 |
+
else:
|
582 |
+
print(text + "\t||\tscore_threhold=0.5 : F1=0.00% ; Recall=0.00% ; Precision=0.00%")
|
583 |
+
ap_dictionary[class_name] = ap
|
584 |
+
|
585 |
+
n_images = counter_images_per_class[class_name]
|
586 |
+
lamr, mr, fppi = log_average_miss_rate(np.array(rec), np.array(fp), n_images)
|
587 |
+
lamr_dictionary[class_name] = lamr
|
588 |
+
|
589 |
+
if draw_plot:
|
590 |
+
plt.plot(rec, prec, '-o')
|
591 |
+
area_under_curve_x = mrec[:-1] + [mrec[-2]] + [mrec[-1]]
|
592 |
+
area_under_curve_y = mprec[:-1] + [0.0] + [mprec[-1]]
|
593 |
+
plt.fill_between(area_under_curve_x, 0, area_under_curve_y, alpha=0.2, edgecolor='r')
|
594 |
+
|
595 |
+
fig = plt.gcf()
|
596 |
+
fig.canvas.set_window_title('AP ' + class_name)
|
597 |
+
|
598 |
+
plt.title('class: ' + text)
|
599 |
+
plt.xlabel('Recall')
|
600 |
+
plt.ylabel('Precision')
|
601 |
+
axes = plt.gca()
|
602 |
+
axes.set_xlim([0.0,1.0])
|
603 |
+
axes.set_ylim([0.0,1.05])
|
604 |
+
fig.savefig(RESULTS_FILES_PATH + "/AP/" + class_name + ".png")
|
605 |
+
plt.cla()
|
606 |
+
|
607 |
+
plt.plot(score, F1, "-", color='orangered')
|
608 |
+
plt.title('class: ' + F1_text + "\nscore_threhold=0.5")
|
609 |
+
plt.xlabel('Score_Threhold')
|
610 |
+
plt.ylabel('F1')
|
611 |
+
axes = plt.gca()
|
612 |
+
axes.set_xlim([0.0,1.0])
|
613 |
+
axes.set_ylim([0.0,1.05])
|
614 |
+
fig.savefig(RESULTS_FILES_PATH + "/F1/" + class_name + ".png")
|
615 |
+
plt.cla()
|
616 |
+
|
617 |
+
plt.plot(score, rec, "-H", color='gold')
|
618 |
+
plt.title('class: ' + Recall_text + "\nscore_threhold=0.5")
|
619 |
+
plt.xlabel('Score_Threhold')
|
620 |
+
plt.ylabel('Recall')
|
621 |
+
axes = plt.gca()
|
622 |
+
axes.set_xlim([0.0,1.0])
|
623 |
+
axes.set_ylim([0.0,1.05])
|
624 |
+
fig.savefig(RESULTS_FILES_PATH + "/Recall/" + class_name + ".png")
|
625 |
+
plt.cla()
|
626 |
+
|
627 |
+
plt.plot(score, prec, "-s", color='palevioletred')
|
628 |
+
plt.title('class: ' + Precision_text + "\nscore_threhold=0.5")
|
629 |
+
plt.xlabel('Score_Threhold')
|
630 |
+
plt.ylabel('Precision')
|
631 |
+
axes = plt.gca()
|
632 |
+
axes.set_xlim([0.0,1.0])
|
633 |
+
axes.set_ylim([0.0,1.05])
|
634 |
+
fig.savefig(RESULTS_FILES_PATH + "/Precision/" + class_name + ".png")
|
635 |
+
plt.cla()
|
636 |
+
|
637 |
+
if show_animation:
|
638 |
+
cv2.destroyAllWindows()
|
639 |
+
|
640 |
+
results_file.write("\n# mAP of all classes\n")
|
641 |
+
mAP = sum_AP / n_classes
|
642 |
+
text = "mAP = {0:.2f}%".format(mAP*100)
|
643 |
+
results_file.write(text + "\n")
|
644 |
+
print(text)
|
645 |
+
|
646 |
+
shutil.rmtree(TEMP_FILES_PATH)
|
647 |
+
|
648 |
+
"""
|
649 |
+
Count total of detection-results
|
650 |
+
"""
|
651 |
+
det_counter_per_class = {}
|
652 |
+
for txt_file in dr_files_list:
|
653 |
+
lines_list = file_lines_to_list(txt_file)
|
654 |
+
for line in lines_list:
|
655 |
+
class_name = line.split()[0]
|
656 |
+
if class_name in det_counter_per_class:
|
657 |
+
det_counter_per_class[class_name] += 1
|
658 |
+
else:
|
659 |
+
det_counter_per_class[class_name] = 1
|
660 |
+
dr_classes = list(det_counter_per_class.keys())
|
661 |
+
|
662 |
+
"""
|
663 |
+
Write number of ground-truth objects per class to results.txt
|
664 |
+
"""
|
665 |
+
with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
|
666 |
+
results_file.write("\n# Number of ground-truth objects per class\n")
|
667 |
+
for class_name in sorted(gt_counter_per_class):
|
668 |
+
results_file.write(class_name + ": " + str(gt_counter_per_class[class_name]) + "\n")
|
669 |
+
|
670 |
+
"""
|
671 |
+
Finish counting true positives
|
672 |
+
"""
|
673 |
+
for class_name in dr_classes:
|
674 |
+
if class_name not in gt_classes:
|
675 |
+
count_true_positives[class_name] = 0
|
676 |
+
|
677 |
+
"""
|
678 |
+
Write number of detected objects per class to results.txt
|
679 |
+
"""
|
680 |
+
with open(RESULTS_FILES_PATH + "/results.txt", 'a') as results_file:
|
681 |
+
results_file.write("\n# Number of detected objects per class\n")
|
682 |
+
for class_name in sorted(dr_classes):
|
683 |
+
n_det = det_counter_per_class[class_name]
|
684 |
+
text = class_name + ": " + str(n_det)
|
685 |
+
text += " (tp:" + str(count_true_positives[class_name]) + ""
|
686 |
+
text += ", fp:" + str(n_det - count_true_positives[class_name]) + ")\n"
|
687 |
+
results_file.write(text)
|
688 |
+
|
689 |
+
"""
|
690 |
+
Plot the total number of occurences of each class in the ground-truth
|
691 |
+
"""
|
692 |
+
if draw_plot:
|
693 |
+
window_title = "ground-truth-info"
|
694 |
+
plot_title = "ground-truth\n"
|
695 |
+
plot_title += "(" + str(len(ground_truth_files_list)) + " files and " + str(n_classes) + " classes)"
|
696 |
+
x_label = "Number of objects per class"
|
697 |
+
output_path = RESULTS_FILES_PATH + "/ground-truth-info.png"
|
698 |
+
to_show = False
|
699 |
+
plot_color = 'forestgreen'
|
700 |
+
draw_plot_func(
|
701 |
+
gt_counter_per_class,
|
702 |
+
n_classes,
|
703 |
+
window_title,
|
704 |
+
plot_title,
|
705 |
+
x_label,
|
706 |
+
output_path,
|
707 |
+
to_show,
|
708 |
+
plot_color,
|
709 |
+
'',
|
710 |
+
)
|
711 |
+
|
712 |
+
# """
|
713 |
+
# Plot the total number of occurences of each class in the "detection-results" folder
|
714 |
+
# """
|
715 |
+
# if draw_plot:
|
716 |
+
# window_title = "detection-results-info"
|
717 |
+
# # Plot title
|
718 |
+
# plot_title = "detection-results\n"
|
719 |
+
# plot_title += "(" + str(len(dr_files_list)) + " files and "
|
720 |
+
# count_non_zero_values_in_dictionary = sum(int(x) > 0 for x in list(det_counter_per_class.values()))
|
721 |
+
# plot_title += str(count_non_zero_values_in_dictionary) + " detected classes)"
|
722 |
+
# # end Plot title
|
723 |
+
# x_label = "Number of objects per class"
|
724 |
+
# output_path = RESULTS_FILES_PATH + "/detection-results-info.png"
|
725 |
+
# to_show = False
|
726 |
+
# plot_color = 'forestgreen'
|
727 |
+
# true_p_bar = count_true_positives
|
728 |
+
# draw_plot_func(
|
729 |
+
# det_counter_per_class,
|
730 |
+
# len(det_counter_per_class),
|
731 |
+
# window_title,
|
732 |
+
# plot_title,
|
733 |
+
# x_label,
|
734 |
+
# output_path,
|
735 |
+
# to_show,
|
736 |
+
# plot_color,
|
737 |
+
# true_p_bar
|
738 |
+
# )
|
739 |
+
|
740 |
+
"""
|
741 |
+
Draw log-average miss rate plot (Show lamr of all classes in decreasing order)
|
742 |
+
"""
|
743 |
+
if draw_plot:
|
744 |
+
window_title = "lamr"
|
745 |
+
plot_title = "log-average miss rate"
|
746 |
+
x_label = "log-average miss rate"
|
747 |
+
output_path = RESULTS_FILES_PATH + "/lamr.png"
|
748 |
+
to_show = False
|
749 |
+
plot_color = 'royalblue'
|
750 |
+
draw_plot_func(
|
751 |
+
lamr_dictionary,
|
752 |
+
n_classes,
|
753 |
+
window_title,
|
754 |
+
plot_title,
|
755 |
+
x_label,
|
756 |
+
output_path,
|
757 |
+
to_show,
|
758 |
+
plot_color,
|
759 |
+
""
|
760 |
+
)
|
761 |
+
|
762 |
+
"""
|
763 |
+
Draw mAP plot (Show AP's of all classes in decreasing order)
|
764 |
+
"""
|
765 |
+
if draw_plot:
|
766 |
+
window_title = "mAP"
|
767 |
+
plot_title = "mAP = {0:.2f}%".format(mAP*100)
|
768 |
+
x_label = "Average Precision"
|
769 |
+
output_path = RESULTS_FILES_PATH + "/mAP.png"
|
770 |
+
to_show = True
|
771 |
+
plot_color = 'royalblue'
|
772 |
+
draw_plot_func(
|
773 |
+
ap_dictionary,
|
774 |
+
n_classes,
|
775 |
+
window_title,
|
776 |
+
plot_title,
|
777 |
+
x_label,
|
778 |
+
output_path,
|
779 |
+
to_show,
|
780 |
+
plot_color,
|
781 |
+
""
|
782 |
+
)
|
783 |
+
|
784 |
+
def preprocess_gt(gt_path, class_names):
|
785 |
+
image_ids = os.listdir(gt_path)
|
786 |
+
results = {}
|
787 |
+
|
788 |
+
images = []
|
789 |
+
bboxes = []
|
790 |
+
for i, image_id in enumerate(image_ids):
|
791 |
+
lines_list = file_lines_to_list(os.path.join(gt_path, image_id))
|
792 |
+
boxes_per_image = []
|
793 |
+
image = {}
|
794 |
+
image_id = os.path.splitext(image_id)[0]
|
795 |
+
image['file_name'] = image_id + '.jpg'
|
796 |
+
image['width'] = 1
|
797 |
+
image['height'] = 1
|
798 |
+
#-----------------------------------------------------------------#
|
799 |
+
# 感谢 多学学英语吧 的提醒
|
800 |
+
# 解决了'Results do not correspond to current coco set'问题
|
801 |
+
#-----------------------------------------------------------------#
|
802 |
+
image['id'] = str(image_id)
|
803 |
+
|
804 |
+
for line in lines_list:
|
805 |
+
difficult = 0
|
806 |
+
if "difficult" in line:
|
807 |
+
line_split = line.split()
|
808 |
+
left, top, right, bottom, _difficult = line_split[-5:]
|
809 |
+
class_name = ""
|
810 |
+
for name in line_split[:-5]:
|
811 |
+
class_name += name + " "
|
812 |
+
class_name = class_name[:-1]
|
813 |
+
difficult = 1
|
814 |
+
else:
|
815 |
+
line_split = line.split()
|
816 |
+
left, top, right, bottom = line_split[-4:]
|
817 |
+
class_name = ""
|
818 |
+
for name in line_split[:-4]:
|
819 |
+
class_name += name + " "
|
820 |
+
class_name = class_name[:-1]
|
821 |
+
|
822 |
+
left, top, right, bottom = float(left), float(top), float(right), float(bottom)
|
823 |
+
cls_id = class_names.index(class_name) + 1
|
824 |
+
bbox = [left, top, right - left, bottom - top, difficult, str(image_id), cls_id, (right - left) * (bottom - top) - 10.0]
|
825 |
+
boxes_per_image.append(bbox)
|
826 |
+
images.append(image)
|
827 |
+
bboxes.extend(boxes_per_image)
|
828 |
+
results['images'] = images
|
829 |
+
|
830 |
+
categories = []
|
831 |
+
for i, cls in enumerate(class_names):
|
832 |
+
category = {}
|
833 |
+
category['supercategory'] = cls
|
834 |
+
category['name'] = cls
|
835 |
+
category['id'] = i + 1
|
836 |
+
categories.append(category)
|
837 |
+
results['categories'] = categories
|
838 |
+
|
839 |
+
annotations = []
|
840 |
+
for i, box in enumerate(bboxes):
|
841 |
+
annotation = {}
|
842 |
+
annotation['area'] = box[-1]
|
843 |
+
annotation['category_id'] = box[-2]
|
844 |
+
annotation['image_id'] = box[-3]
|
845 |
+
annotation['iscrowd'] = box[-4]
|
846 |
+
annotation['bbox'] = box[:4]
|
847 |
+
annotation['id'] = i
|
848 |
+
annotations.append(annotation)
|
849 |
+
results['annotations'] = annotations
|
850 |
+
return results
|
851 |
+
|
852 |
+
def preprocess_dr(dr_path, class_names):
|
853 |
+
image_ids = os.listdir(dr_path)
|
854 |
+
results = []
|
855 |
+
for image_id in image_ids:
|
856 |
+
lines_list = file_lines_to_list(os.path.join(dr_path, image_id))
|
857 |
+
image_id = os.path.splitext(image_id)[0]
|
858 |
+
for line in lines_list:
|
859 |
+
line_split = line.split()
|
860 |
+
confidence, left, top, right, bottom = line_split[-5:]
|
861 |
+
class_name = ""
|
862 |
+
for name in line_split[:-5]:
|
863 |
+
class_name += name + " "
|
864 |
+
class_name = class_name[:-1]
|
865 |
+
left, top, right, bottom = float(left), float(top), float(right), float(bottom)
|
866 |
+
result = {}
|
867 |
+
result["image_id"] = str(image_id)
|
868 |
+
result["category_id"] = class_names.index(class_name) + 1
|
869 |
+
result["bbox"] = [left, top, right - left, bottom - top]
|
870 |
+
result["score"] = float(confidence)
|
871 |
+
results.append(result)
|
872 |
+
return results
|
873 |
+
|
874 |
+
def get_coco_map(class_names, path):
|
875 |
+
from pycocotools.coco import COCO
|
876 |
+
from pycocotools.cocoeval import COCOeval
|
877 |
+
|
878 |
+
GT_PATH = os.path.join(path, 'ground-truth')
|
879 |
+
DR_PATH = os.path.join(path, 'detection-results')
|
880 |
+
COCO_PATH = os.path.join(path, 'coco_eval')
|
881 |
+
|
882 |
+
if not os.path.exists(COCO_PATH):
|
883 |
+
os.makedirs(COCO_PATH)
|
884 |
+
|
885 |
+
GT_JSON_PATH = os.path.join(COCO_PATH, 'instances_gt.json')
|
886 |
+
DR_JSON_PATH = os.path.join(COCO_PATH, 'instances_dr.json')
|
887 |
+
|
888 |
+
with open(GT_JSON_PATH, "w") as f:
|
889 |
+
results_gt = preprocess_gt(GT_PATH, class_names)
|
890 |
+
json.dump(results_gt, f, indent=4)
|
891 |
+
|
892 |
+
with open(DR_JSON_PATH, "w") as f:
|
893 |
+
results_dr = preprocess_dr(DR_PATH, class_names)
|
894 |
+
json.dump(results_dr, f, indent=4)
|
895 |
+
|
896 |
+
cocoGt = COCO(GT_JSON_PATH)
|
897 |
+
cocoDt = cocoGt.loadRes(DR_JSON_PATH)
|
898 |
+
cocoEval = COCOeval(cocoGt, cocoDt, 'bbox')
|
899 |
+
cocoEval.evaluate()
|
900 |
+
cocoEval.accumulate()
|
901 |
+
cocoEval.summarize()
|