import os.path as osp import xml.etree.ElementTree as ET import mmcv import numpy as np from PIL import Image from .builder import DATASETS from .custom import CustomDataset @DATASETS.register_module() class XMLDataset(CustomDataset): """XML dataset for detection. Args: min_size (int | float, optional): The minimum size of bounding boxes in the images. If the size of a bounding box is less than ``min_size``, it would be add to ignored field. """ def __init__(self, min_size=None, **kwargs): assert self.CLASSES or kwargs.get( 'classes', None), 'CLASSES in `XMLDataset` can not be None.' super(XMLDataset, self).__init__(**kwargs) self.cat2label = {cat: i for i, cat in enumerate(self.CLASSES)} self.min_size = min_size def load_annotations(self, ann_file): """Load annotation from XML style ann_file. Args: ann_file (str): Path of XML file. Returns: list[dict]: Annotation info from XML file. """ data_infos = [] img_ids = mmcv.list_from_file(ann_file) for img_id in img_ids: filename = f'JPEGImages/{img_id}.jpg' xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') tree = ET.parse(xml_path) root = tree.getroot() size = root.find('size') if size is not None: width = int(size.find('width').text) height = int(size.find('height').text) else: img_path = osp.join(self.img_prefix, 'JPEGImages', '{}.jpg'.format(img_id)) img = Image.open(img_path) width, height = img.size data_infos.append( dict(id=img_id, filename=filename, width=width, height=height)) return data_infos def _filter_imgs(self, min_size=32): """Filter images too small or without annotation.""" valid_inds = [] for i, img_info in enumerate(self.data_infos): if min(img_info['width'], img_info['height']) < min_size: continue if self.filter_empty_gt: img_id = img_info['id'] xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') tree = ET.parse(xml_path) root = tree.getroot() for obj in root.findall('object'): name = obj.find('name').text if name in self.CLASSES: valid_inds.append(i) break else: valid_inds.append(i) return valid_inds def get_ann_info(self, idx): """Get annotation from XML file by index. Args: idx (int): Index of data. Returns: dict: Annotation info of specified index. """ img_id = self.data_infos[idx]['id'] xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') tree = ET.parse(xml_path) root = tree.getroot() bboxes = [] labels = [] bboxes_ignore = [] labels_ignore = [] for obj in root.findall('object'): name = obj.find('name').text if name not in self.CLASSES: continue label = self.cat2label[name] difficult = obj.find('difficult') difficult = 0 if difficult is None else int(difficult.text) bnd_box = obj.find('bndbox') # TODO: check whether it is necessary to use int # Coordinates may be float type bbox = [ int(float(bnd_box.find('xmin').text)), int(float(bnd_box.find('ymin').text)), int(float(bnd_box.find('xmax').text)), int(float(bnd_box.find('ymax').text)) ] ignore = False if self.min_size: assert not self.test_mode w = bbox[2] - bbox[0] h = bbox[3] - bbox[1] if w < self.min_size or h < self.min_size: ignore = True if difficult or ignore: bboxes_ignore.append(bbox) labels_ignore.append(label) else: bboxes.append(bbox) labels.append(label) if not bboxes: bboxes = np.zeros((0, 4)) labels = np.zeros((0, )) else: bboxes = np.array(bboxes, ndmin=2) - 1 labels = np.array(labels) if not bboxes_ignore: bboxes_ignore = np.zeros((0, 4)) labels_ignore = np.zeros((0, )) else: bboxes_ignore = np.array(bboxes_ignore, ndmin=2) - 1 labels_ignore = np.array(labels_ignore) ann = dict( bboxes=bboxes.astype(np.float32), labels=labels.astype(np.int64), bboxes_ignore=bboxes_ignore.astype(np.float32), labels_ignore=labels_ignore.astype(np.int64)) return ann def get_cat_ids(self, idx): """Get category ids in XML file by index. Args: idx (int): Index of data. Returns: list[int]: All categories in the image of specified index. """ cat_ids = [] img_id = self.data_infos[idx]['id'] xml_path = osp.join(self.img_prefix, 'Annotations', f'{img_id}.xml') tree = ET.parse(xml_path) root = tree.getroot() for obj in root.findall('object'): name = obj.find('name').text if name not in self.CLASSES: continue label = self.cat2label[name] cat_ids.append(label) return cat_ids