ATang0729 commited on
Commit
5a27f3e
1 Parent(s): 0521061

Upload [gradio]model_display.py

Browse files
Files changed (1) hide show
  1. [gradio]model_display.py +323 -0
[gradio]model_display.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import gradio as gr
3
+ import requests
4
+ import base64
5
+ import pandas as pd
6
+ import cv2
7
+ from typing import Tuple
8
+ from PIL import Image
9
+ from io import BytesIO
10
+
11
+ import os
12
+
13
+ from Model.Model6.model6_inference import main as model6_inferencer
14
+ from mmyolo.utils import register_all_modules
15
+
16
+ register_all_modules()
17
+
18
+
19
+ def get_access_token(refatch=False) -> str:
20
+ """获取百度AI的access_token
21
+ :param refatch:是否重新获取access_token
22
+ :return:返回access_token"""
23
+ if refatch:
24
+ # client_id 为官网获取的AK, client_secret 为官网获取的SK
25
+ client_id = '7OtH60uo01ZNYN4yPyahlRSx'
26
+ client_secret = 'D5AxcUpyQyIA7KgPplp7dnz5tM0UIljy'
27
+ host = 'https://aip.baidubce.com/oauth/2.0/token?' \
28
+ 'grant_type=client_credentials&client_id=%s&client_secret=%s' % (client_id, client_secret)
29
+ response = requests.get(host)
30
+ # print(response)
31
+ if response:
32
+ return response.json()['access_token']
33
+ else:
34
+ r"""
35
+ {"refresh_token":"25.24b9368ce91f9bd62c8dad38b3436800.315360000.2007815067.282335-30479502",
36
+ "expires_in":2592000,
37
+ "session_key":
38
+ "9mzdWT\/YmQ7oEi9WCRWbXd0YCcrSYQY6kKZjObKunlcKcZt95j9\/q1aJqbVXihpQOXK84o5WLJ8e7d4cXOi0VUJJcz5YEQ==",
39
+ "access_token":"24.becefee37aba38ea43c546fc154d3016.2592000.1695047067.282335-30479502",
40
+ "scope":"public brain_all_scope brain_body_analysis brain_body_attr brain_body_number brain_driver_behavior
41
+ brain_body_seg brain_gesture_detect brain_body_tracking brain_hand_analysis wise_adapt
42
+ lebo_resource_base lightservice_public hetu_basic lightcms_map_poi kaidian_kaidian
43
+ ApsMisTest_Test\u6743\u9650 vis-classify_flower lpq_\u5f00\u653e cop_helloScope
44
+ ApsMis_fangdi_permission smartapp_snsapi_base smartapp_mapp_dev_manage iop_autocar oauth_tp_app
45
+ smartapp_smart_game_openapi oauth_sessionkey smartapp_swanid_verify smartapp_opensource_openapi
46
+ smartapp_opensource_recapi fake_face_detect_\u5f00\u653eScope
47
+ vis-ocr_\u865a\u62df\u4eba\u7269\u52a9\u7406 idl-video_\u865a\u62df\u4eba\u7269\u52a9\u7406
48
+ smartapp_component smartapp_search_plugin avatar_video_test b2b_tp_openapi b2b_tp_openapi_online
49
+ smartapp_gov_aladin_to_xcx","session_secret":"5c8c3dbb80b04f58bb33aa8077758679"
50
+ }
51
+ """
52
+ access_token = "24.becefee37aba38ea43c546fc154d3016.2592000.1695047067.282335-30479502"
53
+ return access_token
54
+
55
+
56
+ def resize_image(img, max_length=2048, min_length=50) -> Tuple[np.ndarray, bool]:
57
+ """Ensure that the longest side is shorter than 2048px and the shortest side is longer than 50px.
58
+ :param img: 前端传入的图片
59
+ :param max_length: 最长边像素
60
+ :param min_length: 最短边像素
61
+ :return: 返回处理后的图片和是否进行了resize的标志
62
+ """
63
+ flag = False
64
+ max_side = max(img.shape[0], img.shape[1])
65
+ min_side = min(img.shape[0], img.shape[1])
66
+ if max_side > max_length:
67
+ scale = max_length / max_side
68
+ img = cv2.resize(img, (int(img.shape[1] * scale), int(img.shape[0] * scale)))
69
+ flag = True
70
+ if min_side < min_length:
71
+ scale = min_length / min_side
72
+ img = cv2.resize(img, (int(img.shape[1] * scale), int(img.shape[0] * scale)))
73
+ flag = True
74
+ return img, flag
75
+
76
+
77
+ def model1_det(x):
78
+ """人体检测与属性识别
79
+ :param x:前端传入的图片
80
+ :return:返回检测结果
81
+ """
82
+
83
+ def _Baidu_det(img):
84
+ """调用百度AI接口进行人体检测与属性识别
85
+ :param img:前端传入的图片,格式为numpy.ndarray
86
+ :return:返回检测结果
87
+ """
88
+ request_url = "https://aip.baidubce.com/rest/2.0/image-classify/v1/body_attr"
89
+ # 保存图片到本地
90
+ cv2.imwrite('test.jpg', img)
91
+ # 二进制方式打开图片文件
92
+ f = open('test.jpg', 'rb')
93
+ hex_image = base64.b64encode(f.read())
94
+ # 选择二进制图片和需要输出的属性(12个)
95
+ params = {
96
+ "image": hex_image,
97
+ "type": "gender,age,upper_wear,lower_wear,upper_color,lower_color,"
98
+ "orientation,upper_cut,lower_cut,side_cut,occlusion,is_human"
99
+ }
100
+ access_token = get_access_token()
101
+ request_url = request_url + "?access_token=" + access_token
102
+ headers = {'content-type': 'application/x-www-form-urlencoded'}
103
+ response = requests.post(request_url, data=params, headers=headers)
104
+ if response:
105
+ return response.json()
106
+
107
+ def _get_attributes_list(r) -> dict:
108
+ """获取人体属性列表
109
+ :param r:百度AI接口返回的json数据
110
+ :return:返回人体属性列表
111
+ """
112
+ all_humans_attributes_list = {}
113
+ person_num = r['person_num']
114
+ print('person_num:', person_num)
115
+ for human_idx in range(person_num):
116
+ attributes_dict = r['person_info'][human_idx]['attributes']
117
+ attributes_list = []
118
+ for key, value in attributes_dict.items():
119
+ attribute = [key, value['name'], value['score']]
120
+ attributes_list.append(attribute)
121
+ new_value = ['attribute', 'attribute_value', 'accuracy']
122
+ attributes_list.insert(0, new_value)
123
+ df = pd.DataFrame(attributes_list[1:], columns=attributes_list[0])
124
+ all_humans_attributes_list[human_idx] = df
125
+ return all_humans_attributes_list
126
+
127
+ def _show_img(img, bboxes):
128
+ """显示图片
129
+ :param img:前端传入的图片
130
+ :param bboxes:检测框坐标
131
+ :return:处理完成的图片 """
132
+ line_width = int(max(img.shape[1], img.shape[0]) / 400)
133
+ for bbox in bboxes:
134
+ left, top, width, height = bbox['left'], bbox['top'], bbox['width'], bbox['height']
135
+ right, bottom = left + width, top + height
136
+ for i in range(left, right):
137
+ img[top:top + line_width, i] = [255, 0, 0]
138
+ img[bottom - line_width:bottom, i] = [255, 0, 0]
139
+ for i in range(top, bottom):
140
+ img[i, left:left + line_width] = [255, 0, 0]
141
+ img[i, right - line_width:right] = [255, 0, 0]
142
+ return img
143
+
144
+ result = _Baidu_det(x)
145
+ HAs_list = _get_attributes_list(result)
146
+ locations = []
147
+ for i in range(len(result['person_info'])):
148
+ locations.append(result['person_info'][i]['location'])
149
+
150
+ return _show_img(x, locations), f"模型检测到的人数为:{result['person_num']}人"
151
+
152
+
153
+ def model2_rem(x):
154
+ """背景消除
155
+ :param x: 前端传入的图片
156
+ :return: 返回处理后的图片
157
+ """
158
+
159
+ def _Baidu_rem(img):
160
+ """调用百度AI接口进行背景消除
161
+ :param img: 前端传入的图片,格式为numpy.ndarray
162
+ :return: 返回处理后的图片
163
+ """
164
+ request_url = "https://aip.baidubce.com/rest/2.0/image-classify/v1/body_seg"
165
+ bgr_image = cv2.cvtColor(img, cv2.COLOR_RGB2BGR)
166
+ cv2.imwrite('test.jpg', bgr_image)
167
+ f = open('test.jpg', 'rb')
168
+ hex_image = base64.b64encode(f.read())
169
+ params = {"image": hex_image}
170
+ access_token = get_access_token()
171
+ request_url = request_url + "?access_token=" + access_token
172
+ headers = {'content-type': 'application/x-www-form-urlencoded'}
173
+ response = requests.post(request_url, data=params, headers=headers)
174
+ if response:
175
+ encoded_image = response.json()["foreground"]
176
+ decoded_image = base64.b64decode(encoded_image)
177
+ image = Image.open(BytesIO(decoded_image))
178
+ image_array = np.array(image)
179
+ return image_array
180
+
181
+ resized_x, resized_f = resize_image(x)
182
+ new_img = _Baidu_rem(resized_x)
183
+ if resized_f:
184
+ resized_f = "图片尺寸已被修改至合适大小"
185
+ else:
186
+ resized_f = "图片尺寸无需修改"
187
+
188
+ return new_img, resized_f
189
+
190
+
191
+ def model3_ext(x: np.ndarray, num_clusters=12):
192
+ """主色调提取
193
+ :param x: 前端传入的图片
194
+ :param num_clusters: 聚类的数量
195
+ :return: 返回主色调条形卡片"""
196
+
197
+ # TODO: 编写颜色名称匹配算法[most important]
198
+ # TODO: 修改颜色条形卡片呈现形式,要求呈现颜色名称和比例[important]
199
+ def _find_name(color):
200
+ """根据颜色值查找颜色名称
201
+ :param color:颜色值
202
+ :return:返回颜色名称
203
+ """
204
+ pass
205
+
206
+ def _cluster(img, NUM_CLUSTERS):
207
+ """K-means 聚类提取主色调
208
+ :param img: 前端传入的图片
209
+ :param NUM_CLUSTERS: 聚类的数量
210
+ :return: 返回聚类结果
211
+ """
212
+ h, w, ch = img.shape
213
+ reshaped_x = np.float32(img.reshape((-1, 4)))
214
+ new_data_list = []
215
+ for i in range(len(reshaped_x)):
216
+ if reshaped_x[i][3] < 100:
217
+ continue
218
+ else:
219
+ new_data_list.append(reshaped_x[i])
220
+ reshaped_x = np.array(new_data_list)
221
+ reshaped_x = np.delete(reshaped_x, 3, axis=1)
222
+ criteria = (cv2.TERM_CRITERIA_EPS + cv2.TERM_CRITERIA_MAX_ITER, 10, 1.0)
223
+ NUM_CLUSTERS = NUM_CLUSTERS
224
+ ret, label, center = cv2.kmeans(reshaped_x, NUM_CLUSTERS, None, criteria,
225
+ NUM_CLUSTERS, cv2.KMEANS_RANDOM_CENTERS)
226
+ clusters = np.zeros([NUM_CLUSTERS], dtype=np.int32)
227
+ for i in range(len(label)):
228
+ clusters[label[i][0]] += 1
229
+ clusters = np.float32(clusters) / float(len(reshaped_x))
230
+ center = np.int32(center)
231
+ x_offset = 0
232
+ card = np.zeros((50, w, 3), dtype=np.uint8)
233
+ for c in np.argsort(clusters)[::-1]:
234
+ dx = int(clusters[c] * w)
235
+ b = center[c][0]
236
+ g = center[c][1]
237
+ r = center[c][2]
238
+ cv2.rectangle(card, (x_offset, 0), (x_offset + dx, 50),
239
+ (int(b), int(g), int(r)), -1)
240
+ x_offset += dx
241
+
242
+ return card, resized_f
243
+
244
+ resized_x, resized_f = resize_image(x)
245
+ card, resized_f = _cluster(resized_x, num_clusters)
246
+ if resized_f:
247
+ resized_f = "图片尺寸已被修改至合适大小"
248
+ else:
249
+ resized_f = "图片尺寸无需修改"
250
+
251
+ return card, resized_f
252
+
253
+
254
+ def model4_clo(x_path: str):
255
+ def _get_result(input_path: str, cls_results: dict) -> pd.DataFrame:
256
+ """convert the results of model6_2 to a dataframe
257
+ :param input_path: the (absolute) path of the image
258
+ :param cls_results: the results of model6_2
259
+
260
+ :return: a dataframe to display on the web
261
+ """
262
+ result_pd = []
263
+ img_name = os.path.basename(input_path)
264
+ pred_profile = cls_results[img_name][0]['pred_class']
265
+ pred_score = round(cls_results[img_name][0]['pred_score'], 2)
266
+ result_pd.append([img_name, pred_profile, pred_score])
267
+ df = pd.DataFrame(result_pd, columns=None)
268
+ return df
269
+
270
+ output_path_root = 'upload_to_web_tmp'
271
+ if not os.path.exists(output_path_root):
272
+ os.mkdir(output_path_root)
273
+ cls_result = model6_inferencer(x_path, output_path_root)
274
+
275
+ if cls_result:
276
+ # use np to read image·
277
+ x_name = os.path.basename(x_path)
278
+ pred_x = np.array(Image.open(os.path.join(output_path_root, 'visualizations', x_name)))
279
+
280
+ return pred_x, _get_result(x_path, cls_result), "识别成功!"
281
+ # TODO: 完善识别失败时的处理(model6_inference.py中)[important]
282
+ return x_path, pd.DataFrame(), "未检测到服装"
283
+
284
+
285
+ with gr.Blocks() as demo:
286
+ gr.Markdown("# Flip text or image files using this demo.")
287
+ with gr.Tab("人体检测模型"):
288
+ with gr.Row():
289
+ model1_input = gr.Image(height=400)
290
+ model1_output_img = gr.Image(height=400)
291
+ # model1_output_df = gr.DataFrame()
292
+ model1_button = gr.Button("开始检测")
293
+ with gr.Tab("背景消除模型"):
294
+ with gr.Row():
295
+ model2_input = gr.Image(height=400)
296
+ model2_output_img = gr.Image(height=400)
297
+ model2_button = gr.Button("开始消除")
298
+ with gr.Tab('主色调提取'):
299
+ with gr.Row():
300
+ with gr.Column():
301
+ # TODO: 参照“蒙娜丽莎”尝试修改前端界面[not important]
302
+ # TODO: 修改布局,使其更美观[moderately important]
303
+ model3_input = gr.Image(height=400, image_mode='RGBA')
304
+ model3_slider = gr.Slider(minimum=1, maximum=20, step=1, value=12,
305
+ min_width=400, label="聚类数量")
306
+ model3_output_img = gr.Image(height=400)
307
+ model3_button = gr.Button("开始提取")
308
+ with gr.Tab("廓形识别"):
309
+ with gr.Row():
310
+ model4_input = gr.Image(height=400, type="filepath")
311
+ model4_output_img = gr.Image(height=400)
312
+ model4_output_df = gr.DataFrame(headers=['img_name', 'pred_profile', 'pred_score'],
313
+ datatype=['str', 'str', 'number'])
314
+ model4_button = gr.Button("开始识别")
315
+ # 设置折叠内容
316
+ with gr.Accordion("模型运行信息"):
317
+ running_info = gr.Markdown("等待输入和运行...")
318
+
319
+ model1_button.click(model1_det, inputs=model1_input, outputs=[model1_output_img, running_info])
320
+ model2_button.click(model2_rem, inputs=model2_input, outputs=[model2_output_img, running_info])
321
+ model3_button.click(model3_ext, inputs=[model3_input, model3_slider], outputs=[model3_output_img, running_info])
322
+ model4_button.click(model4_clo, inputs=model4_input, outputs=[model4_output_img, model4_output_df, running_info])
323
+ demo.launch(share=True)