abhishek's picture
abhishek HF staff
first commit
58f667f
raw
history blame
15.9 kB
import os.path as osp
import mmcv
import numpy as np
import pycocotools.mask as maskUtils
from mmdet.core import BitmapMasks, PolygonMasks
from ..builder import PIPELINES
@PIPELINES.register_module()
class LoadImageFromFile(object):
"""Load an image from file.
Required keys are "img_prefix" and "img_info" (a dict that must contain the
key "filename"). Added or updated keys are "filename", "img", "img_shape",
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
Defaults to 'color'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
def __init__(self,
to_float32=False,
color_type='color',
file_client_args=dict(backend='disk')):
self.to_float32 = to_float32
self.color_type = color_type
self.file_client_args = file_client_args.copy()
self.file_client = None
def __call__(self, results):
"""Call functions to load image and get image meta information.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded image and meta information.
"""
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
if results['img_prefix'] is not None:
filename = osp.join(results['img_prefix'],
results['img_info']['filename'])
else:
filename = results['img_info']['filename']
img_bytes = self.file_client.get(filename)
img = mmcv.imfrombytes(img_bytes, flag=self.color_type)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = filename
results['ori_filename'] = results['img_info']['filename']
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['img_fields'] = ['img']
return results
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'to_float32={self.to_float32}, '
f"color_type='{self.color_type}', "
f'file_client_args={self.file_client_args})')
return repr_str
@PIPELINES.register_module()
class LoadImageFromWebcam(LoadImageFromFile):
"""Load an image from webcam.
Similar with :obj:`LoadImageFromFile`, but the image read from webcam is in
``results['img']``.
"""
def __call__(self, results):
"""Call functions to add image meta information.
Args:
results (dict): Result dict with Webcam read image in
``results['img']``.
Returns:
dict: The dict contains loaded image and meta information.
"""
img = results['img']
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = None
results['ori_filename'] = None
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
results['img_fields'] = ['img']
return results
@PIPELINES.register_module()
class LoadMultiChannelImageFromFiles(object):
"""Load multi-channel images from a list of separate channel files.
Required keys are "img_prefix" and "img_info" (a dict that must contain the
key "filename", which is expected to be a list of filenames).
Added or updated keys are "filename", "img", "img_shape",
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`),
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1).
Args:
to_float32 (bool): Whether to convert the loaded image to a float32
numpy array. If set to False, the loaded image is an uint8 array.
Defaults to False.
color_type (str): The flag argument for :func:`mmcv.imfrombytes`.
Defaults to 'color'.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
def __init__(self,
to_float32=False,
color_type='unchanged',
file_client_args=dict(backend='disk')):
self.to_float32 = to_float32
self.color_type = color_type
self.file_client_args = file_client_args.copy()
self.file_client = None
def __call__(self, results):
"""Call functions to load multiple images and get images meta
information.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded images and meta information.
"""
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
if results['img_prefix'] is not None:
filename = [
osp.join(results['img_prefix'], fname)
for fname in results['img_info']['filename']
]
else:
filename = results['img_info']['filename']
img = []
for name in filename:
img_bytes = self.file_client.get(name)
img.append(mmcv.imfrombytes(img_bytes, flag=self.color_type))
img = np.stack(img, axis=-1)
if self.to_float32:
img = img.astype(np.float32)
results['filename'] = filename
results['ori_filename'] = results['img_info']['filename']
results['img'] = img
results['img_shape'] = img.shape
results['ori_shape'] = img.shape
# Set initial values for default meta_keys
results['pad_shape'] = img.shape
results['scale_factor'] = 1.0
num_channels = 1 if len(img.shape) < 3 else img.shape[2]
results['img_norm_cfg'] = dict(
mean=np.zeros(num_channels, dtype=np.float32),
std=np.ones(num_channels, dtype=np.float32),
to_rgb=False)
return results
def __repr__(self):
repr_str = (f'{self.__class__.__name__}('
f'to_float32={self.to_float32}, '
f"color_type='{self.color_type}', "
f'file_client_args={self.file_client_args})')
return repr_str
@PIPELINES.register_module()
class LoadAnnotations(object):
"""Load mutiple types of annotations.
Args:
with_bbox (bool): Whether to parse and load the bbox annotation.
Default: True.
with_label (bool): Whether to parse and load the label annotation.
Default: True.
with_mask (bool): Whether to parse and load the mask annotation.
Default: False.
with_seg (bool): Whether to parse and load the semantic segmentation
annotation. Default: False.
poly2mask (bool): Whether to convert the instance masks from polygons
to bitmaps. Default: True.
file_client_args (dict): Arguments to instantiate a FileClient.
See :class:`mmcv.fileio.FileClient` for details.
Defaults to ``dict(backend='disk')``.
"""
def __init__(self,
with_bbox=True,
with_label=True,
with_mask=False,
with_seg=False,
poly2mask=True,
file_client_args=dict(backend='disk')):
self.with_bbox = with_bbox
self.with_label = with_label
self.with_mask = with_mask
self.with_seg = with_seg
self.poly2mask = poly2mask
self.file_client_args = file_client_args.copy()
self.file_client = None
def _load_bboxes(self, results):
"""Private function to load bounding box annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded bounding box annotations.
"""
ann_info = results['ann_info']
results['gt_bboxes'] = ann_info['bboxes'].copy()
gt_bboxes_ignore = ann_info.get('bboxes_ignore', None)
if gt_bboxes_ignore is not None:
results['gt_bboxes_ignore'] = gt_bboxes_ignore.copy()
results['bbox_fields'].append('gt_bboxes_ignore')
results['bbox_fields'].append('gt_bboxes')
return results
def _load_labels(self, results):
"""Private function to load label annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded label annotations.
"""
results['gt_labels'] = results['ann_info']['labels'].copy()
return results
def _poly2mask(self, mask_ann, img_h, img_w):
"""Private function to convert masks represented with polygon to
bitmaps.
Args:
mask_ann (list | dict): Polygon mask annotation input.
img_h (int): The height of output mask.
img_w (int): The width of output mask.
Returns:
numpy.ndarray: The decode bitmap mask of shape (img_h, img_w).
"""
if isinstance(mask_ann, list):
# polygon -- a single object might consist of multiple parts
# we merge all parts into one mask rle code
rles = maskUtils.frPyObjects(mask_ann, img_h, img_w)
rle = maskUtils.merge(rles)
elif isinstance(mask_ann['counts'], list):
# uncompressed RLE
rle = maskUtils.frPyObjects(mask_ann, img_h, img_w)
else:
# rle
rle = mask_ann
mask = maskUtils.decode(rle)
return mask
def process_polygons(self, polygons):
"""Convert polygons to list of ndarray and filter invalid polygons.
Args:
polygons (list[list]): Polygons of one instance.
Returns:
list[numpy.ndarray]: Processed polygons.
"""
polygons = [np.array(p) for p in polygons]
valid_polygons = []
for polygon in polygons:
if len(polygon) % 2 == 0 and len(polygon) >= 6:
valid_polygons.append(polygon)
return valid_polygons
def _load_masks(self, results):
"""Private function to load mask annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded mask annotations.
If ``self.poly2mask`` is set ``True``, `gt_mask` will contain
:obj:`PolygonMasks`. Otherwise, :obj:`BitmapMasks` is used.
"""
h, w = results['img_info']['height'], results['img_info']['width']
gt_masks = results['ann_info']['masks']
if self.poly2mask:
gt_masks = BitmapMasks(
[self._poly2mask(mask, h, w) for mask in gt_masks], h, w)
else:
gt_masks = PolygonMasks(
[self.process_polygons(polygons) for polygons in gt_masks], h,
w)
results['gt_masks'] = gt_masks
results['mask_fields'].append('gt_masks')
return results
def _load_semantic_seg(self, results):
"""Private function to load semantic segmentation annotations.
Args:
results (dict): Result dict from :obj:`dataset`.
Returns:
dict: The dict contains loaded semantic segmentation annotations.
"""
if self.file_client is None:
self.file_client = mmcv.FileClient(**self.file_client_args)
filename = osp.join(results['seg_prefix'],
results['ann_info']['seg_map'])
img_bytes = self.file_client.get(filename)
results['gt_semantic_seg'] = mmcv.imfrombytes(
img_bytes, flag='unchanged').squeeze()
results['seg_fields'].append('gt_semantic_seg')
return results
def __call__(self, results):
"""Call function to load multiple types annotations.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded bounding box, label, mask and
semantic segmentation annotations.
"""
if self.with_bbox:
results = self._load_bboxes(results)
if results is None:
return None
if self.with_label:
results = self._load_labels(results)
if self.with_mask:
results = self._load_masks(results)
if self.with_seg:
results = self._load_semantic_seg(results)
return results
def __repr__(self):
repr_str = self.__class__.__name__
repr_str += f'(with_bbox={self.with_bbox}, '
repr_str += f'with_label={self.with_label}, '
repr_str += f'with_mask={self.with_mask}, '
repr_str += f'with_seg={self.with_seg}, '
repr_str += f'poly2mask={self.poly2mask}, '
repr_str += f'poly2mask={self.file_client_args})'
return repr_str
@PIPELINES.register_module()
class LoadProposals(object):
"""Load proposal pipeline.
Required key is "proposals". Updated keys are "proposals", "bbox_fields".
Args:
num_max_proposals (int, optional): Maximum number of proposals to load.
If not specified, all proposals will be loaded.
"""
def __init__(self, num_max_proposals=None):
self.num_max_proposals = num_max_proposals
def __call__(self, results):
"""Call function to load proposals from file.
Args:
results (dict): Result dict from :obj:`mmdet.CustomDataset`.
Returns:
dict: The dict contains loaded proposal annotations.
"""
proposals = results['proposals']
if proposals.shape[1] not in (4, 5):
raise AssertionError(
'proposals should have shapes (n, 4) or (n, 5), '
f'but found {proposals.shape}')
proposals = proposals[:, :4]
if self.num_max_proposals is not None:
proposals = proposals[:self.num_max_proposals]
if len(proposals) == 0:
proposals = np.array([[0, 0, 0, 0]], dtype=np.float32)
results['proposals'] = proposals
results['bbox_fields'].append('proposals')
return results
def __repr__(self):
return self.__class__.__name__ + \
f'(num_max_proposals={self.num_max_proposals})'
@PIPELINES.register_module()
class FilterAnnotations(object):
"""Filter invalid annotations.
Args:
min_gt_bbox_wh (tuple[int]): Minimum width and height of ground truth
boxes.
"""
def __init__(self, min_gt_bbox_wh):
# TODO: add more filter options
self.min_gt_bbox_wh = min_gt_bbox_wh
def __call__(self, results):
assert 'gt_bboxes' in results
gt_bboxes = results['gt_bboxes']
w = gt_bboxes[:, 2] - gt_bboxes[:, 0]
h = gt_bboxes[:, 3] - gt_bboxes[:, 1]
keep = (w > self.min_gt_bbox_wh[0]) & (h > self.min_gt_bbox_wh[1])
if not keep.any():
return None
else:
keys = ('gt_bboxes', 'gt_labels', 'gt_masks', 'gt_semantic_seg')
for key in keys:
if key in results:
results[key] = results[key][keep]
return results