Spaces:
Runtime error
Runtime error
"""old name: test_runtime_model6.py""" | |
import json | |
import os | |
import subprocess | |
import sys | |
import warnings | |
from time import time | |
from typing import Union, Tuple, Any | |
import pandas as pd | |
from mmdet.apis import inference_detector | |
from mmdet.apis import init_detector as det_init_detector | |
from mmpose.apis import inference_topdown | |
from mmpose.apis import init_model as pose_init_model | |
from mmpretrain import ImageClassificationInferencer | |
from mmpretrain.utils import register_all_modules | |
from .extensions.vis_pred_save import save_result | |
register_all_modules() | |
st = ist = time() | |
# irt = time() - st | |
# print(f'==Packages importing time is {irt}s==\n') | |
print('==Start==') | |
# DEVICE = 'cuda:0,1,2,3' | |
DEVICE = 'cpu' | |
abs_path = os.path.dirname(os.path.abspath(__file__)) | |
yolo_config = os.path.join(abs_path, 'Model6_0_ClothesDetection/mmyolo/configs/custom_dataset/yolov6_s_fast.py') | |
yolo_checkpoint = os.path.join(abs_path, 'Model6_0_ClothesDetection/mmyolo/work_dirs/yolov6_s_df2_0.4/epoch_64.pth') | |
pretrain_config = os.path.join(abs_path, 'Model6_2_ProfileRecogition/mmpretrain/configs/resnext101_4xb32_2048e_3c_noF.py') | |
pretrain_checkpoint = os.path.join(abs_path, 'Model6_2_ProfileRecogition/mmpretrain/work_dirs/' | |
'resnext101_4xb32_2048e_3c_noF/best_accuracy_top1_epoch_1520.pth') | |
pose_configs = { | |
'short_sleeved_shirt': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb32-60e_deepfashion2_short_sleeved_shirt_256x192.py', | |
'long_sleeved_shirt': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-120e_deepfashion2_long_sleeved_shirt_256x192.py', | |
'short_sleeved_outwear': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb8-150e_deepfashion2_short_sleeved_outwear_256x192.py', | |
'long_sleeved_outwear': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb16-120e_deepfashion2_long_sleeved_outwear_256x192.py', | |
'vest': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-120e_deepfashion2_vest_256x192.py', | |
'sling': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-120e_deepfashion2_sling_256x192.py', | |
'shorts': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-210e_deepfashion2_shorts_256x192.py', | |
'trousers': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-60e_deepfashion2_trousers_256x192.py', | |
'skirt': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-120e_deepfashion2_skirt_256x192.py', | |
'short_sleeved_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-150e_deepfashion2_short_sleeved_dress_256x192.py', | |
'long_sleeved_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb16-150e_deepfashion2_long_sleeved_dress_256x192.py', | |
'vest_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-150e_deepfashion2_vest_dress_256x192.py', | |
'sling_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/mmpose_1_x/configs/fashion_2d_keypoint/topdown_heatmap/deepfashion2/td_hm_res50_4xb64-210e_deepfashion2_sling_dress_256x192.py', | |
} | |
pose_checkpoints = { | |
'short_sleeved_shirt': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb32-60e_deepfashion2_short_sleeved_shirt_256x192/best_PCK_epoch_50.pth', | |
'long_sleeved_shirt': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-120e_deepfashion2_long_sleeved_shirt_256x192/best_PCK_epoch_60.pth', | |
'short_sleeved_outwear': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb8-150e_deepfashion2_short_sleeved_outwear_256x192/best_PCK_epoch_120.pth', | |
'long_sleeved_outwear': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb16-120e_deepfashion2_long_sleeved_outwear_256x192/best_PCK_epoch_100.pth', | |
'vest': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-120e_deepfashion2_vest_256x192/best_PCK_epoch_90.pth', | |
'sling': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-120e_deepfashion2_sling_256x192/best_PCK_epoch_60.pth', | |
'shorts': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-210e_deepfashion2_shorts_256x192/best_PCK_epoch_160.pth', | |
'trousers': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-60e_deepfashion2_trousers_256x192/best_PCK_epoch_30.pth', | |
'skirt': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-120e_deepfashion2_skirt_256x192/best_PCK_epoch_110.pth', | |
'short_sleeved_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-150e_deepfashion2_short_sleeved_dress_256x192/best_PCK_epoch_100.pth', | |
'long_sleeved_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb16-150e_deepfashion2_long_sleeved_dress_256x192/best_PCK_epoch_120.pth', | |
'vest_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-150e_deepfashion2_vest_dress_256x192/best_PCK_epoch_80.pth', | |
'sling_dress': 'Model/Model6/Model6_1_ClothesKeyPoint/work_dirs_1-x/td_hm_res50_4xb64-210e_deepfashion2_sling_dress_256x192/best_PCK_epoch_140.pth', | |
} | |
start_load = time() | |
yolo_inferencer = det_init_detector(yolo_config, yolo_checkpoint, device=DEVICE) | |
print('=' * 2 + 'The model loading time of MMYolo is {}s'.format(time() - start_load) + '=' * 2) | |
start_load = time() | |
pretrain_inferencer = ImageClassificationInferencer(model=pretrain_config, | |
pretrained=pretrain_checkpoint, | |
device=DEVICE) | |
print('=' * 2 + 'The model loading time of MMPretrain is {}s'.format(time() - start_load) + '=' * 2) | |
def get_bbox_results_by_classes(result) -> dict: | |
""" | |
:param result: the result of mmyolo inference | |
:return: a dict of bbox results by classes | |
""" | |
bbox_results_by_classes = { | |
'short_sleeved_shirt': [], | |
'long_sleeved_shirt': [], | |
'short_sleeved_outwear': [], | |
'long_sleeved_outwear': [], | |
'vest': [], | |
'sling': [], | |
'shorts': [], | |
'trousers': [], | |
'skirt': [], | |
'short_sleeved_dress': [], | |
'long_sleeved_dress': [], | |
'vest_dress': [], | |
'sling_dress': [], | |
} | |
pred_instances = result.pred_instances | |
_bboxes = pred_instances.bboxes | |
_labels = pred_instances.labels | |
_scores = pred_instances.scores | |
labels = _labels[[_scores > 0.3]] | |
bboxes = _bboxes[[_scores > 0.3]] | |
# use enumerate to get index and value | |
for idx, value in enumerate(labels): | |
class_name = list(bbox_results_by_classes.keys())[value] | |
x1 = bboxes[idx][0] | |
y1 = bboxes[idx][1] | |
x2 = bboxes[idx][2] | |
y2 = bboxes[idx][3] | |
bbox_results_by_classes[class_name].append([x1, y1, x2, y2]) | |
return bbox_results_by_classes | |
def mmyolo_inference(img: Union[str, list], model) -> tuple: | |
mmyolo_st = time() | |
result = inference_detector(model, img) | |
mmyolo_et = time() | |
return result, (mmyolo_et - mmyolo_st) | |
def mmpose_inference(person_results: dict, use_bbox: bool, | |
mmyolo_cfg_path: str, mmyolo_ckf_path: str, | |
img: str, output_path_root: str, save=True, device='cpu') -> float: | |
""" | |
:param person_results: the result of mmyolo inference | |
:param use_bbox: whether to use bbox to inference the pose results | |
:param mmyolo_cfg_path: the file path of mmyolo config | |
:param mmyolo_ckf_path: the file path of mmyolo checkpoint | |
:param img: the path of the image to inference | |
:param output_path_root: the root path of the output | |
:param save: whether to save the inference result, including the image and the predicted json file. | |
If `save` is False, `output_path_root` will be invalid. | |
:param device: the device to inference | |
""" | |
mmpose_st = time() | |
poses = { | |
'short_sleeved_shirt': {}, | |
'long_sleeved_shirt': {}, | |
'short_sleeved_outwear': {}, | |
'long_sleeved_outwear': {}, | |
'vest': {}, | |
'sling': {}, | |
'shorts': {}, | |
'trousers': {}, | |
'skirt': {}, | |
'short_sleeved_dress': {}, | |
'long_sleeved_dress': {}, | |
'vest_dress': {}, | |
'sling_dress': {} | |
} | |
for label, person_result in person_results.items(): | |
if len(person_result) == 0: | |
continue | |
pose_config = pose_configs[label] | |
pose_checkpoint = pose_checkpoints[label] | |
if not use_bbox: | |
from mmpose.apis import MMPoseInferencer | |
warnings.warn('use_bbox is False, ' | |
'which means using MMPoseInferencer to inference the pose results without use_bbox ' | |
'and may be wrong') | |
inferencer = MMPoseInferencer( | |
pose2d=pose_config, | |
pose2d_weights=pose_checkpoint, | |
det_model=mmyolo_cfg_path, | |
det_weights=mmyolo_ckf_path | |
) | |
result_generator = inferencer(img, out_dir='upload_to_web_tmp', return_vis=True) | |
result = next(result_generator) | |
# print(result) | |
else: | |
pose_model = pose_init_model( | |
pose_config, | |
pose_checkpoint, | |
device=device | |
) | |
pose_results = inference_topdown(pose_model, img, person_result, bbox_format='xyxy') | |
poses[label]['pose_results'] = pose_results | |
poses[label]['pose_model'] = pose_model | |
mmpose_et = time() | |
if save: | |
save_result(img, poses, out_dir=output_path_root) | |
return mmpose_et - mmpose_st | |
def mmpretrain_inference(img: Union[str, list], model) -> tuple: | |
mmpretain_st = time() | |
cls_result = model(img) | |
mmpretain_et = time() | |
return cls_result, (mmpretain_et - mmpretain_st) | |
def main(img_path: str, output_path_root='upload_to_web_tmp', use_bbox=True, device='cpu', test_runtime=False) -> dict: | |
""" | |
:param img_path: the path of the image or the folder of images | |
:param output_path_root: the root path of the output | |
:param use_bbox: whether to use bbox to inference the pose results | |
:param device: the device to inference | |
:param test_runtime: whether to test the runtime | |
:return: the results of model6_2 in form of dictionary | |
""" | |
if os.path.isdir(img_path): | |
img_names = os.listdir(img_path) | |
img_paths = [os.path.join(img_path, img_name) for img_name in img_names] | |
elif os.path.isfile(img_path): | |
img_paths = [img_path] | |
else: | |
print('==Img_path must be a path of an imgage or a folder!==') | |
raise ValueError() | |
runtimes = [['img_name', | |
'runtime_mmyolo', 'percent1', | |
'runtime_mmpose', 'percent2', | |
'runtime_mmpretrain', 'percent3', | |
'runtime_total']] | |
cls_results = {} | |
for img in img_paths: | |
print(f'==Start to inference {img}==') | |
yolo_result, runtime_mmyolo = mmyolo_inference(img, yolo_inferencer) | |
print(f'==mmyolo running time is {runtime_mmyolo}s==') | |
person_results = get_bbox_results_by_classes(yolo_result) | |
runtime_mmpose = mmpose_inference( | |
person_results=person_results, | |
use_bbox=use_bbox, | |
mmyolo_cfg_path=yolo_config, | |
mmyolo_ckf_path=yolo_checkpoint, | |
img=img, | |
output_path_root=output_path_root, | |
save=True, | |
device=device | |
) | |
print(f'mmpose running time is {runtime_mmpose}s') | |
cls_result, runtime_mmpretrain = mmpretrain_inference(img, pretrain_inferencer) | |
print(f'mmpretrain running time is {runtime_mmpretrain}s') | |
cls_results[os.path.basename(img)] = cls_result | |
if test_runtime: | |
runtime_total = runtime_mmyolo + runtime_mmpose + runtime_mmpretrain | |
percent1 = str(round(runtime_mmyolo / runtime_total * 100, 2)) + '%' | |
percent2 = str(round(runtime_mmpose / runtime_total * 100, 2)) + '%' | |
percent3 = str(round(runtime_mmpretrain / runtime_total * 100, 2)) + '%' | |
img_name = os.path.basename(img) | |
runtimes.append([img_name, | |
runtime_mmyolo, percent1, | |
runtime_mmpose, percent2, | |
runtime_mmpretrain, percent3, | |
runtime_total]) | |
if test_runtime: | |
df = pd.DataFrame(runtimes, columns=runtimes[0]) | |
df.to_csv('runtimes.csv', index=False) | |
return cls_results | |
if __name__ == "__main__": | |
# main(1) | |
main('data-test/') | |
# main('data-test/000002.jpg') | |
rt = time() - st | |
print(f'==Totol time cost is {rt}s==') | |