Spaces:
Runtime error
Runtime error
File size: 8,047 Bytes
b328990 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 |
from .dataset_info import DatasetInfo
import cv2
import mmcv
import numpy as np
import os
from os import path as osp
import json
def save_result(img,
poses,
img_name=None,
radius=4,
thickness=1,
bbox_score_thr=None,
kpt_score_thr=0.3,
bbox_color='green',
dataset_info=None,
show=False,
out_dir=None,
vis_out_dir=None,
pred_out_dir=None,):
"""Visualize the detection results on the image.
Args:
img (str | np.ndarray): Image filename or loaded image.
poses (dict[dict]): a dict which contains pose_model and pose_results of different classes.
And the pose_results contains bboxes, bbox_scores, keypoints and keypoint_scores.
img_name (str): Image name.
radius (int): Radius of circles.
thickness (int): Thickness of lines.
bbox_score_thr (float): The threshold to visualize the bounding boxes.
kpt_score_thr (float): The threshold to visualize the keypoints.
bbox_color (str | tuple[int]): Color of bounding boxes.
dataset_info (DatasetInfo): Dataset info.
show (bool): Whether to show the image. Default False.
out_dir (str): The output directory to save the visualizations and predictions results.
If vis_out_dir is None, visualizations will be saved in ${out_dir}/visualizations.
If pred_out_dir is None, predictions will be saved in ${out_dir}/predictions.
Default None.
vis_out_dir (str): The output directory to save the visualization results. Default None.
pred_out_dir (str): The output directory to save the predictions results. Default None.
"""
# set flags
vis_out_flag = False if vis_out_dir is None else vis_out_dir
pred_out_flag = False if pred_out_dir is None else pred_out_dir
if out_dir:
if not vis_out_dir:
vis_out_flag = osp.join(out_dir, 'visualizations')
if not osp.exists(vis_out_flag):
os.mkdir(vis_out_flag)
if not pred_out_dir:
pred_out_flag = osp.join(out_dir, 'predictions')
if not osp.exists(pred_out_flag):
os.mkdir(pred_out_flag)
# read image
img_path = None
if isinstance(img, str):
img_path = img
img = mmcv.imread(img)
elif isinstance(img, np.ndarray):
img = img.copy()
else:
raise TypeError('img must be a filename or numpy array, '
f'but got {type(img)}')
bbox_list = []
label_list = []
class_name_list = []
bbox_score_list = []
idx = 0
for label, v in poses.items():
if len(v) == 0:
continue
pose_results = v['pose_results']
bbox = pose_results[0].gt_instances.bboxes
bbox_score = pose_results[0].gt_instances.bbox_scores
for bbox_idx in range(len(bbox)):
b = bbox[bbox_idx]
s = bbox_score[bbox_idx]
if bbox_score_thr is not None:
b = np.append(b, values=s) # switch to x1, y1, x2, y2, score
bbox_score_list.append(s.tolist())
bbox_list.append(b)
label_list.append(idx)
class_name_list.append(label)
idx += 1
bbox_list = np.array(bbox_list)
label_list = np.array(label_list)
# draw bbox
img = mmcv.imshow_det_bboxes(
img,
bbox_list,
label_list,
class_names=class_name_list,
score_thr=bbox_score_thr if bbox_score_thr is not None else 0,
bbox_color=bbox_color,
text_color='white',
show=False,
# out_file=out_file
)
keypoints_list = []
keypoint_scores_list = []
# draw pose of different classes
for label, v in poses.items():
if len(v) == 0:
continue
pose_model = v['pose_model']
pose_results = v['pose_results']
keypoints = pose_results[0].pred_instances.keypoints
for ks in keypoints:
keypoints_list.append(ks.tolist())
keypoint_scores = pose_results[0].pred_instances.keypoint_scores
for kss in keypoint_scores:
keypoint_scores_list.append(kss.tolist())
# get dataset info
if (dataset_info is None and hasattr(pose_model, 'cfg')
and 'dataset_info' in pose_model.cfg):
dataset_info = DatasetInfo(pose_model.cfg.dataset_info)
if dataset_info is not None:
skeleton = dataset_info.skeleton
pose_kpt_color = dataset_info.pose_kpt_color
pose_kpt_color_tmp = []
for color in pose_kpt_color:
pose_kpt_color_tmp.append(tuple([int(x) for x in color]))
pose_kpt_color = pose_kpt_color_tmp
pose_link_color = dataset_info.pose_link_color
pose_link_color_tmp = []
for color in pose_link_color:
pose_link_color_tmp.append(tuple([int(x) for x in color]))
pose_link_color = pose_link_color_tmp
else:
warnings.warn(
'dataset is deprecated.'
'Please set `dataset_info` in the config.'
'Check https://github.com/open-mmlab/mmpose/pull/663 for details.',
DeprecationWarning)
raise ValueError('dataset_info is not specified or set in the config file.')
# create circles_list
circles_list = []
for bbox_idx, circles in enumerate(keypoints):
c_dict = {}
for c_idx, c in enumerate(circles):
if keypoint_scores[bbox_idx][c_idx] >= kpt_score_thr:
c_dict[c_idx] = c
# else:
# c_dict[c_idx] = None
circles_list.append(c_dict)
# create lines_list
lines_list = []
for bbox_idx, _ in enumerate(keypoints):
s_dict = {}
for s_idx, s in enumerate(skeleton):
if s[0] in circles_list[bbox_idx].keys() and s[1] in circles_list[bbox_idx].keys():
s_dict[s_idx] = True
else:
s_dict[s_idx] = False
lines_list.append(s_dict)
# draw circle
for _, circles in enumerate(circles_list):
for c_idx, c in circles.items():
if c is not None:
cv2.circle(img, (int(c[0]), int(c[1])), radius, pose_kpt_color[c_idx], -1)
# draw line
for bbox_idx, lines in enumerate(lines_list):
for l_idx, l in lines.items():
if l:
s = skeleton[l_idx][0] # idx of start point
e = skeleton[l_idx][1] # idx of end point
cv2.line(img,
(int(circles_list[bbox_idx][s][0]), int(circles_list[bbox_idx][s][1])),
(int(circles_list[bbox_idx][e][0]), int(circles_list[bbox_idx][e][1])),
pose_link_color[l_idx], thickness)
if show:
mmcv.imshow(img, wait_time=0)
if img_path is None:
if img_name is not None:
img_path = img_name
else:
img_path = 'demo.jpg'
if vis_out_flag:
out_file = osp.join(vis_out_flag, osp.basename(img_path))
mmcv.imwrite(img, out_file)
if pred_out_flag:
pred_list = []
for bbox_idx in range(len(bbox_list)):
bbl = bbox_list[bbox_idx].tolist()
pred_list.append(dict(
keypoints=keypoints_list[bbox_idx],
keypoint_scores=keypoint_scores_list[bbox_idx],
bbox=[bbl],
bbox_score=bbox_score_list[bbox_idx],
))
# replace .jpg or .png with .json
out_file = osp.join(pred_out_flag, osp.basename(img_path).rsplit('.', 1)[0] + '.json')
json.dump(pred_list, open(out_file, 'w'))
return img
|