Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import numpy as np | |
class DatasetInfo: | |
def __init__(self, dataset_info): | |
self._dataset_info = dataset_info | |
self.dataset_name = self._dataset_info['dataset_name'] | |
self.paper_info = self._dataset_info['paper_info'] | |
self.keypoint_info = self._dataset_info['keypoint_info'] | |
self.skeleton_info = self._dataset_info['skeleton_info'] | |
self.joint_weights = np.array( | |
self._dataset_info['joint_weights'], dtype=np.float32)[:, None] | |
self.sigmas = np.array(self._dataset_info['sigmas']) | |
self._parse_keypoint_info() | |
self._parse_skeleton_info() | |
def _parse_skeleton_info(self): | |
"""Parse skeleton information. | |
- link_num (int): number of links. | |
- skeleton (list((2,))): list of links (id). | |
- skeleton_name (list((2,))): list of links (name). | |
- pose_link_color (np.ndarray): the color of the link for | |
visualization. | |
""" | |
self.link_num = len(self.skeleton_info.keys()) | |
self.pose_link_color = [] | |
self.skeleton_name = [] | |
self.skeleton = [] | |
for skid in self.skeleton_info.keys(): | |
link = self.skeleton_info[skid]['link'] | |
self.skeleton_name.append(link) | |
self.skeleton.append([ | |
self.keypoint_name2id[link[0]], self.keypoint_name2id[link[1]] | |
]) | |
self.pose_link_color.append(self.skeleton_info[skid].get( | |
'color', [255, 128, 0])) | |
self.pose_link_color = np.array(self.pose_link_color) | |
def _parse_keypoint_info(self): | |
"""Parse keypoint information. | |
- keypoint_num (int): number of keypoints. | |
- keypoint_id2name (dict): mapping keypoint id to keypoint name. | |
- keypoint_name2id (dict): mapping keypoint name to keypoint id. | |
- upper_body_ids (list): a list of keypoints that belong to the | |
upper body. | |
- lower_body_ids (list): a list of keypoints that belong to the | |
lower body. | |
- flip_index (list): list of flip index (id) | |
- flip_pairs (list((2,))): list of flip pairs (id) | |
- flip_index_name (list): list of flip index (name) | |
- flip_pairs_name (list((2,))): list of flip pairs (name) | |
- pose_kpt_color (np.ndarray): the color of the keypoint for | |
visualization. | |
""" | |
self.keypoint_num = len(self.keypoint_info.keys()) | |
self.keypoint_id2name = {} | |
self.keypoint_name2id = {} | |
self.pose_kpt_color = [] | |
self.upper_body_ids = [] | |
self.lower_body_ids = [] | |
self.flip_index_name = [] | |
self.flip_pairs_name = [] | |
for kid in self.keypoint_info.keys(): | |
keypoint_name = self.keypoint_info[kid]['name'] | |
self.keypoint_id2name[kid] = keypoint_name | |
self.keypoint_name2id[keypoint_name] = kid | |
self.pose_kpt_color.append(self.keypoint_info[kid].get( | |
'color', [255, 128, 0])) | |
type = self.keypoint_info[kid].get('type', '') | |
if type == 'upper': | |
self.upper_body_ids.append(kid) | |
elif type == 'lower': | |
self.lower_body_ids.append(kid) | |
else: | |
pass | |
swap_keypoint = self.keypoint_info[kid].get('swap', '') | |
if swap_keypoint == keypoint_name or swap_keypoint == '': | |
self.flip_index_name.append(keypoint_name) | |
else: | |
self.flip_index_name.append(swap_keypoint) | |
if [swap_keypoint, keypoint_name] not in self.flip_pairs_name: | |
self.flip_pairs_name.append([keypoint_name, swap_keypoint]) | |
self.flip_pairs = [[ | |
self.keypoint_name2id[pair[0]], self.keypoint_name2id[pair[1]] | |
] for pair in self.flip_pairs_name] | |
self.flip_index = [ | |
self.keypoint_name2id[name] for name in self.flip_index_name | |
] | |
self.pose_kpt_color = np.array(self.pose_kpt_color) | |