Spaces:
Runtime error
Runtime error
from collections import OrderedDict | |
from mmcv.utils import print_log | |
from mmdet.core import eval_map, eval_recalls | |
from .builder import DATASETS | |
from .xml_style import XMLDataset | |
class VOCDataset(XMLDataset): | |
CLASSES = ('aeroplane', 'bicycle', 'bird', 'boat', 'bottle', 'bus', 'car', | |
'cat', 'chair', 'cow', 'diningtable', 'dog', 'horse', | |
'motorbike', 'person', 'pottedplant', 'sheep', 'sofa', 'train', | |
'tvmonitor') | |
def __init__(self, **kwargs): | |
super(VOCDataset, self).__init__(**kwargs) | |
if 'VOC2007' in self.img_prefix: | |
self.year = 2007 | |
elif 'VOC2012' in self.img_prefix: | |
self.year = 2012 | |
else: | |
raise ValueError('Cannot infer dataset year from img_prefix') | |
def evaluate(self, | |
results, | |
metric='mAP', | |
logger=None, | |
proposal_nums=(100, 300, 1000), | |
iou_thr=0.5, | |
scale_ranges=None): | |
"""Evaluate in VOC protocol. | |
Args: | |
results (list[list | tuple]): Testing results of the dataset. | |
metric (str | list[str]): Metrics to be evaluated. Options are | |
'mAP', 'recall'. | |
logger (logging.Logger | str, optional): Logger used for printing | |
related information during evaluation. Default: None. | |
proposal_nums (Sequence[int]): Proposal number used for evaluating | |
recalls, such as recall@100, recall@1000. | |
Default: (100, 300, 1000). | |
iou_thr (float | list[float]): IoU threshold. Default: 0.5. | |
scale_ranges (list[tuple], optional): Scale ranges for evaluating | |
mAP. If not specified, all bounding boxes would be included in | |
evaluation. Default: None. | |
Returns: | |
dict[str, float]: AP/recall metrics. | |
""" | |
if not isinstance(metric, str): | |
assert len(metric) == 1 | |
metric = metric[0] | |
allowed_metrics = ['mAP', 'recall'] | |
if metric not in allowed_metrics: | |
raise KeyError(f'metric {metric} is not supported') | |
annotations = [self.get_ann_info(i) for i in range(len(self))] | |
eval_results = OrderedDict() | |
iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr | |
if metric == 'mAP': | |
assert isinstance(iou_thrs, list) | |
if self.year == 2007: | |
ds_name = 'voc07' | |
else: | |
ds_name = self.CLASSES | |
mean_aps = [] | |
for iou_thr in iou_thrs: | |
print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') | |
mean_ap, _ = eval_map( | |
results, | |
annotations, | |
scale_ranges=None, | |
iou_thr=iou_thr, | |
dataset=ds_name, | |
logger=logger) | |
mean_aps.append(mean_ap) | |
eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) | |
eval_results['mAP'] = sum(mean_aps) / len(mean_aps) | |
elif metric == 'recall': | |
gt_bboxes = [ann['bboxes'] for ann in annotations] | |
recalls = eval_recalls( | |
gt_bboxes, results, proposal_nums, iou_thr, logger=logger) | |
for i, num in enumerate(proposal_nums): | |
for j, iou in enumerate(iou_thr): | |
eval_results[f'recall@{num}@{iou}'] = recalls[i, j] | |
if recalls.shape[1] > 1: | |
ar = recalls.mean(axis=1) | |
for i, num in enumerate(proposal_nums): | |
eval_results[f'AR@{num}'] = ar[i] | |
return eval_results | |