Spaces:
Runtime error
Runtime error
''' | |
* Copyright (c) 2023 Salesforce, Inc. | |
* All rights reserved. | |
* SPDX-License-Identifier: Apache License 2.0 | |
* For full license text, see LICENSE.txt file in the repo root or http://www.apache.org/licenses/ | |
* By Can Qin | |
* Modified from ControlNet repo: https://github.com/lllyasviel/ControlNet | |
* Copyright (c) 2023 Lvmin Zhang and Maneesh Agrawala | |
* Modified from MMCV repo: From https://github.com/open-mmlab/mmcv | |
* Copyright (c) OpenMMLab. All rights reserved. | |
''' | |
import os.path as osp | |
import annotator.uniformer.mmcv as mmcv | |
import numpy as np | |
from ..builder import PIPELINES | |
class LoadImageFromFile(object): | |
"""Load an image from file. | |
Required keys are "img_prefix" and "img_info" (a dict that must contain the | |
key "filename"). Added or updated keys are "filename", "img", "img_shape", | |
"ori_shape" (same as `img_shape`), "pad_shape" (same as `img_shape`), | |
"scale_factor" (1.0) and "img_norm_cfg" (means=0 and stds=1). | |
Args: | |
to_float32 (bool): Whether to convert the loaded image to a float32 | |
numpy array. If set to False, the loaded image is an uint8 array. | |
Defaults to False. | |
color_type (str): The flag argument for :func:`mmcv.imfrombytes`. | |
Defaults to 'color'. | |
file_client_args (dict): Arguments to instantiate a FileClient. | |
See :class:`mmcv.fileio.FileClient` for details. | |
Defaults to ``dict(backend='disk')``. | |
imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: | |
'cv2' | |
""" | |
def __init__(self, | |
to_float32=False, | |
color_type='color', | |
file_client_args=dict(backend='disk'), | |
imdecode_backend='cv2'): | |
self.to_float32 = to_float32 | |
self.color_type = color_type | |
self.file_client_args = file_client_args.copy() | |
self.file_client = None | |
self.imdecode_backend = imdecode_backend | |
def __call__(self, results): | |
"""Call functions to load image and get image meta information. | |
Args: | |
results (dict): Result dict from :obj:`mmseg.CustomDataset`. | |
Returns: | |
dict: The dict contains loaded image and meta information. | |
""" | |
if self.file_client is None: | |
self.file_client = mmcv.FileClient(**self.file_client_args) | |
if results.get('img_prefix') is not None: | |
filename = osp.join(results['img_prefix'], | |
results['img_info']['filename']) | |
else: | |
filename = results['img_info']['filename'] | |
img_bytes = self.file_client.get(filename) | |
img = mmcv.imfrombytes( | |
img_bytes, flag=self.color_type, backend=self.imdecode_backend) | |
if self.to_float32: | |
img = img.astype(np.float32) | |
results['filename'] = filename | |
results['ori_filename'] = results['img_info']['filename'] | |
results['img'] = img | |
results['img_shape'] = img.shape | |
results['ori_shape'] = img.shape | |
# Set initial values for default meta_keys | |
results['pad_shape'] = img.shape | |
results['scale_factor'] = 1.0 | |
num_channels = 1 if len(img.shape) < 3 else img.shape[2] | |
results['img_norm_cfg'] = dict( | |
mean=np.zeros(num_channels, dtype=np.float32), | |
std=np.ones(num_channels, dtype=np.float32), | |
to_rgb=False) | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f'(to_float32={self.to_float32},' | |
repr_str += f"color_type='{self.color_type}'," | |
repr_str += f"imdecode_backend='{self.imdecode_backend}')" | |
return repr_str | |
class LoadAnnotations(object): | |
"""Load annotations for semantic segmentation. | |
Args: | |
reduce_zero_label (bool): Whether reduce all label value by 1. | |
Usually used for datasets where 0 is background label. | |
Default: False. | |
file_client_args (dict): Arguments to instantiate a FileClient. | |
See :class:`mmcv.fileio.FileClient` for details. | |
Defaults to ``dict(backend='disk')``. | |
imdecode_backend (str): Backend for :func:`mmcv.imdecode`. Default: | |
'pillow' | |
""" | |
def __init__(self, | |
reduce_zero_label=False, | |
file_client_args=dict(backend='disk'), | |
imdecode_backend='pillow'): | |
self.reduce_zero_label = reduce_zero_label | |
self.file_client_args = file_client_args.copy() | |
self.file_client = None | |
self.imdecode_backend = imdecode_backend | |
def __call__(self, results): | |
"""Call function to load multiple types annotations. | |
Args: | |
results (dict): Result dict from :obj:`mmseg.CustomDataset`. | |
Returns: | |
dict: The dict contains loaded semantic segmentation annotations. | |
""" | |
if self.file_client is None: | |
self.file_client = mmcv.FileClient(**self.file_client_args) | |
if results.get('seg_prefix', None) is not None: | |
filename = osp.join(results['seg_prefix'], | |
results['ann_info']['seg_map']) | |
else: | |
filename = results['ann_info']['seg_map'] | |
img_bytes = self.file_client.get(filename) | |
gt_semantic_seg = mmcv.imfrombytes( | |
img_bytes, flag='unchanged', | |
backend=self.imdecode_backend).squeeze().astype(np.uint8) | |
# modify if custom classes | |
if results.get('label_map', None) is not None: | |
for old_id, new_id in results['label_map'].items(): | |
gt_semantic_seg[gt_semantic_seg == old_id] = new_id | |
# reduce zero_label | |
if self.reduce_zero_label: | |
# avoid using underflow conversion | |
gt_semantic_seg[gt_semantic_seg == 0] = 255 | |
gt_semantic_seg = gt_semantic_seg - 1 | |
gt_semantic_seg[gt_semantic_seg == 254] = 255 | |
results['gt_semantic_seg'] = gt_semantic_seg | |
results['seg_fields'].append('gt_semantic_seg') | |
return results | |
def __repr__(self): | |
repr_str = self.__class__.__name__ | |
repr_str += f'(reduce_zero_label={self.reduce_zero_label},' | |
repr_str += f"imdecode_backend='{self.imdecode_backend}')" | |
return repr_str | |