Spaces:
Running
Running
上传主要代码
Browse files- Pipfile +19 -0
- app.py +400 -0
- get_yaml.py +14 -0
- instructions.md +9 -0
- packages.txt +2 -0
- predict.py +194 -0
- requirements.txt +15 -0
- yolo.py +422 -0
Pipfile
ADDED
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
[[source]]
|
2 |
+
name = "pypi"
|
3 |
+
url = "https://pypi.org/simple"
|
4 |
+
verify_ssl = true
|
5 |
+
|
6 |
+
[dev-packages]
|
7 |
+
|
8 |
+
[packages]
|
9 |
+
streamlit = ">0.49.0"
|
10 |
+
opencv-python = "*"
|
11 |
+
numpy = "*"
|
12 |
+
torchvision = "0.9.1"
|
13 |
+
torch = "1.8.1"
|
14 |
+
Pillow = "8.2.0"
|
15 |
+
pyyaml = "6.0"
|
16 |
+
matplotlib = "*"
|
17 |
+
opencv-python-headless = "4.5.2.52"
|
18 |
+
av = "*"
|
19 |
+
streamlit-webrtc = "0.36.1"
|
app.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""Create an Object Detection Web App using PyTorch and Streamlit."""
|
2 |
+
# import libraries
|
3 |
+
from PIL import Image
|
4 |
+
from torchvision import models, transforms
|
5 |
+
import torch
|
6 |
+
import streamlit as st
|
7 |
+
from yolo import YOLO
|
8 |
+
import os
|
9 |
+
import urllib
|
10 |
+
import numpy as np
|
11 |
+
from streamlit_webrtc import webrtc_streamer, WebRtcMode, RTCConfiguration
|
12 |
+
import av
|
13 |
+
# 设置网页的icon
|
14 |
+
st.set_page_config(page_title='Gesture Detector', page_icon='✌',
|
15 |
+
layout='centered', initial_sidebar_state='expanded')
|
16 |
+
|
17 |
+
RTC_CONFIGURATION = RTCConfiguration(
|
18 |
+
{
|
19 |
+
"RTCIceServer": [{
|
20 |
+
"urls": ["stun:stun.l.google.com:19302"],
|
21 |
+
"username": "pikachu",
|
22 |
+
"credential": "1234",
|
23 |
+
}]
|
24 |
+
}
|
25 |
+
)
|
26 |
+
def main():
|
27 |
+
# Render the readme as markdown using st.markdown.
|
28 |
+
readme_text = st.markdown(open("instructions.md",encoding='utf-8').read())
|
29 |
+
|
30 |
+
|
31 |
+
# Once we have the dependencies, add a selector for the app mode on the sidebar.
|
32 |
+
st.sidebar.title("What to do")
|
33 |
+
app_mode = st.sidebar.selectbox("Choose the app mode",
|
34 |
+
["Show instructions", "Run the app", "Show the source code"])
|
35 |
+
if app_mode == "Show instructions":
|
36 |
+
st.sidebar.success('To continue select "Run the app".')
|
37 |
+
elif app_mode == "Show the source code":
|
38 |
+
readme_text.empty()
|
39 |
+
st.code(open("app.py",encoding='utf-8').read())
|
40 |
+
elif app_mode == "Run the app":
|
41 |
+
# Download external dependencies.
|
42 |
+
for filename in EXTERNAL_DEPENDENCIES.keys():
|
43 |
+
download_file(filename)
|
44 |
+
|
45 |
+
readme_text.empty()
|
46 |
+
run_the_app()
|
47 |
+
|
48 |
+
# External files to download.
|
49 |
+
EXTERNAL_DEPENDENCIES = {
|
50 |
+
"yolov4_tiny.pth": {
|
51 |
+
"url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_tiny.pth",
|
52 |
+
"size": 23631189
|
53 |
+
},
|
54 |
+
"yolov4_SE.pth": {
|
55 |
+
"url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_SE.pth",
|
56 |
+
"size": 23806027
|
57 |
+
},
|
58 |
+
"yolov4_CBAM.pth":{
|
59 |
+
"url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_CBAM.pth",
|
60 |
+
"size": 23981478
|
61 |
+
},
|
62 |
+
"yolov4_ECA.pth":{
|
63 |
+
"url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_ECA.pth",
|
64 |
+
"size": 23632688
|
65 |
+
},
|
66 |
+
"yolov4_weights_ep150_608.pth":{
|
67 |
+
"url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_weights_ep150_608.pth",
|
68 |
+
"size": 256423031
|
69 |
+
},
|
70 |
+
"yolov4_weights_ep150_416.pth":{
|
71 |
+
"url": "https://github.com/Dreaming-future/my_weights/releases/download/v1.3/yolov4_weights_ep150_416.pth",
|
72 |
+
"size": 256423031
|
73 |
+
},
|
74 |
+
}
|
75 |
+
|
76 |
+
|
77 |
+
# This file downloader demonstrates Streamlit animation.
|
78 |
+
def download_file(file_path):
|
79 |
+
# Don't download the file twice. (If possible, verify the download using the file length.)
|
80 |
+
if os.path.exists(file_path):
|
81 |
+
if "size" not in EXTERNAL_DEPENDENCIES[file_path]:
|
82 |
+
return
|
83 |
+
elif os.path.getsize(file_path) == EXTERNAL_DEPENDENCIES[file_path]["size"]:
|
84 |
+
return
|
85 |
+
# print(os.path.getsize(file_path))
|
86 |
+
# These are handles to two visual elements to animate.
|
87 |
+
weights_warning, progress_bar = None, None
|
88 |
+
try:
|
89 |
+
weights_warning = st.warning("Downloading %s..." % file_path)
|
90 |
+
progress_bar = st.progress(0)
|
91 |
+
with open(file_path, "wb") as output_file:
|
92 |
+
with urllib.request.urlopen(EXTERNAL_DEPENDENCIES[file_path]["url"]) as response:
|
93 |
+
length = int(response.info()["Content-Length"])
|
94 |
+
counter = 0.0
|
95 |
+
MEGABYTES = 2.0 ** 20.0
|
96 |
+
while True:
|
97 |
+
data = response.read(8192)
|
98 |
+
if not data:
|
99 |
+
break
|
100 |
+
counter += len(data)
|
101 |
+
output_file.write(data)
|
102 |
+
|
103 |
+
# We perform animation by overwriting the elements.
|
104 |
+
weights_warning.warning("Downloading %s... (%6.2f/%6.2f MB)" %
|
105 |
+
(file_path, counter / MEGABYTES, length / MEGABYTES))
|
106 |
+
progress_bar.progress(min(counter / length, 1.0))
|
107 |
+
except Exception as e:
|
108 |
+
print(e)
|
109 |
+
# Finally, we remove these visual elements by calling .empty().
|
110 |
+
finally:
|
111 |
+
if weights_warning is not None:
|
112 |
+
weights_warning.empty()
|
113 |
+
if progress_bar is not None:
|
114 |
+
progress_bar.empty()
|
115 |
+
|
116 |
+
# This is the main app app itself, which appears when the user selects "Run the app".
|
117 |
+
def run_the_app():
|
118 |
+
class Config():
|
119 |
+
def __init__(self, weights = 'yolov4_tiny.pth', tiny = True, phi = 0, shape = 416,nms_iou = 0.3, confidence = 0.5):
|
120 |
+
self.weights = weights
|
121 |
+
self.tiny = tiny
|
122 |
+
self.phi = phi
|
123 |
+
self.cuda = False
|
124 |
+
self.shape = shape
|
125 |
+
self.confidence = confidence
|
126 |
+
self.nms_iou = nms_iou
|
127 |
+
# set title of app
|
128 |
+
st.markdown('<h1 align="center">✌ Gesture Detection</h1>',
|
129 |
+
unsafe_allow_html=True)
|
130 |
+
st.sidebar.markdown("# Gesture Detection on?")
|
131 |
+
activities = ["Example","Image", "Camera", "FPS", "Heatmap","Real Time", "Video"]
|
132 |
+
choice = st.sidebar.selectbox("Choose among the given options:", activities)
|
133 |
+
phi = st.sidebar.selectbox("yolov4-tiny 使用的自注意力模式:",('0tiny','1SE','2CABM','3ECA'))
|
134 |
+
print("")
|
135 |
+
|
136 |
+
tiny = st.sidebar.checkbox('是否使用 yolov4 tiny 模型')
|
137 |
+
if not tiny:
|
138 |
+
shape = st.sidebar.selectbox("Choose shape to Input:", [416,608])
|
139 |
+
conf,nms = object_detector_ui()
|
140 |
+
@st.cache
|
141 |
+
def get_yolo(tiny,phi,conf,nms,shape=416):
|
142 |
+
weights = 'yolov4_tiny.pth'
|
143 |
+
if tiny:
|
144 |
+
if phi == '0tiny':
|
145 |
+
weights = 'yolov4_tiny.pth'
|
146 |
+
elif phi == '1SE':
|
147 |
+
weights = 'yolov4_SE.pth'
|
148 |
+
elif phi == '2CABM':
|
149 |
+
weights = 'yolov4_CBAM.pth'
|
150 |
+
elif phi == '3ECA':
|
151 |
+
weights = 'yolov4_ECA.pth'
|
152 |
+
else:
|
153 |
+
if shape == 608:
|
154 |
+
weights = 'yolov4_weights_ep150_608.pth'
|
155 |
+
elif shape == 416:
|
156 |
+
weights = 'yolov4_weights_ep150_416.pth'
|
157 |
+
opt = Config(weights = weights, tiny = tiny , phi = int(phi[0]), shape = shape,nms_iou = nms, confidence = conf)
|
158 |
+
yolo = YOLO(opt)
|
159 |
+
return yolo
|
160 |
+
|
161 |
+
if tiny:
|
162 |
+
yolo = get_yolo(tiny, phi, conf, nms)
|
163 |
+
st.write("YOLOV4 tiny 模型加载完毕")
|
164 |
+
else:
|
165 |
+
yolo = get_yolo(tiny, phi, conf, nms, shape)
|
166 |
+
st.write("YOLOV4 模型加载完毕")
|
167 |
+
|
168 |
+
if choice == 'Image':
|
169 |
+
detect_image(yolo)
|
170 |
+
elif choice =='Camera':
|
171 |
+
detect_camera(yolo)
|
172 |
+
elif choice == 'FPS':
|
173 |
+
detect_fps(yolo)
|
174 |
+
elif choice == "Heatmap":
|
175 |
+
detect_heatmap(yolo)
|
176 |
+
elif choice == "Example":
|
177 |
+
detect_example(yolo)
|
178 |
+
elif choice == "Real Time":
|
179 |
+
detect_realtime(yolo)
|
180 |
+
elif choice == "Video":
|
181 |
+
detect_video(yolo)
|
182 |
+
|
183 |
+
|
184 |
+
|
185 |
+
# This sidebar UI lets the user select parameters for the YOLO object detector.
|
186 |
+
def object_detector_ui():
|
187 |
+
st.sidebar.markdown("# Model")
|
188 |
+
confidence_threshold = st.sidebar.slider("Confidence threshold", 0.0, 1.0, 0.5, 0.01)
|
189 |
+
overlap_threshold = st.sidebar.slider("Overlap threshold", 0.0, 1.0, 0.3, 0.01)
|
190 |
+
return confidence_threshold, overlap_threshold
|
191 |
+
|
192 |
+
def predict(image,yolo):
|
193 |
+
"""Return predictions.
|
194 |
+
|
195 |
+
Parameters
|
196 |
+
----------
|
197 |
+
:param image: uploaded image
|
198 |
+
:type image: jpg
|
199 |
+
:rtype: list
|
200 |
+
:return: none
|
201 |
+
"""
|
202 |
+
crop = False
|
203 |
+
count = False
|
204 |
+
try:
|
205 |
+
# image = Image.open(image)
|
206 |
+
r_image = yolo.detect_image(image, crop = crop, count=count)
|
207 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
208 |
+
result = transform(r_image)
|
209 |
+
st.image(result.permute(1,2,0).numpy(), caption = 'Processed Image.', use_column_width = True)
|
210 |
+
except Exception as e:
|
211 |
+
print(e)
|
212 |
+
|
213 |
+
def fps(image,yolo):
|
214 |
+
test_interval = 50
|
215 |
+
tact_time = yolo.get_FPS(image, test_interval)
|
216 |
+
st.write(str(tact_time) + ' seconds, ', str(1/tact_time),'FPS, @batch_size 1')
|
217 |
+
return tact_time
|
218 |
+
# print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
|
219 |
+
|
220 |
+
|
221 |
+
def detect_image(yolo):
|
222 |
+
# enable users to upload images for the model to make predictions
|
223 |
+
file_up = st.file_uploader("Upload an image", type = ["jpg","png","jpeg"])
|
224 |
+
classes = ["up","down","left","right","front","back","clockwise","anticlockwise"]
|
225 |
+
class_to_idx = {cls: idx for (idx, cls) in enumerate(classes)}
|
226 |
+
st.sidebar.markdown("See the model preformance and play with it")
|
227 |
+
if file_up is not None:
|
228 |
+
with st.spinner(text='Preparing Image'):
|
229 |
+
# display image that user uploaded
|
230 |
+
image = Image.open(file_up)
|
231 |
+
st.image(image, caption = 'Uploaded Image.', use_column_width = True)
|
232 |
+
st.balloons()
|
233 |
+
detect = st.button("开始检测Image")
|
234 |
+
if detect:
|
235 |
+
st.write("")
|
236 |
+
st.write("Just a second ...")
|
237 |
+
predict(image,yolo)
|
238 |
+
st.balloons()
|
239 |
+
|
240 |
+
|
241 |
+
|
242 |
+
def detect_camera(yolo):
|
243 |
+
picture = st.camera_input("Take a picture")
|
244 |
+
if picture:
|
245 |
+
filters_to_funcs = {
|
246 |
+
"No filter": predict,
|
247 |
+
"Heatmap": heatmap,
|
248 |
+
"FPS": fps,
|
249 |
+
}
|
250 |
+
filters = st.selectbox("...and now, apply a filter!", filters_to_funcs.keys())
|
251 |
+
image = Image.open(picture)
|
252 |
+
with st.spinner(text='Preparing Image'):
|
253 |
+
filters_to_funcs[filters](image,yolo)
|
254 |
+
st.balloons()
|
255 |
+
|
256 |
+
def detect_fps(yolo):
|
257 |
+
file_up = st.file_uploader("Upload an image", type = ["jpg","png","jpeg"])
|
258 |
+
classes = ["up","down","left","right","front","back","clockwise","anticlockwise"]
|
259 |
+
class_to_idx = {cls: idx for (idx, cls) in enumerate(classes)}
|
260 |
+
st.sidebar.markdown("See the model preformance and play with it")
|
261 |
+
if file_up is not None:
|
262 |
+
# display image that user uploaded
|
263 |
+
image = Image.open(file_up)
|
264 |
+
st.image(image, caption = 'Uploaded Image.', use_column_width = True)
|
265 |
+
st.balloons()
|
266 |
+
detect = st.button("开始检测 FPS")
|
267 |
+
if detect:
|
268 |
+
with st.spinner(text='Preparing Image'):
|
269 |
+
st.write("")
|
270 |
+
st.write("Just a second ...")
|
271 |
+
tact_time = fps(image,yolo)
|
272 |
+
# st.write(str(tact_time) + ' seconds, ', str(1/tact_time),'FPS, @batch_size 1')
|
273 |
+
st.balloons()
|
274 |
+
|
275 |
+
def heatmap(image,yolo):
|
276 |
+
heatmap_save_path = "heatmap_vision.png"
|
277 |
+
yolo.detect_heatmap(image, heatmap_save_path)
|
278 |
+
img = Image.open(heatmap_save_path)
|
279 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
280 |
+
result = transform(img)
|
281 |
+
st.image(result.permute(1,2,0).numpy(), caption = 'Processed Image.', use_column_width = True)
|
282 |
+
|
283 |
+
def detect_heatmap(yolo):
|
284 |
+
file_up = st.file_uploader("Upload an image", type = ["jpg","png","jpeg"])
|
285 |
+
classes = ["up","down","left","right","front","back","clockwise","anticlockwise"]
|
286 |
+
class_to_idx = {cls: idx for (idx, cls) in enumerate(classes)}
|
287 |
+
st.sidebar.markdown("See the model preformance and play with it")
|
288 |
+
if file_up is not None:
|
289 |
+
# display image that user uploaded
|
290 |
+
image = Image.open(file_up)
|
291 |
+
st.image(image, caption = 'Uploaded Image.', use_column_width = True)
|
292 |
+
st.balloons()
|
293 |
+
detect = st.button("开始检测 heatmap")
|
294 |
+
if detect:
|
295 |
+
with st.spinner(text='Preparing Heatmap'):
|
296 |
+
st.write("")
|
297 |
+
st.write("Just a second ...")
|
298 |
+
heatmap(image,yolo)
|
299 |
+
st.balloons()
|
300 |
+
|
301 |
+
def detect_example(yolo):
|
302 |
+
st.sidebar.title("Choose an Image as a example")
|
303 |
+
images = os.listdir('./img')
|
304 |
+
images.sort()
|
305 |
+
image = st.sidebar.selectbox("Image Name", images)
|
306 |
+
st.sidebar.markdown("See the model preformance and play with it")
|
307 |
+
image = Image.open(os.path.join('img',image))
|
308 |
+
st.image(image, caption = 'Choose Image.', use_column_width = True)
|
309 |
+
st.balloons()
|
310 |
+
detect = st.button("开始检测Image")
|
311 |
+
if detect:
|
312 |
+
st.write("")
|
313 |
+
st.write("Just a second ...")
|
314 |
+
predict(image,yolo)
|
315 |
+
st.balloons()
|
316 |
+
|
317 |
+
def detect_realtime(yolo):
|
318 |
+
|
319 |
+
class VideoProcessor:
|
320 |
+
def recv(self, frame):
|
321 |
+
img = frame.to_ndarray(format="bgr24")
|
322 |
+
img = Image.fromarray(img)
|
323 |
+
crop = False
|
324 |
+
count = False
|
325 |
+
r_image = yolo.detect_image(img, crop = crop, count=count)
|
326 |
+
transform = transforms.Compose([transforms.ToTensor()])
|
327 |
+
result = transform(r_image)
|
328 |
+
result = result.permute(1,2,0).numpy()
|
329 |
+
result = (result * 255).astype(np.uint8)
|
330 |
+
return av.VideoFrame.from_ndarray(result, format="bgr24")
|
331 |
+
|
332 |
+
webrtc_ctx = webrtc_streamer(
|
333 |
+
key="example",
|
334 |
+
mode=WebRtcMode.SENDRECV,
|
335 |
+
rtc_configuration=RTC_CONFIGURATION,
|
336 |
+
media_stream_constraints={"video": True, "audio": False},
|
337 |
+
async_processing=False,
|
338 |
+
video_processor_factory=VideoProcessor
|
339 |
+
)
|
340 |
+
|
341 |
+
import cv2
|
342 |
+
import time
|
343 |
+
def detect_video(yolo):
|
344 |
+
file_up = st.file_uploader("Upload a video", type = ["mp4"])
|
345 |
+
print(file_up)
|
346 |
+
classes = ["up","down","left","right","front","back","clockwise","anticlockwise"]
|
347 |
+
|
348 |
+
if file_up is not None:
|
349 |
+
video_path = 'video.mp4'
|
350 |
+
st.video(file_up)
|
351 |
+
with open(video_path, 'wb') as f:
|
352 |
+
f.write(file_up.read())
|
353 |
+
detect = st.button("开始检测 Video")
|
354 |
+
|
355 |
+
if detect:
|
356 |
+
video_save_path = 'video2.mp4'
|
357 |
+
# display image that user uploaded
|
358 |
+
capture = cv2.VideoCapture(video_path)
|
359 |
+
|
360 |
+
video_fps = st.slider("Video FPS", 5, 30, int(capture.get(cv2.CAP_PROP_FPS)), 1)
|
361 |
+
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
362 |
+
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
363 |
+
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
|
364 |
+
|
365 |
+
|
366 |
+
|
367 |
+
while(True):
|
368 |
+
# 读取某一帧
|
369 |
+
ref, frame = capture.read()
|
370 |
+
if not ref:
|
371 |
+
break
|
372 |
+
# 转变成Image
|
373 |
+
# frame = Image.fromarray(np.uint8(frame))
|
374 |
+
# 格式转变,BGRtoRGB
|
375 |
+
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
|
376 |
+
# 转变成Image
|
377 |
+
frame = Image.fromarray(np.uint8(frame))
|
378 |
+
# 进行检测
|
379 |
+
frame = np.array(yolo.detect_image(frame))
|
380 |
+
# RGBtoBGR满足opencv显示格式
|
381 |
+
frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
|
382 |
+
|
383 |
+
# print("fps= %.2f"%(fps))
|
384 |
+
# frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
385 |
+
out.write(frame)
|
386 |
+
|
387 |
+
out.release()
|
388 |
+
capture.release()
|
389 |
+
print("Save processed video to the path :" + video_save_path)
|
390 |
+
|
391 |
+
with open(video_save_path, "rb") as file:
|
392 |
+
btn = st.download_button(
|
393 |
+
label="Download Video",
|
394 |
+
data=file,
|
395 |
+
file_name="video.mp4",
|
396 |
+
)
|
397 |
+
st.balloons()
|
398 |
+
|
399 |
+
if __name__ == "__main__":
|
400 |
+
main()
|
get_yaml.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import sys
|
3 |
+
import yaml
|
4 |
+
|
5 |
+
def get_config():
|
6 |
+
yaml_path = 'model_data/gesture.yaml'
|
7 |
+
f = open(yaml_path,'r',encoding='utf-8')
|
8 |
+
config = yaml.load(f,Loader =yaml.FullLoader)
|
9 |
+
f.close()
|
10 |
+
return config
|
11 |
+
|
12 |
+
if __name__ == "__main__":
|
13 |
+
config = get_config()
|
14 |
+
print(config)
|
instructions.md
ADDED
@@ -0,0 +1,9 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# ✌ Gesture Detection
|
2 |
+
|
3 |
+
|
4 |
+
这是一个基于无人机视觉图像手势识别控制系统,选择了YOLOv4模型进行训练
|
5 |
+
|
6 |
+
**YOLOv4 = CSPDarknet53(主干) + SPP** **附加模块(颈** **) +** **PANet** **路径聚合(颈** **) + YOLOv3(头部)**
|
7 |
+
|
8 |
+
![img](https://pdf.cdn.readpaper.com/parsed/fetch_target/699143cdb334ecfc63caf8192472490c_0_Figure_1.png)
|
9 |
+
|
packages.txt
ADDED
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
1 |
+
freeglut3-dev
|
2 |
+
libgtk2.0-dev
|
predict.py
ADDED
@@ -0,0 +1,194 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#-----------------------------------------------------------------------#
|
2 |
+
# predict.py将单张图片预测、摄像头检测、FPS测试和目录遍历检测等功能
|
3 |
+
# 整合到了一个py文件中,通过指定mode进行模式的修改。
|
4 |
+
#-----------------------------------------------------------------------#
|
5 |
+
import time
|
6 |
+
import yaml
|
7 |
+
import cv2
|
8 |
+
import numpy as np
|
9 |
+
from PIL import Image
|
10 |
+
from get_yaml import get_config
|
11 |
+
from yolo import YOLO
|
12 |
+
import argparse
|
13 |
+
if __name__ == "__main__":
|
14 |
+
parser = argparse.ArgumentParser()
|
15 |
+
parser.add_argument('--weights',type=str,default='model_data/yolotiny_SE_ep100.pth',help='initial weights path')
|
16 |
+
parser.add_argument('--tiny',action='store_true',help='使用yolotiny模型')
|
17 |
+
parser.add_argument('--phi',type=int,default=1,help='yolov4tiny注意力机制类型')
|
18 |
+
parser.add_argument('--mode',type=str,choices=['dir_predict', 'video', 'fps','predict','heatmap','export_onnx'],default="dir_predict",help='预测的模式')
|
19 |
+
parser.add_argument('--cuda',action='store_true',help='表示是否使用GPU')
|
20 |
+
parser.add_argument('--shape',type=int,default=416,help='输入图像的shape')
|
21 |
+
parser.add_argument('--video',type=str,default='',help='需要检测的视频文件')
|
22 |
+
parser.add_argument('--save-video',type=str,default='',help='保存视频的位置')
|
23 |
+
parser.add_argument('--confidence',type=float,default=0.5,help='只有得分大于置信度的预测框会被保留下来')
|
24 |
+
parser.add_argument('--nms_iou',type=float,default=0.3,help='非极大抑制所用到的nms_iou大小')
|
25 |
+
opt = parser.parse_args()
|
26 |
+
print(opt)
|
27 |
+
|
28 |
+
# 配置文件
|
29 |
+
config = get_config()
|
30 |
+
yolo = YOLO(opt)
|
31 |
+
|
32 |
+
#----------------------------------------------------------------------------------------------------------#
|
33 |
+
# mode用于指定测试的模式:
|
34 |
+
# 'predict' 表示单张图片预测,如果想对预测过程进行修改,如保存图片,截取对象等,可以先看下方详细的注释
|
35 |
+
# 'video' 表示视频检测,可调用摄像头或者视频进行检测,详情查看下方注释。
|
36 |
+
# 'fps' 表示测试fps,使用的图片是img里面的street.jpg,详情查看下方注释。
|
37 |
+
# 'dir_predict' 表示遍历文件夹进行检测并保存。默认遍历img文件夹,保存img_out文件夹,详情查看下方注释。
|
38 |
+
# 'heatmap' 表示进行预测结果的热力图可视化,详情查看下方注释。
|
39 |
+
# 'export_onnx' 表示将模型导出为onnx,需要pytorch1.7.1以上。
|
40 |
+
#----------------------------------------------------------------------------------------------------------#
|
41 |
+
mode = opt.mode
|
42 |
+
#-------------------------------------------------------------------------#
|
43 |
+
# crop 指定了是否在单张图片预测后对目标进行截取
|
44 |
+
# count 指定了是否进行目标的计数
|
45 |
+
# crop、count仅在mode='predict'时有效
|
46 |
+
#-------------------------------------------------------------------------#
|
47 |
+
crop = False
|
48 |
+
count = False
|
49 |
+
#----------------------------------------------------------------------------------------------------------#
|
50 |
+
# video_path 用于指定视频的路径,当video_path=0时表示检测摄像头
|
51 |
+
# 想要检测视频,则设置如video_path = "xxx.mp4"即可,代表读取出根目录下的xxx.mp4文件。
|
52 |
+
# video_save_path 表示视频保存的路径,当video_save_path=""时表示不保存
|
53 |
+
# 想要保存视频,则设置如video_save_path = "yyy.mp4"即可,代表保存为根目录下的yyy.mp4文件。
|
54 |
+
# video_fps 用于保存的视频的fps
|
55 |
+
#
|
56 |
+
# video_path、video_save_path和video_fps仅在mode='video'时有效
|
57 |
+
# 保存视频时需要ctrl+c退出或者运行到最后一帧才会完成完整的保存步骤。
|
58 |
+
#----------------------------------------------------------------------------------------------------------#
|
59 |
+
video_path = 0 if opt.video == '' else opt.video
|
60 |
+
video_save_path = opt.save_video
|
61 |
+
video_fps = 25.0
|
62 |
+
#----------------------------------------------------------------------------------------------------------#
|
63 |
+
# test_interval 用于指定测量fps的时候,图片检测的次数。理论上test_interval越大,fps越准确。
|
64 |
+
# fps_image_path 用于指定测试的fps图片
|
65 |
+
#
|
66 |
+
# test_interval和fps_image_path仅在mode='fps'有效
|
67 |
+
#----------------------------------------------------------------------------------------------------------#
|
68 |
+
test_interval = 100
|
69 |
+
fps_image_path = "img/up.jpg"
|
70 |
+
#-------------------------------------------------------------------------#
|
71 |
+
# dir_origin_path 指定了用于检测的图片的文件夹路径
|
72 |
+
# dir_save_path 指定了检测完图片的保存路径
|
73 |
+
#
|
74 |
+
# dir_origin_path和dir_save_path���在mode='dir_predict'时有效
|
75 |
+
#-------------------------------------------------------------------------#
|
76 |
+
dir_origin_path = "img/"
|
77 |
+
dir_save_path = "img_out/"
|
78 |
+
#-------------------------------------------------------------------------#
|
79 |
+
# heatmap_save_path 热力图的保存路径,默认保存在model_data下
|
80 |
+
#
|
81 |
+
# heatmap_save_path仅在mode='heatmap'有效
|
82 |
+
#-------------------------------------------------------------------------#
|
83 |
+
heatmap_save_path = "model_data/heatmap_vision.png"
|
84 |
+
#-------------------------------------------------------------------------#
|
85 |
+
# simplify 使用Simplify onnx
|
86 |
+
# onnx_save_path 指定了onnx的保存路径
|
87 |
+
#-------------------------------------------------------------------------#
|
88 |
+
simplify = True
|
89 |
+
onnx_save_path = "model_data/models.onnx"
|
90 |
+
|
91 |
+
if mode == "predict":
|
92 |
+
'''
|
93 |
+
1、如果想要进行检测完的图片的保存,利用r_image.save("img.jpg")即可保存,直接在predict.py里进行修改即可。
|
94 |
+
2、如果想要获得预测框的坐标,可以进入yolo.detect_image函数,在绘图部分读取top,left,bottom,right这四个值。
|
95 |
+
3、如果想要利用预测框截取下目标,可以进入yolo.detect_image函数,在绘图部分利用获取到的top,left,bottom,right这四个值
|
96 |
+
在原图上利用矩阵的方式进行截取。
|
97 |
+
4、如果想要在预测图上写额外的字,比如检测到的特定目标的数量,可以进入yolo.detect_image函数,在绘图部分对predicted_class进行判断,
|
98 |
+
比如判断if predicted_class == 'car': 即可判断当前目标是否为车,然后记录数量即可。利用draw.text即可写字。
|
99 |
+
'''
|
100 |
+
while True:
|
101 |
+
img = input('Input image filename:')
|
102 |
+
try:
|
103 |
+
image = Image.open(img)
|
104 |
+
except:
|
105 |
+
print('Open Error! Try again!')
|
106 |
+
continue
|
107 |
+
else:
|
108 |
+
r_image = yolo.detect_image(image, crop = crop, count=count)
|
109 |
+
r_image.show()
|
110 |
+
r_image.save(dir_save_path + 'img_result.jpg')
|
111 |
+
|
112 |
+
elif mode == "video":
|
113 |
+
capture = cv2.VideoCapture(video_path)
|
114 |
+
if video_save_path != '':
|
115 |
+
fourcc = cv2.VideoWriter_fourcc(*'XVID')
|
116 |
+
size = (int(capture.get(cv2.CAP_PROP_FRAME_WIDTH)), int(capture.get(cv2.CAP_PROP_FRAME_HEIGHT)))
|
117 |
+
out = cv2.VideoWriter(video_save_path, fourcc, video_fps, size)
|
118 |
+
|
119 |
+
ref, frame = capture.read()
|
120 |
+
if not ref:
|
121 |
+
raise ValueError("未能正确读取摄像头(视频),请注意是否正确安装摄像头(是否正确填写视频路径)。")
|
122 |
+
|
123 |
+
fps = 0.0
|
124 |
+
while(True):
|
125 |
+
t1 = time.time()
|
126 |
+
# 读取某一帧
|
127 |
+
ref, frame = capture.read()
|
128 |
+
if not ref:
|
129 |
+
break
|
130 |
+
# 格式转变,BGRtoRGB
|
131 |
+
frame = cv2.cvtColor(frame,cv2.COLOR_BGR2RGB)
|
132 |
+
# 转变成Image
|
133 |
+
frame = Image.fromarray(np.uint8(frame))
|
134 |
+
# 进行检测
|
135 |
+
frame = np.array(yolo.detect_image(frame))
|
136 |
+
# RGBtoBGR满足opencv显示格式
|
137 |
+
frame = cv2.cvtColor(frame,cv2.COLOR_RGB2BGR)
|
138 |
+
|
139 |
+
fps = ( fps + (1./(time.time()-t1)) ) / 2
|
140 |
+
print("fps= %.2f"%(fps))
|
141 |
+
frame = cv2.putText(frame, "fps= %.2f"%(fps), (0, 40), cv2.FONT_HERSHEY_SIMPLEX, 1, (0, 255, 0), 2)
|
142 |
+
|
143 |
+
cv2.imshow("video",frame)
|
144 |
+
c= cv2.waitKey(1) & 0xff
|
145 |
+
if video_save_path != '':
|
146 |
+
out.write(frame)
|
147 |
+
|
148 |
+
if c==27:
|
149 |
+
capture.release()
|
150 |
+
break
|
151 |
+
|
152 |
+
print("Video Detection Done!")
|
153 |
+
capture.release()
|
154 |
+
if video_save_path != '':
|
155 |
+
print("Save processed video to the path :" + video_save_path)
|
156 |
+
out.release()
|
157 |
+
cv2.destroyAllWindows()
|
158 |
+
|
159 |
+
elif mode == "fps":
|
160 |
+
img = Image.open(fps_image_path)
|
161 |
+
tact_time = yolo.get_FPS(img, test_interval)
|
162 |
+
print(str(tact_time) + ' seconds, ' + str(1/tact_time) + 'FPS, @batch_size 1')
|
163 |
+
|
164 |
+
elif mode == "dir_predict":
|
165 |
+
import os
|
166 |
+
|
167 |
+
from tqdm import tqdm
|
168 |
+
|
169 |
+
img_names = os.listdir(dir_origin_path)
|
170 |
+
for img_name in tqdm(img_names):
|
171 |
+
if img_name.lower().endswith(('.bmp', '.dib', '.png', '.jpg', '.jpeg', '.pbm', '.pgm', '.ppm', '.tif', '.tiff')):
|
172 |
+
image_path = os.path.join(dir_origin_path, img_name)
|
173 |
+
image = Image.open(image_path)
|
174 |
+
r_image = yolo.detect_image(image)
|
175 |
+
if not os.path.exists(dir_save_path):
|
176 |
+
os.makedirs(dir_save_path)
|
177 |
+
r_image.save(os.path.join(dir_save_path, img_name.replace(".jpg", ".png")), quality=95, subsampling=0)
|
178 |
+
|
179 |
+
elif mode == "heatmap":
|
180 |
+
while True:
|
181 |
+
img = input('Input image filename:')
|
182 |
+
try:
|
183 |
+
image = Image.open(img)
|
184 |
+
except:
|
185 |
+
print('Open Error! Try again!')
|
186 |
+
continue
|
187 |
+
else:
|
188 |
+
yolo.detect_heatmap(image, heatmap_save_path)
|
189 |
+
|
190 |
+
elif mode == "export_onnx":
|
191 |
+
yolo.convert_to_onnx(simplify, onnx_save_path)
|
192 |
+
|
193 |
+
else:
|
194 |
+
raise AssertionError("Please specify the correct mode: 'predict', 'video', 'fps', 'heatmap', 'export_onnx', 'dir_predict'.")
|
requirements.txt
ADDED
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
scipy
|
2 |
+
numpy
|
3 |
+
matplotlib
|
4 |
+
opencv_python
|
5 |
+
torch==1.8.1
|
6 |
+
torchvision==0.9.1
|
7 |
+
tqdm==4.60.0
|
8 |
+
Pillow==8.2.0
|
9 |
+
h5py==2.10.0
|
10 |
+
tensorboard
|
11 |
+
pyyaml==6.0
|
12 |
+
torchinfo
|
13 |
+
labelimg==1.8.6
|
14 |
+
streamlit==1.8.1
|
15 |
+
opencv-python-headless==4.5.2.52
|
yolo.py
ADDED
@@ -0,0 +1,422 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import colorsys
|
2 |
+
import os
|
3 |
+
import time
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from PIL import ImageDraw, ImageFont
|
9 |
+
|
10 |
+
from nets.yolo import YoloBody
|
11 |
+
from nets.yolo_tiny import YoloBodytiny
|
12 |
+
from utils.utils import (cvtColor, get_anchors, get_classes, preprocess_input,
|
13 |
+
resize_image)
|
14 |
+
from utils.utils_bbox import DecodeBox
|
15 |
+
from get_yaml import get_config
|
16 |
+
import argparse
|
17 |
+
'''
|
18 |
+
训练自己的数据集必看注释!
|
19 |
+
'''
|
20 |
+
class YOLO(object):
|
21 |
+
# 配置文件
|
22 |
+
config = get_config()
|
23 |
+
_defaults = {
|
24 |
+
#--------------------------------------------------------------------------#
|
25 |
+
# 使用自己训练好的模型进行预测一定要修改model_path和classes_path!
|
26 |
+
# model_path指向logs文件夹下的权值文件,classes_path指向model_data下的txt
|
27 |
+
#
|
28 |
+
# 训练好后logs文件夹下存在多个权值文件,选择验证集损失较低的即可。
|
29 |
+
# 验证集损失较低不代表mAP较高,仅代表该权值在验证集上泛化性能较好。
|
30 |
+
# 如果出现shape不匹配,同时要注意训练时的model_path和classes_path参数的修改
|
31 |
+
#--------------------------------------------------------------------------#
|
32 |
+
"class_names" : config['classes'],
|
33 |
+
"num_classes" : config['nc'],
|
34 |
+
#---------------------------------------------------------------------#
|
35 |
+
# anchors_path代表先验框对应的txt文件,一般不修改。
|
36 |
+
# anchors_mask用于帮助代码找到对应的先验框,一般不修改。
|
37 |
+
#---------------------------------------------------------------------#
|
38 |
+
"anchors_path" : 'model_data/yolo_anchors.txt',
|
39 |
+
"anchors_mask" : [[6, 7, 8], [3, 4, 5], [0, 1, 2]],
|
40 |
+
#---------------------------------------------------------------------#
|
41 |
+
# 只有得分大于置信度的预测框会被保留下来
|
42 |
+
#---------------------------------------------------------------------#
|
43 |
+
"confidence" : 0.5, # 0.5,
|
44 |
+
#---------------------------------------------------------------------#
|
45 |
+
# 非极大抑制所用到的nms_iou大小
|
46 |
+
#---------------------------------------------------------------------#
|
47 |
+
"nms_iou" : 0.3, # 0.3,
|
48 |
+
#---------------------------------------------------------------------#
|
49 |
+
# 该变量用于控制是否使用letterbox_image对输入图像进行不失真的resize,
|
50 |
+
# 在多次测试后,发现关闭letterbox_image直接resize的效果更好
|
51 |
+
#---------------------------------------------------------------------#
|
52 |
+
"letterbox_image" : config['letterbox_image'], # False,
|
53 |
+
}
|
54 |
+
|
55 |
+
|
56 |
+
|
57 |
+
@classmethod
|
58 |
+
def get_defaults(cls, n):
|
59 |
+
if n in cls._defaults:
|
60 |
+
return cls._defaults[n]
|
61 |
+
else:
|
62 |
+
return "Unrecognized attribute name '" + n + "'"
|
63 |
+
|
64 |
+
#---------------------------------------------------#
|
65 |
+
# 初始化YOLO
|
66 |
+
#---------------------------------------------------#
|
67 |
+
def __init__(self, opt, **kwargs):
|
68 |
+
self.__dict__.update(self._defaults)
|
69 |
+
for name, value in kwargs.items():
|
70 |
+
setattr(self, name, value)
|
71 |
+
self.phi = opt.phi
|
72 |
+
self.tiny = opt.tiny
|
73 |
+
self.cuda = opt.cuda
|
74 |
+
self.input_shape = [opt.shape,opt.shape]
|
75 |
+
self.model_path = opt.weights
|
76 |
+
self.phi = opt.phi
|
77 |
+
self.confidence = opt.confidence
|
78 |
+
self.nms_iou = opt.nms_iou
|
79 |
+
if self.tiny:
|
80 |
+
self.anchors_mask = [[3,4,5], [1,2,3]]
|
81 |
+
self.anchors_path = 'model_data/yolotiny_anchors.txt'
|
82 |
+
#---------------------------------------------------#
|
83 |
+
# 获得种类和先验框的数量
|
84 |
+
#---------------------------------------------------#
|
85 |
+
# self.class_names, self.num_classes = get_classes(self.classes_path)
|
86 |
+
self.anchors, self.num_anchors = get_anchors(self.anchors_path)
|
87 |
+
self.bbox_util = DecodeBox(self.anchors, self.num_classes, (self.input_shape[0], self.input_shape[1]), self.anchors_mask)
|
88 |
+
|
89 |
+
#---------------------------------------------------#
|
90 |
+
# 画框设置不同的颜色
|
91 |
+
#---------------------------------------------------#
|
92 |
+
hsv_tuples = [(x / self.num_classes, 1., 1.) for x in range(self.num_classes)]
|
93 |
+
self.colors = list(map(lambda x: colorsys.hsv_to_rgb(*x), hsv_tuples))
|
94 |
+
self.colors = list(map(lambda x: (int(x[0] * 255), int(x[1] * 255), int(x[2] * 255)), self.colors))
|
95 |
+
self.generate()
|
96 |
+
|
97 |
+
#---------------------------------------------------#
|
98 |
+
# 生成模型
|
99 |
+
#---------------------------------------------------#
|
100 |
+
def generate(self, onnx=False):
|
101 |
+
#---------------------------------------------------#
|
102 |
+
# 建立yolo模型,载入yolo模型的权重
|
103 |
+
#---------------------------------------------------#
|
104 |
+
|
105 |
+
if not self.tiny:
|
106 |
+
self.net = YoloBody(self.anchors_mask, self.num_classes)
|
107 |
+
elif self.tiny:
|
108 |
+
self.net = YoloBodytiny(self.anchors_mask, self.num_classes, self.phi)
|
109 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
110 |
+
self.net.load_state_dict(torch.load(self.model_path, map_location=device))
|
111 |
+
self.net = self.net.eval()
|
112 |
+
|
113 |
+
print('{} model, anchors, and classes loaded.'.format(self.model_path))
|
114 |
+
if not onnx:
|
115 |
+
if self.cuda:
|
116 |
+
self.net = nn.DataParallel(self.net)
|
117 |
+
self.net = self.net.cuda()
|
118 |
+
|
119 |
+
#---------------------------------------------------#
|
120 |
+
# 检测图片
|
121 |
+
#---------------------------------------------------#
|
122 |
+
def detect_image(self, image, crop = False, count = False):
|
123 |
+
#---------------------------------------------------#
|
124 |
+
# 计算输入图片的高和宽
|
125 |
+
#---------------------------------------------------#
|
126 |
+
image_shape = np.array(np.shape(image)[0:2])
|
127 |
+
#---------------------------------------------------------#
|
128 |
+
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
129 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
130 |
+
#---------------------------------------------------------#
|
131 |
+
image = cvtColor(image)
|
132 |
+
#---------------------------------------------------------#
|
133 |
+
# 给图像增加灰条,实现不失真的resize
|
134 |
+
# 也可以直接resize进行识别
|
135 |
+
#---------------------------------------------------------#
|
136 |
+
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
|
137 |
+
#---------------------------------------------------------#
|
138 |
+
# 添加上batch_size维度
|
139 |
+
#---------------------------------------------------------#
|
140 |
+
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
141 |
+
|
142 |
+
with torch.no_grad():
|
143 |
+
images = torch.from_numpy(image_data)
|
144 |
+
if self.cuda:
|
145 |
+
images = images.cuda()
|
146 |
+
#---------------------------------------------------------#
|
147 |
+
# 将图像输入网络当中进行预测!
|
148 |
+
#---------------------------------------------------------#
|
149 |
+
outputs = self.net(images)
|
150 |
+
outputs = self.bbox_util.decode_box(outputs)
|
151 |
+
#---------------------------------------------------------#
|
152 |
+
# 将预测框进行堆叠,然后进行非极大抑制
|
153 |
+
#---------------------------------------------------------#
|
154 |
+
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
155 |
+
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
|
156 |
+
|
157 |
+
if results[0] is None:
|
158 |
+
return image
|
159 |
+
|
160 |
+
top_label = np.array(results[0][:, 6], dtype = 'int32')
|
161 |
+
top_conf = results[0][:, 4] * results[0][:, 5]
|
162 |
+
top_boxes = results[0][:, :4]
|
163 |
+
#---------------------------------------------------------#
|
164 |
+
# 设置字体与边框厚度
|
165 |
+
#---------------------------------------------------------#
|
166 |
+
font = ImageFont.truetype(font='model_data/simhei.ttf', size=np.floor(3e-2 * image.size[1] + 0.5).astype('int32'))
|
167 |
+
thickness = int(max((image.size[0] + image.size[1]) // np.mean(self.input_shape), 1))
|
168 |
+
#---------------------------------------------------------#
|
169 |
+
# 计数
|
170 |
+
#---------------------------------------------------------#
|
171 |
+
if count:
|
172 |
+
print("top_label:", top_label)
|
173 |
+
classes_nums = np.zeros([self.num_classes])
|
174 |
+
for i in range(self.num_classes):
|
175 |
+
num = np.sum(top_label == i)
|
176 |
+
if num > 0:
|
177 |
+
print(self.class_names[i], " : ", num)
|
178 |
+
classes_nums[i] = num
|
179 |
+
print("classes_nums:", classes_nums)
|
180 |
+
#---------------------------------------------------------#
|
181 |
+
# 是否进行目标的裁剪
|
182 |
+
#---------------------------------------------------------#
|
183 |
+
if crop:
|
184 |
+
for i, c in list(enumerate(top_label)):
|
185 |
+
top, left, bottom, right = top_boxes[i]
|
186 |
+
top = max(0, np.floor(top).astype('int32'))
|
187 |
+
left = max(0, np.floor(left).astype('int32'))
|
188 |
+
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
|
189 |
+
right = min(image.size[0], np.floor(right).astype('int32'))
|
190 |
+
|
191 |
+
dir_save_path = "img_crop"
|
192 |
+
if not os.path.exists(dir_save_path):
|
193 |
+
os.makedirs(dir_save_path)
|
194 |
+
crop_image = image.crop([left, top, right, bottom])
|
195 |
+
crop_image.save(os.path.join(dir_save_path, "crop_" + str(i) + ".png"), quality=95, subsampling=0)
|
196 |
+
print("save crop_" + str(i) + ".png to " + dir_save_path)
|
197 |
+
#---------------------------------------------------------#
|
198 |
+
# 图像绘制
|
199 |
+
#---------------------------------------------------------#
|
200 |
+
for i, c in list(enumerate(top_label)):
|
201 |
+
predicted_class = self.class_names[int(c)]
|
202 |
+
box = top_boxes[i]
|
203 |
+
score = top_conf[i]
|
204 |
+
|
205 |
+
top, left, bottom, right = box
|
206 |
+
|
207 |
+
top = max(0, np.floor(top).astype('int32'))
|
208 |
+
left = max(0, np.floor(left).astype('int32'))
|
209 |
+
bottom = min(image.size[1], np.floor(bottom).astype('int32'))
|
210 |
+
right = min(image.size[0], np.floor(right).astype('int32'))
|
211 |
+
|
212 |
+
label = '{} {:.2f}'.format(predicted_class, score)
|
213 |
+
draw = ImageDraw.Draw(image)
|
214 |
+
label_size = draw.textsize(label, font)
|
215 |
+
label = label.encode('utf-8')
|
216 |
+
print(label, top, left, bottom, right)
|
217 |
+
|
218 |
+
if top - label_size[1] >= 0:
|
219 |
+
text_origin = np.array([left, top - label_size[1]])
|
220 |
+
else:
|
221 |
+
text_origin = np.array([left, top + 1])
|
222 |
+
|
223 |
+
for i in range(thickness):
|
224 |
+
draw.rectangle([left + i, top + i, right - i, bottom - i], outline=self.colors[c])
|
225 |
+
draw.rectangle([tuple(text_origin), tuple(text_origin + label_size)], fill=self.colors[c])
|
226 |
+
draw.text(text_origin, str(label,'UTF-8'), fill=(0, 0, 0), font=font)
|
227 |
+
del draw
|
228 |
+
|
229 |
+
return image
|
230 |
+
|
231 |
+
def get_FPS(self, image, test_interval):
|
232 |
+
image_shape = np.array(np.shape(image)[0:2])
|
233 |
+
#---------------------------------------------------------#
|
234 |
+
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
235 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
236 |
+
#---------------------------------------------------------#
|
237 |
+
image = cvtColor(image)
|
238 |
+
#---------------------------------------------------------#
|
239 |
+
# 给图像增加灰条,实现不失真的resize
|
240 |
+
# 也可以直接resize进行识别
|
241 |
+
#---------------------------------------------------------#
|
242 |
+
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
|
243 |
+
#---------------------------------------------------------#
|
244 |
+
# 添加上batch_size维度
|
245 |
+
#---------------------------------------------------------#
|
246 |
+
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
247 |
+
|
248 |
+
with torch.no_grad():
|
249 |
+
images = torch.from_numpy(image_data)
|
250 |
+
if self.cuda:
|
251 |
+
images = images.cuda()
|
252 |
+
#---------------------------------------------------------#
|
253 |
+
# 将图像输入网络当中进行预测!
|
254 |
+
#---------------------------------------------------------#
|
255 |
+
outputs = self.net(images)
|
256 |
+
outputs = self.bbox_util.decode_box(outputs)
|
257 |
+
#---------------------------------------------------------#
|
258 |
+
# 将预测框进行堆叠,然后进行非极大抑制
|
259 |
+
#---------------------------------------------------------#
|
260 |
+
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
261 |
+
image_shape, self.letterbox_image, conf_thres=self.confidence, nms_thres=self.nms_iou)
|
262 |
+
|
263 |
+
t1 = time.time()
|
264 |
+
for _ in range(test_interval):
|
265 |
+
with torch.no_grad():
|
266 |
+
#---------------------------------------------------------#
|
267 |
+
# 将图像输入网络当中进行预测!
|
268 |
+
#---------------------------------------------------------#
|
269 |
+
outputs = self.net(images)
|
270 |
+
outputs = self.bbox_util.decode_box(outputs)
|
271 |
+
#---------------------------------------------------------#
|
272 |
+
# 将预测框进行堆叠,然后进行非极大抑制
|
273 |
+
#---------------------------------------------------------#
|
274 |
+
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
275 |
+
image_shape, self.letterbox_image, conf_thres=self.confidence, nms_thres=self.nms_iou)
|
276 |
+
|
277 |
+
t2 = time.time()
|
278 |
+
tact_time = (t2 - t1) / test_interval
|
279 |
+
return tact_time
|
280 |
+
|
281 |
+
def detect_heatmap(self, image, heatmap_save_path):
|
282 |
+
import cv2
|
283 |
+
import matplotlib.pyplot as plt
|
284 |
+
def sigmoid(x):
|
285 |
+
y = 1.0 / (1.0 + np.exp(-x))
|
286 |
+
return y
|
287 |
+
#---------------------------------------------------------#
|
288 |
+
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
289 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
290 |
+
#---------------------------------------------------------#
|
291 |
+
image = cvtColor(image)
|
292 |
+
#---------------------------------------------------------#
|
293 |
+
# 给图像增加灰条,实现不失真的resize
|
294 |
+
# 也可以直接resize进行识别
|
295 |
+
#---------------------------------------------------------#
|
296 |
+
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
|
297 |
+
#---------------------------------------------------------#
|
298 |
+
# 添加上batch_size维度
|
299 |
+
#---------------------------------------------------------#
|
300 |
+
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
301 |
+
|
302 |
+
with torch.no_grad():
|
303 |
+
images = torch.from_numpy(image_data)
|
304 |
+
if self.cuda:
|
305 |
+
images = images.cuda()
|
306 |
+
#---------------------------------------------------------#
|
307 |
+
# 将图像输入网络当中进行预测!
|
308 |
+
#---------------------------------------------------------#
|
309 |
+
outputs = self.net(images)
|
310 |
+
|
311 |
+
plt.imshow(image, alpha=1)
|
312 |
+
plt.axis('off')
|
313 |
+
mask = np.zeros((image.size[1], image.size[0]))
|
314 |
+
for sub_output in outputs:
|
315 |
+
sub_output = sub_output.cpu().numpy()
|
316 |
+
b, c, h, w = np.shape(sub_output)
|
317 |
+
sub_output = np.transpose(np.reshape(sub_output, [b, 3, -1, h, w]), [0, 3, 4, 1, 2])[0]
|
318 |
+
score = np.max(sigmoid(sub_output[..., 4]), -1)
|
319 |
+
score = cv2.resize(score, (image.size[0], image.size[1]))
|
320 |
+
normed_score = (score * 255).astype('uint8')
|
321 |
+
mask = np.maximum(mask, normed_score)
|
322 |
+
|
323 |
+
plt.imshow(mask, alpha=0.5, interpolation='nearest', cmap="jet")
|
324 |
+
|
325 |
+
plt.axis('off')
|
326 |
+
plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
|
327 |
+
plt.margins(0, 0)
|
328 |
+
plt.savefig(heatmap_save_path, dpi=200, bbox_inches='tight', pad_inches = -0.1)
|
329 |
+
print("Save to the " + heatmap_save_path)
|
330 |
+
plt.show()
|
331 |
+
|
332 |
+
def convert_to_onnx(self, simplify, model_path):
|
333 |
+
import onnx
|
334 |
+
self.generate(onnx=True)
|
335 |
+
|
336 |
+
im = torch.zeros(1, 3, *self.input_shape).to('cpu') # image size(1, 3, 512, 512) BCHW
|
337 |
+
input_layer_names = ["images"]
|
338 |
+
output_layer_names = ["output"]
|
339 |
+
|
340 |
+
# Export the model
|
341 |
+
print(f'Starting export with onnx {onnx.__version__}.')
|
342 |
+
torch.onnx.export(self.net,
|
343 |
+
im,
|
344 |
+
f = model_path,
|
345 |
+
verbose = False,
|
346 |
+
opset_version = 12,
|
347 |
+
training = torch.onnx.TrainingMode.EVAL,
|
348 |
+
do_constant_folding = True,
|
349 |
+
input_names = input_layer_names,
|
350 |
+
output_names = output_layer_names,
|
351 |
+
dynamic_axes = None)
|
352 |
+
|
353 |
+
# Checks
|
354 |
+
model_onnx = onnx.load(model_path) # load onnx model
|
355 |
+
onnx.checker.check_model(model_onnx) # check onnx model
|
356 |
+
|
357 |
+
# Simplify onnx
|
358 |
+
if simplify:
|
359 |
+
import onnxsim
|
360 |
+
print(f'Simplifying with onnx-simplifier {onnxsim.__version__}.')
|
361 |
+
model_onnx, check = onnxsim.simplify(
|
362 |
+
model_onnx,
|
363 |
+
dynamic_input_shape=False,
|
364 |
+
input_shapes=None)
|
365 |
+
assert check, 'assert check failed'
|
366 |
+
onnx.save(model_onnx, model_path)
|
367 |
+
|
368 |
+
print('Onnx model save as {}'.format(model_path))
|
369 |
+
|
370 |
+
def get_map_txt(self, image_id, image, class_names, map_out_path):
|
371 |
+
f = open(os.path.join(map_out_path, "detection-results/"+image_id+".txt"),"w")
|
372 |
+
image_shape = np.array(np.shape(image)[0:2])
|
373 |
+
#---------------------------------------------------------#
|
374 |
+
# 在这里将图像转换成RGB图像,防止灰度图在预测时报错。
|
375 |
+
# 代码仅仅支持RGB图像的预测,所有其它类型的图像都会转化成RGB
|
376 |
+
#---------------------------------------------------------#
|
377 |
+
image = cvtColor(image)
|
378 |
+
#---------------------------------------------------------#
|
379 |
+
# 给图像增加灰条,实现不失真的resize
|
380 |
+
# 也可以直接resize进行识别
|
381 |
+
#---------------------------------------------------------#
|
382 |
+
image_data = resize_image(image, (self.input_shape[1],self.input_shape[0]), self.letterbox_image)
|
383 |
+
#---------------------------------------------------------#
|
384 |
+
# 添加上batch_size维度
|
385 |
+
#---------------------------------------------------------#
|
386 |
+
image_data = np.expand_dims(np.transpose(preprocess_input(np.array(image_data, dtype='float32')), (2, 0, 1)), 0)
|
387 |
+
|
388 |
+
with torch.no_grad():
|
389 |
+
images = torch.from_numpy(image_data)
|
390 |
+
if self.cuda:
|
391 |
+
images = images.cuda()
|
392 |
+
#---------------------------------------------------------#
|
393 |
+
# 将图像输入网络当中进行预测!
|
394 |
+
#---------------------------------------------------------#
|
395 |
+
outputs = self.net(images)
|
396 |
+
outputs = self.bbox_util.decode_box(outputs)
|
397 |
+
#---------------------------------------------------------#
|
398 |
+
# 将预测框进行堆叠,然后进行非极大抑制
|
399 |
+
#---------------------------------------------------------#
|
400 |
+
results = self.bbox_util.non_max_suppression(torch.cat(outputs, 1), self.num_classes, self.input_shape,
|
401 |
+
image_shape, self.letterbox_image, conf_thres = self.confidence, nms_thres = self.nms_iou)
|
402 |
+
|
403 |
+
if results[0] is None:
|
404 |
+
return
|
405 |
+
|
406 |
+
top_label = np.array(results[0][:, 6], dtype = 'int32')
|
407 |
+
top_conf = results[0][:, 4] * results[0][:, 5]
|
408 |
+
top_boxes = results[0][:, :4]
|
409 |
+
|
410 |
+
for i, c in list(enumerate(top_label)):
|
411 |
+
predicted_class = self.class_names[int(c)]
|
412 |
+
box = top_boxes[i]
|
413 |
+
score = str(top_conf[i])
|
414 |
+
|
415 |
+
top, left, bottom, right = box
|
416 |
+
if predicted_class not in class_names:
|
417 |
+
continue
|
418 |
+
|
419 |
+
f.write("%s %s %s %s %s %s\n" % (predicted_class, score[:6], str(int(left)), str(int(top)), str(int(right)),str(int(bottom))))
|
420 |
+
|
421 |
+
f.close()
|
422 |
+
return
|