# Copyright (c) OpenMMLab. All rights reserved. import os.path as osp import warnings from collections import OrderedDict import mmcv import numpy as np from mmcv.utils import print_log from terminaltables import AsciiTable from torch.utils.data import Dataset from mmdet.core import eval_map, eval_recalls from .builder import DATASETS from .pipelines import Compose @DATASETS.register_module() class CustomDataset(Dataset): """Custom dataset for detection. The annotation format is shown as follows. The `ann` field is optional for testing. .. code-block:: none [ { 'filename': 'a.jpg', 'width': 1280, 'height': 720, 'ann': { 'bboxes': (n, 4) in (x1, y1, x2, y2) order. 'labels': (n, ), 'bboxes_ignore': (k, 4), (optional field) 'labels_ignore': (k, 4) (optional field) } }, ... ] Args: ann_file (str): Annotation file path. pipeline (list[dict]): Processing pipeline. classes (str | Sequence[str], optional): Specify classes to load. If is None, ``cls.CLASSES`` will be used. Default: None. data_root (str, optional): Data root for ``ann_file``, ``img_prefix``, ``seg_prefix``, ``proposal_file`` if specified. test_mode (bool, optional): If set True, annotation will not be loaded. filter_empty_gt (bool, optional): If set true, images without bounding boxes of the dataset's classes will be filtered out. This option only works when `test_mode=False`, i.e., we never filter images during tests. """ CLASSES = None PALETTE = None def __init__(self, ann_file, pipeline, classes=None, data_root=None, img_prefix='', seg_prefix=None, seg_suffix='.png', proposal_file=None, test_mode=False, filter_empty_gt=True, file_client_args=dict(backend='disk')): self.ann_file = ann_file self.data_root = data_root self.img_prefix = img_prefix self.seg_prefix = seg_prefix self.seg_suffix = seg_suffix self.proposal_file = proposal_file self.test_mode = test_mode self.filter_empty_gt = filter_empty_gt self.file_client = mmcv.FileClient(**file_client_args) self.CLASSES = self.get_classes(classes) # join paths if data_root is specified if self.data_root is not None: if not osp.isabs(self.ann_file): self.ann_file = osp.join(self.data_root, self.ann_file) if not (self.img_prefix is None or osp.isabs(self.img_prefix)): self.img_prefix = osp.join(self.data_root, self.img_prefix) if not (self.seg_prefix is None or osp.isabs(self.seg_prefix)): self.seg_prefix = osp.join(self.data_root, self.seg_prefix) if not (self.proposal_file is None or osp.isabs(self.proposal_file)): self.proposal_file = osp.join(self.data_root, self.proposal_file) # load annotations (and proposals) if hasattr(self.file_client, 'get_local_path'): with self.file_client.get_local_path(self.ann_file) as local_path: self.data_infos = self.load_annotations(local_path) else: warnings.warn( 'The used MMCV version does not have get_local_path. ' f'We treat the {self.ann_file} as local paths and it ' 'might cause errors if the path is not a local path. ' 'Please use MMCV>= 1.3.16 if you meet errors.') self.data_infos = self.load_annotations(self.ann_file) if self.proposal_file is not None: if hasattr(self.file_client, 'get_local_path'): with self.file_client.get_local_path( self.proposal_file) as local_path: self.proposals = self.load_proposals(local_path) else: warnings.warn( 'The used MMCV version does not have get_local_path. ' f'We treat the {self.ann_file} as local paths and it ' 'might cause errors if the path is not a local path. ' 'Please use MMCV>= 1.3.16 if you meet errors.') self.proposals = self.load_proposals(self.proposal_file) else: self.proposals = None # filter images too small and containing no annotations if not test_mode: valid_inds = self._filter_imgs() self.data_infos = [self.data_infos[i] for i in valid_inds] if self.proposals is not None: self.proposals = [self.proposals[i] for i in valid_inds] # set group flag for the sampler self._set_group_flag() # processing pipeline self.pipeline = Compose(pipeline) def __len__(self): """Total number of samples of data.""" return len(self.data_infos) def load_annotations(self, ann_file): """Load annotation from annotation file.""" return mmcv.load(ann_file) def load_proposals(self, proposal_file): """Load proposal from proposal file.""" return mmcv.load(proposal_file) def get_ann_info(self, idx): """Get annotation by index. Args: idx (int): Index of data. Returns: dict: Annotation info of specified index. """ return self.data_infos[idx]['ann'] def get_cat_ids(self, idx): """Get category ids by index. Args: idx (int): Index of data. Returns: list[int]: All categories in the image of specified index. """ return self.data_infos[idx]['ann']['labels'].astype(np.int).tolist() def pre_pipeline(self, results): """Prepare results dict for pipeline.""" results['img_prefix'] = self.img_prefix results['seg_prefix'] = self.seg_prefix results['proposal_file'] = self.proposal_file results['bbox_fields'] = [] results['mask_fields'] = [] results['seg_fields'] = [] def _filter_imgs(self, min_size=32): """Filter images too small.""" if self.filter_empty_gt: warnings.warn( 'CustomDataset does not support filtering empty gt images.') valid_inds = [] for i, img_info in enumerate(self.data_infos): if min(img_info['width'], img_info['height']) >= min_size: valid_inds.append(i) return valid_inds def _set_group_flag(self): """Set flag according to image aspect ratio. Images with aspect ratio greater than 1 will be set as group 1, otherwise group 0. """ self.flag = np.zeros(len(self), dtype=np.uint8) for i in range(len(self)): img_info = self.data_infos[i] if img_info['width'] / img_info['height'] > 1: self.flag[i] = 1 def _rand_another(self, idx): """Get another random index from the same group as the given index.""" pool = np.where(self.flag == self.flag[idx])[0] return np.random.choice(pool) def __getitem__(self, idx): """Get training/test data after pipeline. Args: idx (int): Index of data. Returns: dict: Training/test data (with annotation if `test_mode` is set \ True). """ if self.test_mode: return self.prepare_test_img(idx) while True: data = self.prepare_train_img(idx) if data is None: idx = self._rand_another(idx) continue return data def prepare_train_img(self, idx): """Get training data and annotations after pipeline. Args: idx (int): Index of data. Returns: dict: Training data and annotation after pipeline with new keys \ introduced by pipeline. """ img_info = self.data_infos[idx] ann_info = self.get_ann_info(idx) results = dict(img_info=img_info, ann_info=ann_info) if self.proposals is not None: results['proposals'] = self.proposals[idx] self.pre_pipeline(results) return self.pipeline(results) def prepare_test_img(self, idx): """Get testing data after pipeline. Args: idx (int): Index of data. Returns: dict: Testing data after pipeline with new keys introduced by \ pipeline. """ img_info = self.data_infos[idx] results = dict(img_info=img_info) if self.proposals is not None: results['proposals'] = self.proposals[idx] self.pre_pipeline(results) return self.pipeline(results) @classmethod def get_classes(cls, classes=None): """Get class names of current dataset. Args: classes (Sequence[str] | str | None): If classes is None, use default CLASSES defined by builtin dataset. If classes is a string, take it as a file name. The file contains the name of classes where each line contains one class name. If classes is a tuple or list, override the CLASSES defined by the dataset. Returns: tuple[str] or list[str]: Names of categories of the dataset. """ if classes is None: return cls.CLASSES if isinstance(classes, str): # take it as a file path class_names = mmcv.list_from_file(classes) elif isinstance(classes, (tuple, list)): class_names = classes else: raise ValueError(f'Unsupported type {type(classes)} of classes.') return class_names def get_cat2imgs(self): """Get a dict with class as key and img_ids as values, which will be used in :class:`ClassAwareSampler`. Returns: dict[list]: A dict of per-label image list, the item of the dict indicates a label index, corresponds to the image index that contains the label. """ if self.CLASSES is None: raise ValueError('self.CLASSES can not be None') # sort the label index cat2imgs = {i: [] for i in range(len(self.CLASSES))} for i in range(len(self)): cat_ids = set(self.get_cat_ids(i)) for cat in cat_ids: cat2imgs[cat].append(i) return cat2imgs def format_results(self, results, **kwargs): """Place holder to format result to dataset specific output.""" def evaluate(self, results, metric='mAP', logger=None, proposal_nums=(100, 300, 1000), iou_thr=0.5, scale_ranges=None): """Evaluate the dataset. Args: results (list): Testing results of the dataset. metric (str | list[str]): Metrics to be evaluated. logger (logging.Logger | None | str): Logger used for printing related information during evaluation. Default: None. proposal_nums (Sequence[int]): Proposal number used for evaluating recalls, such as recall@100, recall@1000. Default: (100, 300, 1000). iou_thr (float | list[float]): IoU threshold. Default: 0.5. scale_ranges (list[tuple] | None): Scale ranges for evaluating mAP. Default: None. """ if not isinstance(metric, str): assert len(metric) == 1 metric = metric[0] allowed_metrics = ['mAP', 'recall'] if metric not in allowed_metrics: raise KeyError(f'metric {metric} is not supported') annotations = [self.get_ann_info(i) for i in range(len(self))] eval_results = OrderedDict() iou_thrs = [iou_thr] if isinstance(iou_thr, float) else iou_thr if metric == 'mAP': assert isinstance(iou_thrs, list) mean_aps = [] for iou_thr in iou_thrs: print_log(f'\n{"-" * 15}iou_thr: {iou_thr}{"-" * 15}') mean_ap, _ = eval_map( results, annotations, scale_ranges=scale_ranges, iou_thr=iou_thr, dataset=self.CLASSES, logger=logger) mean_aps.append(mean_ap) eval_results[f'AP{int(iou_thr * 100):02d}'] = round(mean_ap, 3) eval_results['mAP'] = sum(mean_aps) / len(mean_aps) elif metric == 'recall': gt_bboxes = [ann['bboxes'] for ann in annotations] recalls = eval_recalls( gt_bboxes, results, proposal_nums, iou_thr, logger=logger) for i, num in enumerate(proposal_nums): for j, iou in enumerate(iou_thrs): eval_results[f'recall@{num}@{iou}'] = recalls[i, j] if recalls.shape[1] > 1: ar = recalls.mean(axis=1) for i, num in enumerate(proposal_nums): eval_results[f'AR@{num}'] = ar[i] return eval_results def __repr__(self): """Print the number of instance number.""" dataset_type = 'Test' if self.test_mode else 'Train' result = (f'\n{self.__class__.__name__} {dataset_type} dataset ' f'with number of images {len(self)}, ' f'and instance counts: \n') if self.CLASSES is None: result += 'Category names are not provided. \n' return result instance_count = np.zeros(len(self.CLASSES) + 1).astype(int) # count the instance number in each image for idx in range(len(self)): label = self.get_ann_info(idx)['labels'] unique, counts = np.unique(label, return_counts=True) if len(unique) > 0: # add the occurrence number to each class instance_count[unique] += counts else: # background is the last index instance_count[-1] += 1 # create a table with category count table_data = [['category', 'count'] * 5] row_data = [] for cls, count in enumerate(instance_count): if cls < len(self.CLASSES): row_data += [f'{cls} [{self.CLASSES[cls]}]', f'{count}'] else: # add the background number row_data += ['-1 background', f'{count}'] if len(row_data) == 10: table_data.append(row_data) row_data = [] if len(row_data) >= 2: if row_data[-1] == '0': row_data = row_data[:-2] if len(row_data) >= 2: table_data.append([]) table_data.append(row_data) table = AsciiTable(table_data) result += table.table return result