Spaces:
Runtime error
Runtime error
Upload [gradio]model_display.py
Browse files- [gradio]model_display.py +323 -0
[gradio]model_display.py
ADDED
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import gradio as gr
|
3 |
+
import requests
|
4 |
+
import base64
|
5 |
+
import pandas as pd
|
6 |
+
import cv2
|
7 |
+
from typing import Tuple
|
8 |
+
from PIL import Image
|
9 |
+
from io import BytesIO
|
10 |
+
|
11 |
+
import os
|
12 |
+
|
13 |
+
from Model.Model6.model6_inference import main as model6_inferencer
|
14 |
+
from mmyolo.utils import register_all_modules
|
15 |
+
|
16 |
+
register_all_modules()
|
17 |
+
|
18 |
+
|
19 |
+
def get_access_token(refatch=False) -> str:
|
20 |
+
"""获取百度AI的access_token
|
21 |
+
:param refatch:是否重新获取access_token
|
22 |
+
:return:返回access_token"""
|
23 |
+
if refatch:
|
24 |
+
# client_id 为官网获取的AK, client_secret 为官网获取的SK
|
25 |
+
client_id = '7OtH60uo01ZNYN4yPyahlRSx'
|
26 |
+
client_secret = 'D5AxcUpyQyIA7KgPplp7dnz5tM0UIljy'
|
27 |
+
host = 'https://aip.baidubce.com/oauth/2.0/token?' \
|
28 |
+
'grant_type=client_credentials&client_id=%s&client_secret=%s' % (client_id, client_secret)
|
29 |
+
response = requests.get(host)
|
30 |
+
# print(response)
|
31 |
+
if response:
|
32 |
+
return response.json()['access_token']
|
33 |
+
else:
|
34 |
+
r"""
|
35 |
+
{"refresh_token":"25.24b9368ce91f9bd62c8dad38b3436800.315360000.2007815067.282335-30479502",
|
36 |
+
"expires_in":2592000,
|
37 |
+
"session_key":
|
38 |
+
"9mzdWT\/YmQ7oEi9WCRWbXd0YCcrSYQY6kKZjObKunlcKcZt95j9\/q1aJqbVXihpQOXK84o5WLJ8e7d4cXOi0VUJJcz5YEQ==",
|
39 |
+
"access_token":"24.becefee37aba38ea43c546fc154d3016.2592000.1695047067.282335-30479502",
|
40 |
+
"scope":"public brain_all_scope brain_body_analysis brain_body_attr brain_body_number brain_driver_behavior
|
41 |
+
brain_body_seg brain_gesture_detect brain_body_tracking brain_hand_analysis wise_adapt
|
42 |
+
lebo_resource_base lightservice_public hetu_basic lightcms_map_poi kaidian_kaidian
|
43 |
+
ApsMisTest_Test\u6743\u9650 vis-classify_flower lpq_\u5f00\u653e cop_helloScope
|
44 |
+
ApsMis_fangdi_permission smartapp_snsapi_base smartapp_mapp_dev_manage iop_autocar oauth_tp_app
|
45 |
+
smartapp_smart_game_openapi oauth_sessionkey smartapp_swanid_verify smartapp_opensource_openapi
|
46 |
+
smartapp_opensource_recapi fake_face_detect_\u5f00\u653eScope
|
47 |
+
vis-ocr_\u865a\u62df\u4eba\u7269\u52a9\u7406 idl-video_\u865a\u62df\u4eba\u7269\u52a9\u7406
|
48 |
+
smartapp_component smartapp_search_plugin avatar_video_test b2b_tp_openapi b2b_tp_openapi_online
|
49 |
+
smartapp_gov_aladin_to_xcx","session_secret":"5c8c3dbb80b04f58bb33aa8077758679"
|
50 |
+
}
|
51 |
+
"""
|
52 |
+
access_token = "24.becefee37aba38ea43c546fc154d3016.2592000.1695047067.282335-30479502"
|
53 |
+
return access_token
|
54 |
+
|
55 |
+
|
56 |
+
def resize_image(img, max_length=2048, min_length=50) -> Tuple[np.ndarray, bool]:
|
57 |
+
"""Ensure that the longest side is shorter than 2048px and the shortest side is longer than 50px.
|
58 |
+
:param img: 前端传入的图片
|
59 |
+
:param max_length: 最长边像素
|
60 |
+
:param min_length: 最短边像素
|
61 |
+
:return: 返回处理后的图片和是否进行了resize的标志
|
62 |
+
"""
|
63 |
+
flag = False
|
64 |
+
max_side = max(img.shape[0], img.shape[1])
|
65 |
+
min_side = min(img.shape[0], img.shape[1])
|
66 |
+
if max_side > max_length:
|
67 |
+
scale = max_length / max_side
|
68 |
+
img = cv2.resize(img, (int(img.shape[1] * scale), int(img.shape[0] * scale)))
|
69 |
+
flag = True
|
70 |
+
if min_side < min_length:
|
71 |
+
scale = min_length / min_side
|
72 |
+
img = cv2.resize(img, (int(img.shape[1] * scale), int(img.shape[0] * scale)))
|
73 |
+
flag = True
|
74 |
+
return img, flag
|
75 |
+
|
76 |
+
|
77 |
+
def model1_det(x):
|
78 |
+
"""人体检测与属性识别
|
79 |
+
:param x:前端传入的图片
|
80 |
+
:return:返回检测结果
|
81 |
+
"""
|
82 |
+
|
83 |
+
def _Baidu_det(img):
|
84 |
+
"""调用百度AI接口进行人体检测与属性识别
|
85 |
+
:param img:前端传入的图片,格式为numpy.ndarray
|
86 |
+
:return:返回检测结果
|
87 |
+
"""
|
88 |
+
request_url = "https://aip.baidubce.com/rest/2.0/image-classify/v1/body_attr"
|
89 |
+
# 保存图片到本地
|
90 |
+
cv2.imwrite('test.jpg', img)
|
91 |
+
# 二进制方式打开图片文件
|
92 |
+
f = open('test.jpg', 'rb')
|
93 |
+
hex_image = base64.b64encode(f.read())
|
94 |
+
# 选择二进制图片和需要输出的属性(12个)
|
95 |
+
params = {
|
96 |
+
"image": hex_image,
|
97 |
+
"type": "gender,age,upper_wear,lower_wear,upper_color,lower_color,"
|
98 |
+
"orientation,upper_cut,lower_cut,side_cut,occlusion,is_human"
|
99 |
+
}
|
100 |
+
access_token = get_access_token()
|
101 |
+
request_url = request_url + "?access_token=" + access_token
|
102 |
+
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
103 |
+
response = requests.post(request_url, data=params, headers=headers)
|
104 |
+
if response:
|
105 |
+
return response.json()
|
106 |
+
|
107 |
+
def _get_attributes_list(r) -> dict:
|
108 |
+
"""获取人体属性列表
|
109 |
+
:param r:百度AI接口返回的json数据
|
110 |
+
:return:返回人体属性列表
|
111 |
+
"""
|
112 |
+
all_humans_attributes_list = {}
|
113 |
+
person_num = r['person_num']
|
114 |
+
print('person_num:', person_num)
|
115 |
+
for human_idx in range(person_num):
|
116 |
+
attributes_dict = r['person_info'][human_idx]['attributes']
|
117 |
+
attributes_list = []
|
118 |
+
for key, value in attributes_dict.items():
|
119 |
+
attribute = [key, value['name'], value['score']]
|
120 |
+
attributes_list.append(attribute)
|
121 |
+
new_value = ['attribute', 'attribute_value', 'accuracy']
|
122 |
+
attributes_list.insert(0, new_value)
|
123 |
+
df = pd.DataFrame(attributes_list[1:], columns=attributes_list[0])
|
124 |
+
all_humans_attributes_list[human_idx] = df
|
125 |
+
return all_humans_attributes_list
|
126 |
+
|
127 |
+
def _show_img(img, bboxes):
|
128 |
+
"""显示图片
|
129 |
+
:param img:前端传入的图片
|
130 |
+
:param bboxes:检测框坐标
|
131 |
+
:return:处理完成的图片 """
|
132 |
+
line_width = int(max(img.shape[1], img.shape[0]) / 400)
|
133 |
+
for bbox in bboxes:
|
134 |
+
left, top, width, height = bbox['left'], bbox['top'], bbox['width'], bbox['height']
|
135 |
+
right, bottom = left + width, top + height
|
136 |
+
for i in range(left, right):
|
137 |
+
img[top:top + line_width, i] = [255, 0, 0]
|
138 |
+
img[bottom - line_width:bottom, i] = [255, 0, 0]
|
139 |
+
for i in range(top, bottom):
|
140 |
+
img[i, left:left + line_width] = [255, 0, 0]
|
141 |
+
img[i, right - line_width:right] = [255, 0, 0]
|
142 |
+
return img
|
143 |
+
|
144 |
+
result = _Baidu_det(x)
|
145 |
+
HAs_list = _get_attributes_list(result)
|
146 |
+
locations = []
|
147 |
+
for i in range(len(result['person_info'])):
|
148 |
+
locations.append(result['person_info'][i]['location'])
|
149 |
+
|
150 |
+
return _show_img(x, locations), f"模型检测到的人数为:{result['person_num']}人"
|
151 |
+
|
152 |
+
|
153 |
+
def model2_rem(x):
|
154 |
+
"""背景消除
|
155 |
+
:param x: 前端传入的图片
|
156 |
+
:return: 返回处理后的图片
|
157 |
+
"""
|
158 |
+
|
159 |
+
def _Baidu_rem(img):
|
160 |
+
"""调用百度AI接口进行背景消除
|
161 |
+
:param img: 前端传入的图片,格式为numpy.ndarray
|
162 |
+
:return: 返回处理后的图片
|
163 |
+
"""
|
164 |
+
request_url = "https://aip.baidubce.com/rest/2.0/image-classify/v1/body_seg"
|
165 |
+
bgr_image = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
|
166 |
+
cv2.imwrite('test.jpg', bgr_image)
|
167 |
+
f = open('test.jpg', 'rb')
|
168 |
+
hex_image = base64.b64encode(f.read())
|
169 |
+
params = {"image": hex_image}
|
170 |
+
access_token = get_access_token()
|
171 |
+
request_url = request_url + "?access_token=" + access_token
|
172 |
+
headers = {'content-type': 'application/x-www-form-urlencoded'}
|
173 |
+
response = requests.post(request_url, data=params, headers=headers)
|
174 |
+
if response:
|
175 |
+
encoded_image = response.json()["foreground"]
|
176 |
+
decoded_image = base64.b64decode(encoded_image)
|
177 |
+
image = Image.open(BytesIO(decoded_image))
|
178 |
+
image_array = np.array(image)
|
179 |
+
return image_array
|
180 |
+
|
181 |
+
resized_x, resized_f = resize_image(x)
|
182 |
+
new_img = _Baidu_rem(resized_x)
|
183 |
+
if resized_f:
|
184 |
+
resized_f = "图片尺寸已被修改至合适大小"
|
185 |
+
else:
|
186 |
+
resized_f = "图片尺寸无需修改"
|
187 |
+
|
188 |
+
return new_img, resized_f
|
189 |
+
|
190 |
+
|
191 |
+
def model3_ext(x: np.ndarray, num_clusters=12):
|
192 |
+
"""主色调提取
|
193 |
+
:param x: 前端传入的图片
|
194 |
+
:param num_clusters: 聚类的数量
|
195 |
+
:return: 返回主色调条形卡片"""
|
196 |
+
|
197 |
+
# TODO: 编写颜色名称匹配算法[most important]
|
198 |
+
# TODO: 修改颜色条形卡片呈现形式,要求呈现颜色名称和比例[important]
|
199 |
+
def _find_name(color):
|
200 |
+
"""根据颜色值查找颜色名称
|
201 |
+
:param color:颜色值
|
202 |
+
:return:返回颜色名称
|
203 |
+
"""
|
204 |
+
pass
|
205 |
+
|
206 |
+
def _cluster(img, NUM_CLUSTERS):
|
207 |
+
"""K-means 聚类提取主色调
|
208 |
+
:param img: 前端传入的图片
|
209 |
+
:param NUM_CLUSTERS: 聚类的数量
|
210 |
+
:return: 返回聚类结果
|
211 |
+
"""
|
212 |
+
h, w, ch = img.shape
|
213 |
+
reshaped_x = np.float32(img.reshape((-1, 4)))
|
214 |
+
new_data_list = []
|
215 |
+
for i in range(len(reshaped_x)):
|
216 |
+
if reshaped_x[i][3] < 100:
|
217 |
+
continue
|
218 |
+
else:
|
219 |
+
new_data_list.append(reshaped_x[i])
|
220 |
+
reshaped_x = np.array(new_data_list)
|
221 |
+
reshaped_x = np.delete(reshaped_x, 3, axis=1)
|
222 |
+
criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
|
223 |
+
NUM_CLUSTERS = NUM_CLUSTERS
|
224 |
+
ret, label, center = cv2.kmeans(reshaped_x, NUM_CLUSTERS, None, criteria,
|
225 |
+
NUM_CLUSTERS, cv2.KMEANS_RANDOM_CENTERS)
|
226 |
+
clusters = np.zeros([NUM_CLUSTERS], dtype=np.int32)
|
227 |
+
for i in range(len(label)):
|
228 |
+
clusters[label[i][0]] += 1
|
229 |
+
clusters = np.float32(clusters) / float(len(reshaped_x))
|
230 |
+
center = np.int32(center)
|
231 |
+
x_offset = 0
|
232 |
+
card = np.zeros((50, w, 3), dtype=np.uint8)
|
233 |
+
for c in np.argsort(clusters)[::-1]:
|
234 |
+
dx = int(clusters[c] * w)
|
235 |
+
b = center[c][0]
|
236 |
+
g = center[c][1]
|
237 |
+
r = center[c][2]
|
238 |
+
cv2.rectangle(card, (x_offset, 0), (x_offset + dx, 50),
|
239 |
+
(int(b), int(g), int(r)), -1)
|
240 |
+
x_offset += dx
|
241 |
+
|
242 |
+
return card, resized_f
|
243 |
+
|
244 |
+
resized_x, resized_f = resize_image(x)
|
245 |
+
card, resized_f = _cluster(resized_x, num_clusters)
|
246 |
+
if resized_f:
|
247 |
+
resized_f = "图片尺寸已被修改至合适大小"
|
248 |
+
else:
|
249 |
+
resized_f = "图片尺寸无需修改"
|
250 |
+
|
251 |
+
return card, resized_f
|
252 |
+
|
253 |
+
|
254 |
+
def model4_clo(x_path: str):
|
255 |
+
def _get_result(input_path: str, cls_results: dict) -> pd.DataFrame:
|
256 |
+
"""convert the results of model6_2 to a dataframe
|
257 |
+
:param input_path: the (absolute) path of the image
|
258 |
+
:param cls_results: the results of model6_2
|
259 |
+
|
260 |
+
:return: a dataframe to display on the web
|
261 |
+
"""
|
262 |
+
result_pd = []
|
263 |
+
img_name = os.path.basename(input_path)
|
264 |
+
pred_profile = cls_results[img_name][0]['pred_class']
|
265 |
+
pred_score = round(cls_results[img_name][0]['pred_score'], 2)
|
266 |
+
result_pd.append([img_name, pred_profile, pred_score])
|
267 |
+
df = pd.DataFrame(result_pd, columns=None)
|
268 |
+
return df
|
269 |
+
|
270 |
+
output_path_root = 'upload_to_web_tmp'
|
271 |
+
if not os.path.exists(output_path_root):
|
272 |
+
os.mkdir(output_path_root)
|
273 |
+
cls_result = model6_inferencer(x_path, output_path_root)
|
274 |
+
|
275 |
+
if cls_result:
|
276 |
+
# use np to read image·
|
277 |
+
x_name = os.path.basename(x_path)
|
278 |
+
pred_x = np.array(Image.open(os.path.join(output_path_root, 'visualizations', x_name)))
|
279 |
+
|
280 |
+
return pred_x, _get_result(x_path, cls_result), "识别成功!"
|
281 |
+
# TODO: 完善识别失败时的处理(model6_inference.py中)[important]
|
282 |
+
return x_path, pd.DataFrame(), "未检测到服装"
|
283 |
+
|
284 |
+
|
285 |
+
with gr.Blocks() as demo:
|
286 |
+
gr.Markdown("# Flip text or image files using this demo.")
|
287 |
+
with gr.Tab("人体检测模型"):
|
288 |
+
with gr.Row():
|
289 |
+
model1_input = gr.Image(height=400)
|
290 |
+
model1_output_img = gr.Image(height=400)
|
291 |
+
# model1_output_df = gr.DataFrame()
|
292 |
+
model1_button = gr.Button("开始检测")
|
293 |
+
with gr.Tab("背景消除模型"):
|
294 |
+
with gr.Row():
|
295 |
+
model2_input = gr.Image(height=400)
|
296 |
+
model2_output_img = gr.Image(height=400)
|
297 |
+
model2_button = gr.Button("开始消除")
|
298 |
+
with gr.Tab('主色调提取'):
|
299 |
+
with gr.Row():
|
300 |
+
with gr.Column():
|
301 |
+
# TODO: 参照“蒙娜丽莎”尝试修改前端界面[not important]
|
302 |
+
# TODO: 修改布局,使其更美观[moderately important]
|
303 |
+
model3_input = gr.Image(height=400, image_mode='RGBA')
|
304 |
+
model3_slider = gr.Slider(minimum=1, maximum=20, step=1, value=12,
|
305 |
+
min_width=400, label="聚类数量")
|
306 |
+
model3_output_img = gr.Image(height=400)
|
307 |
+
model3_button = gr.Button("开始提取")
|
308 |
+
with gr.Tab("廓形识别"):
|
309 |
+
with gr.Row():
|
310 |
+
model4_input = gr.Image(height=400, type="filepath")
|
311 |
+
model4_output_img = gr.Image(height=400)
|
312 |
+
model4_output_df = gr.DataFrame(headers=['img_name', 'pred_profile', 'pred_score'],
|
313 |
+
datatype=['str', 'str', 'number'])
|
314 |
+
model4_button = gr.Button("开始识别")
|
315 |
+
# 设置折叠内容
|
316 |
+
with gr.Accordion("模型运行信息"):
|
317 |
+
running_info = gr.Markdown("等待输入和运行...")
|
318 |
+
|
319 |
+
model1_button.click(model1_det, inputs=model1_input, outputs=[model1_output_img, running_info])
|
320 |
+
model2_button.click(model2_rem, inputs=model2_input, outputs=[model2_output_img, running_info])
|
321 |
+
model3_button.click(model3_ext, inputs=[model3_input, model3_slider], outputs=[model3_output_img, running_info])
|
322 |
+
model4_button.click(model4_clo, inputs=model4_input, outputs=[model4_output_img, model4_output_df, running_info])
|
323 |
+
demo.launch(share=True)
|