RockeyCoss
add code files”
51f6859
raw
history blame contribute delete
No virus
2.07 kB
# Copyright (c) OpenMMLab. All rights reserved.
import mmcv
import numpy as np
def palette_val(palette):
"""Convert palette to matplotlib palette.
Args:
palette List[tuple]: A list of color tuples.
Returns:
List[tuple[float]]: A list of RGB matplotlib color tuples.
"""
new_palette = []
for color in palette:
color = [c / 255 for c in color]
new_palette.append(tuple(color))
return new_palette
def get_palette(palette, num_classes):
"""Get palette from various inputs.
Args:
palette (list[tuple] | str | tuple | :obj:`Color`): palette inputs.
num_classes (int): the number of classes.
Returns:
list[tuple[int]]: A list of color tuples.
"""
assert isinstance(num_classes, int)
if isinstance(palette, list):
dataset_palette = palette
elif isinstance(palette, tuple):
dataset_palette = [palette] * num_classes
elif palette == 'random' or palette is None:
state = np.random.get_state()
# random color
np.random.seed(42)
palette = np.random.randint(0, 256, size=(num_classes, 3))
np.random.set_state(state)
dataset_palette = [tuple(c) for c in palette]
elif palette == 'coco':
from mmdet.datasets import CocoDataset, CocoPanopticDataset
dataset_palette = CocoDataset.PALETTE
if len(dataset_palette) < num_classes:
dataset_palette = CocoPanopticDataset.PALETTE
elif palette == 'citys':
from mmdet.datasets import CityscapesDataset
dataset_palette = CityscapesDataset.PALETTE
elif palette == 'voc':
from mmdet.datasets import VOCDataset
dataset_palette = VOCDataset.PALETTE
elif mmcv.is_str(palette):
dataset_palette = [mmcv.color_val(palette)[::-1]] * num_classes
else:
raise TypeError(f'Invalid type for palette: {type(palette)}')
assert len(dataset_palette) >= num_classes, \
'The length of palette should not be less than `num_classes`.'
return dataset_palette