Spaces:
Runtime error
Runtime error
from .dataset_info import DatasetInfo | |
import cv2 | |
import mmcv | |
import numpy as np | |
import os | |
from os import path as osp | |
import json | |
def save_result(img, | |
poses, | |
img_name=None, | |
radius=4, | |
thickness=1, | |
bbox_score_thr=None, | |
kpt_score_thr=0.3, | |
bbox_color='green', | |
dataset_info=None, | |
show=False, | |
out_dir=None, | |
vis_out_dir=None, | |
pred_out_dir=None,): | |
"""Visualize the detection results on the image. | |
Args: | |
img (str | np.ndarray): Image filename or loaded image. | |
poses (dict[dict]): a dict which contains pose_model and pose_results of different classes. | |
And the pose_results contains bboxes, bbox_scores, keypoints and keypoint_scores. | |
img_name (str): Image name. | |
radius (int): Radius of circles. | |
thickness (int): Thickness of lines. | |
bbox_score_thr (float): The threshold to visualize the bounding boxes. | |
kpt_score_thr (float): The threshold to visualize the keypoints. | |
bbox_color (str | tuple[int]): Color of bounding boxes. | |
dataset_info (DatasetInfo): Dataset info. | |
show (bool): Whether to show the image. Default False. | |
out_dir (str): The output directory to save the visualizations and predictions results. | |
If vis_out_dir is None, visualizations will be saved in ${out_dir}/visualizations. | |
If pred_out_dir is None, predictions will be saved in ${out_dir}/predictions. | |
Default None. | |
vis_out_dir (str): The output directory to save the visualization results. Default None. | |
pred_out_dir (str): The output directory to save the predictions results. Default None. | |
""" | |
# set flags | |
vis_out_flag = False if vis_out_dir is None else vis_out_dir | |
pred_out_flag = False if pred_out_dir is None else pred_out_dir | |
if out_dir: | |
if not vis_out_dir: | |
vis_out_flag = osp.join(out_dir, 'visualizations') | |
if not osp.exists(vis_out_flag): | |
os.mkdir(vis_out_flag) | |
if not pred_out_dir: | |
pred_out_flag = osp.join(out_dir, 'predictions') | |
if not osp.exists(pred_out_flag): | |
os.mkdir(pred_out_flag) | |
# read image | |
img_path = None | |
if isinstance(img, str): | |
img_path = img | |
img = mmcv.imread(img) | |
elif isinstance(img, np.ndarray): | |
img = img.copy() | |
else: | |
raise TypeError('img must be a filename or numpy array, ' | |
f'but got {type(img)}') | |
bbox_list = [] | |
label_list = [] | |
class_name_list = [] | |
bbox_score_list = [] | |
idx = 0 | |
for label, v in poses.items(): | |
if len(v) == 0: | |
continue | |
pose_results = v['pose_results'] | |
bbox = pose_results[0].gt_instances.bboxes | |
bbox_score = pose_results[0].gt_instances.bbox_scores | |
for bbox_idx in range(len(bbox)): | |
b = bbox[bbox_idx] | |
s = bbox_score[bbox_idx] | |
if bbox_score_thr is not None: | |
b = np.append(b, values=s) # switch to x1, y1, x2, y2, score | |
bbox_score_list.append(s.tolist()) | |
bbox_list.append(b) | |
label_list.append(idx) | |
class_name_list.append(label) | |
idx += 1 | |
bbox_list = np.array(bbox_list) | |
label_list = np.array(label_list) | |
# draw bbox | |
img = mmcv.imshow_det_bboxes( | |
img, | |
bbox_list, | |
label_list, | |
class_names=class_name_list, | |
score_thr=bbox_score_thr if bbox_score_thr is not None else 0, | |
bbox_color=bbox_color, | |
text_color='white', | |
show=False, | |
# out_file=out_file | |
) | |
keypoints_list = [] | |
keypoint_scores_list = [] | |
# draw pose of different classes | |
for label, v in poses.items(): | |
if len(v) == 0: | |
continue | |
pose_model = v['pose_model'] | |
pose_results = v['pose_results'] | |
keypoints = pose_results[0].pred_instances.keypoints | |
for ks in keypoints: | |
keypoints_list.append(ks.tolist()) | |
keypoint_scores = pose_results[0].pred_instances.keypoint_scores | |
for kss in keypoint_scores: | |
keypoint_scores_list.append(kss.tolist()) | |
# get dataset info | |
if (dataset_info is None and hasattr(pose_model, 'cfg') | |
and 'dataset_info' in pose_model.cfg): | |
dataset_info = DatasetInfo(pose_model.cfg.dataset_info) | |
if dataset_info is not None: | |
skeleton = dataset_info.skeleton | |
pose_kpt_color = dataset_info.pose_kpt_color | |
pose_kpt_color_tmp = [] | |
for color in pose_kpt_color: | |
pose_kpt_color_tmp.append(tuple([int(x) for x in color])) | |
pose_kpt_color = pose_kpt_color_tmp | |
pose_link_color = dataset_info.pose_link_color | |
pose_link_color_tmp = [] | |
for color in pose_link_color: | |
pose_link_color_tmp.append(tuple([int(x) for x in color])) | |
pose_link_color = pose_link_color_tmp | |
else: | |
warnings.warn( | |
'dataset is deprecated.' | |
'Please set `dataset_info` in the config.' | |
'Check https://github.com/open-mmlab/mmpose/pull/663 for details.', | |
DeprecationWarning) | |
raise ValueError('dataset_info is not specified or set in the config file.') | |
# create circles_list | |
circles_list = [] | |
for bbox_idx, circles in enumerate(keypoints): | |
c_dict = {} | |
for c_idx, c in enumerate(circles): | |
if keypoint_scores[bbox_idx][c_idx] >= kpt_score_thr: | |
c_dict[c_idx] = c | |
# else: | |
# c_dict[c_idx] = None | |
circles_list.append(c_dict) | |
# create lines_list | |
lines_list = [] | |
for bbox_idx, _ in enumerate(keypoints): | |
s_dict = {} | |
for s_idx, s in enumerate(skeleton): | |
if s[0] in circles_list[bbox_idx].keys() and s[1] in circles_list[bbox_idx].keys(): | |
s_dict[s_idx] = True | |
else: | |
s_dict[s_idx] = False | |
lines_list.append(s_dict) | |
# draw circle | |
for _, circles in enumerate(circles_list): | |
for c_idx, c in circles.items(): | |
if c is not None: | |
cv2.circle(img, (int(c[0]), int(c[1])), radius, pose_kpt_color[c_idx], -1) | |
# draw line | |
for bbox_idx, lines in enumerate(lines_list): | |
for l_idx, l in lines.items(): | |
if l: | |
s = skeleton[l_idx][0] # idx of start point | |
e = skeleton[l_idx][1] # idx of end point | |
cv2.line(img, | |
(int(circles_list[bbox_idx][s][0]), int(circles_list[bbox_idx][s][1])), | |
(int(circles_list[bbox_idx][e][0]), int(circles_list[bbox_idx][e][1])), | |
pose_link_color[l_idx], thickness) | |
if show: | |
mmcv.imshow(img, wait_time=0) | |
if img_path is None: | |
if img_name is not None: | |
img_path = img_name | |
else: | |
img_path = 'demo.jpg' | |
if vis_out_flag: | |
out_file = osp.join(vis_out_flag, osp.basename(img_path)) | |
mmcv.imwrite(img, out_file) | |
if pred_out_flag: | |
pred_list = [] | |
for bbox_idx in range(len(bbox_list)): | |
bbl = bbox_list[bbox_idx].tolist() | |
pred_list.append(dict( | |
keypoints=keypoints_list[bbox_idx], | |
keypoint_scores=keypoint_scores_list[bbox_idx], | |
bbox=[bbl], | |
bbox_score=bbox_score_list[bbox_idx], | |
)) | |
# replace .jpg or .png with .json | |
out_file = osp.join(pred_out_flag, osp.basename(img_path).rsplit('.', 1)[0] + '.json') | |
json.dump(pred_list, open(out_file, 'w')) | |
return img | |