|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import base64 |
|
import os |
|
from io import BytesIO |
|
from typing import TYPE_CHECKING, Dict, Iterable, List, Optional, Tuple, Union |
|
|
|
import numpy as np |
|
import requests |
|
from packaging import version |
|
|
|
from .utils import ( |
|
ExplicitEnum, |
|
is_jax_tensor, |
|
is_tf_tensor, |
|
is_torch_available, |
|
is_torch_tensor, |
|
is_vision_available, |
|
requires_backends, |
|
to_numpy, |
|
) |
|
from .utils.constants import ( |
|
IMAGENET_DEFAULT_MEAN, |
|
IMAGENET_DEFAULT_STD, |
|
IMAGENET_STANDARD_MEAN, |
|
IMAGENET_STANDARD_STD, |
|
OPENAI_CLIP_MEAN, |
|
OPENAI_CLIP_STD, |
|
) |
|
|
|
|
|
if is_vision_available(): |
|
import PIL.Image |
|
import PIL.ImageOps |
|
|
|
if version.parse(version.parse(PIL.__version__).base_version) >= version.parse("9.1.0"): |
|
PILImageResampling = PIL.Image.Resampling |
|
else: |
|
PILImageResampling = PIL.Image |
|
|
|
if TYPE_CHECKING: |
|
if is_torch_available(): |
|
import torch |
|
|
|
|
|
ImageInput = Union[ |
|
"PIL.Image.Image", np.ndarray, "torch.Tensor", List["PIL.Image.Image"], List[np.ndarray], List["torch.Tensor"] |
|
] |
|
|
|
|
|
class ChannelDimension(ExplicitEnum): |
|
FIRST = "channels_first" |
|
LAST = "channels_last" |
|
|
|
|
|
def is_pil_image(img): |
|
return is_vision_available() and isinstance(img, PIL.Image.Image) |
|
|
|
|
|
def is_valid_image(img): |
|
return ( |
|
(is_vision_available() and isinstance(img, PIL.Image.Image)) |
|
or isinstance(img, np.ndarray) |
|
or is_torch_tensor(img) |
|
or is_tf_tensor(img) |
|
or is_jax_tensor(img) |
|
) |
|
|
|
|
|
def valid_images(imgs): |
|
|
|
if isinstance(imgs, (list, tuple)): |
|
for img in imgs: |
|
if not valid_images(img): |
|
return False |
|
|
|
elif not is_valid_image(imgs): |
|
return False |
|
return True |
|
|
|
|
|
def is_batched(img): |
|
if isinstance(img, (list, tuple)): |
|
return is_valid_image(img[0]) |
|
return False |
|
|
|
|
|
def is_scaled_image(image: np.ndarray) -> bool: |
|
""" |
|
Checks to see whether the pixel values have already been rescaled to [0, 1]. |
|
""" |
|
if image.dtype == np.uint8: |
|
return False |
|
|
|
|
|
return np.min(image) >= 0 and np.max(image) <= 1 |
|
|
|
|
|
def make_list_of_images(images, expected_ndims: int = 3) -> List[ImageInput]: |
|
""" |
|
Ensure that the input is a list of images. If the input is a single image, it is converted to a list of length 1. |
|
If the input is a batch of images, it is converted to a list of images. |
|
|
|
Args: |
|
images (`ImageInput`): |
|
Image of images to turn into a list of images. |
|
expected_ndims (`int`, *optional*, defaults to 3): |
|
Expected number of dimensions for a single input image. If the input image has a different number of |
|
dimensions, an error is raised. |
|
""" |
|
if is_batched(images): |
|
return images |
|
|
|
|
|
if isinstance(images, PIL.Image.Image): |
|
|
|
return [images] |
|
|
|
if is_valid_image(images): |
|
if images.ndim == expected_ndims + 1: |
|
|
|
images = list(images) |
|
elif images.ndim == expected_ndims: |
|
|
|
images = [images] |
|
else: |
|
raise ValueError( |
|
f"Invalid image shape. Expected either {expected_ndims + 1} or {expected_ndims} dimensions, but got" |
|
f" {images.ndim} dimensions." |
|
) |
|
return images |
|
raise ValueError( |
|
"Invalid image type. Expected either PIL.Image.Image, numpy.ndarray, torch.Tensor, tf.Tensor or " |
|
f"jax.ndarray, but got {type(images)}." |
|
) |
|
|
|
|
|
def to_numpy_array(img) -> np.ndarray: |
|
if not is_valid_image(img): |
|
raise ValueError(f"Invalid image type: {type(img)}") |
|
|
|
if is_vision_available() and isinstance(img, PIL.Image.Image): |
|
return np.array(img) |
|
return to_numpy(img) |
|
|
|
|
|
def infer_channel_dimension_format( |
|
image: np.ndarray, num_channels: Optional[Union[int, Tuple[int, ...]]] = None |
|
) -> ChannelDimension: |
|
""" |
|
Infers the channel dimension format of `image`. |
|
|
|
Args: |
|
image (`np.ndarray`): |
|
The image to infer the channel dimension of. |
|
num_channels (`int` or `Tuple[int, ...]`, *optional*, defaults to `(1, 3)`): |
|
The number of channels of the image. |
|
|
|
Returns: |
|
The channel dimension of the image. |
|
""" |
|
num_channels = num_channels if num_channels is not None else (1, 3) |
|
num_channels = (num_channels,) if isinstance(num_channels, int) else num_channels |
|
|
|
if image.ndim == 3: |
|
first_dim, last_dim = 0, 2 |
|
elif image.ndim == 4: |
|
first_dim, last_dim = 1, 3 |
|
else: |
|
raise ValueError(f"Unsupported number of image dimensions: {image.ndim}") |
|
|
|
if image.shape[first_dim] in num_channels: |
|
return ChannelDimension.FIRST |
|
elif image.shape[last_dim] in num_channels: |
|
return ChannelDimension.LAST |
|
raise ValueError("Unable to infer channel dimension format") |
|
|
|
|
|
def get_channel_dimension_axis( |
|
image: np.ndarray, input_data_format: Optional[Union[ChannelDimension, str]] = None |
|
) -> int: |
|
""" |
|
Returns the channel dimension axis of the image. |
|
|
|
Args: |
|
image (`np.ndarray`): |
|
The image to get the channel dimension axis of. |
|
input_data_format (`ChannelDimension` or `str`, *optional*): |
|
The channel dimension format of the image. If `None`, will infer the channel dimension from the image. |
|
|
|
Returns: |
|
The channel dimension axis of the image. |
|
""" |
|
if input_data_format is None: |
|
input_data_format = infer_channel_dimension_format(image) |
|
if input_data_format == ChannelDimension.FIRST: |
|
return image.ndim - 3 |
|
elif input_data_format == ChannelDimension.LAST: |
|
return image.ndim - 1 |
|
raise ValueError(f"Unsupported data format: {input_data_format}") |
|
|
|
|
|
def get_image_size(image: np.ndarray, channel_dim: ChannelDimension = None) -> Tuple[int, int]: |
|
""" |
|
Returns the (height, width) dimensions of the image. |
|
|
|
Args: |
|
image (`np.ndarray`): |
|
The image to get the dimensions of. |
|
channel_dim (`ChannelDimension`, *optional*): |
|
Which dimension the channel dimension is in. If `None`, will infer the channel dimension from the image. |
|
|
|
Returns: |
|
A tuple of the image's height and width. |
|
""" |
|
if channel_dim is None: |
|
channel_dim = infer_channel_dimension_format(image) |
|
|
|
if channel_dim == ChannelDimension.FIRST: |
|
return image.shape[-2], image.shape[-1] |
|
elif channel_dim == ChannelDimension.LAST: |
|
return image.shape[-3], image.shape[-2] |
|
else: |
|
raise ValueError(f"Unsupported data format: {channel_dim}") |
|
|
|
|
|
def is_valid_annotation_coco_detection(annotation: Dict[str, Union[List, Tuple]]) -> bool: |
|
if ( |
|
isinstance(annotation, dict) |
|
and "image_id" in annotation |
|
and "annotations" in annotation |
|
and isinstance(annotation["annotations"], (list, tuple)) |
|
and ( |
|
|
|
len(annotation["annotations"]) == 0 |
|
or isinstance(annotation["annotations"][0], dict) |
|
) |
|
): |
|
return True |
|
return False |
|
|
|
|
|
def is_valid_annotation_coco_panoptic(annotation: Dict[str, Union[List, Tuple]]) -> bool: |
|
if ( |
|
isinstance(annotation, dict) |
|
and "image_id" in annotation |
|
and "segments_info" in annotation |
|
and "file_name" in annotation |
|
and isinstance(annotation["segments_info"], (list, tuple)) |
|
and ( |
|
|
|
len(annotation["segments_info"]) == 0 |
|
or isinstance(annotation["segments_info"][0], dict) |
|
) |
|
): |
|
return True |
|
return False |
|
|
|
|
|
def valid_coco_detection_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool: |
|
return all(is_valid_annotation_coco_detection(ann) for ann in annotations) |
|
|
|
|
|
def valid_coco_panoptic_annotations(annotations: Iterable[Dict[str, Union[List, Tuple]]]) -> bool: |
|
return all(is_valid_annotation_coco_panoptic(ann) for ann in annotations) |
|
|
|
|
|
def load_image(image: Union[str, "PIL.Image.Image"], timeout: Optional[float] = None) -> "PIL.Image.Image": |
|
""" |
|
Loads `image` to a PIL Image. |
|
|
|
Args: |
|
image (`str` or `PIL.Image.Image`): |
|
The image to convert to the PIL Image format. |
|
timeout (`float`, *optional*): |
|
The timeout value in seconds for the URL request. |
|
|
|
Returns: |
|
`PIL.Image.Image`: A PIL Image. |
|
""" |
|
requires_backends(load_image, ["vision"]) |
|
if isinstance(image, str): |
|
if image.startswith("http://") or image.startswith("https://"): |
|
|
|
|
|
image = PIL.Image.open(requests.get(image, stream=True, timeout=timeout).raw) |
|
elif os.path.isfile(image): |
|
image = PIL.Image.open(image) |
|
else: |
|
if image.startswith("data:image/"): |
|
image = image.split(",")[1] |
|
|
|
|
|
try: |
|
b64 = base64.b64decode(image, validate=True) |
|
image = PIL.Image.open(BytesIO(b64)) |
|
except Exception as e: |
|
raise ValueError( |
|
f"Incorrect image source. Must be a valid URL starting with `http://` or `https://`, a valid path to an image file, or a base64 encoded string. Got {image}. Failed with {e}" |
|
) |
|
elif isinstance(image, PIL.Image.Image): |
|
image = image |
|
else: |
|
raise ValueError( |
|
"Incorrect format used for image. Should be an url linking to an image, a base64 string, a local path, or a PIL image." |
|
) |
|
image = PIL.ImageOps.exif_transpose(image) |
|
image = image.convert("RGB") |
|
return image |
|
|
|
|
|
|
|
class ImageFeatureExtractionMixin: |
|
""" |
|
Mixin that contain utilities for preparing image features. |
|
""" |
|
|
|
def _ensure_format_supported(self, image): |
|
if not isinstance(image, (PIL.Image.Image, np.ndarray)) and not is_torch_tensor(image): |
|
raise ValueError( |
|
f"Got type {type(image)} which is not supported, only `PIL.Image.Image`, `np.array` and " |
|
"`torch.Tensor` are." |
|
) |
|
|
|
def to_pil_image(self, image, rescale=None): |
|
""" |
|
Converts `image` to a PIL Image. Optionally rescales it and puts the channel dimension back as the last axis if |
|
needed. |
|
|
|
Args: |
|
image (`PIL.Image.Image` or `numpy.ndarray` or `torch.Tensor`): |
|
The image to convert to the PIL Image format. |
|
rescale (`bool`, *optional*): |
|
Whether or not to apply the scaling factor (to make pixel values integers between 0 and 255). Will |
|
default to `True` if the image type is a floating type, `False` otherwise. |
|
""" |
|
self._ensure_format_supported(image) |
|
|
|
if is_torch_tensor(image): |
|
image = image.numpy() |
|
|
|
if isinstance(image, np.ndarray): |
|
if rescale is None: |
|
|
|
rescale = isinstance(image.flat[0], np.floating) |
|
|
|
if image.ndim == 3 and image.shape[0] in [1, 3]: |
|
image = image.transpose(1, 2, 0) |
|
if rescale: |
|
image = image * 255 |
|
image = image.astype(np.uint8) |
|
return PIL.Image.fromarray(image) |
|
return image |
|
|
|
def convert_rgb(self, image): |
|
""" |
|
Converts `PIL.Image.Image` to RGB format. |
|
|
|
Args: |
|
image (`PIL.Image.Image`): |
|
The image to convert. |
|
""" |
|
self._ensure_format_supported(image) |
|
if not isinstance(image, PIL.Image.Image): |
|
return image |
|
|
|
return image.convert("RGB") |
|
|
|
def rescale(self, image: np.ndarray, scale: Union[float, int]) -> np.ndarray: |
|
""" |
|
Rescale a numpy image by scale amount |
|
""" |
|
self._ensure_format_supported(image) |
|
return image * scale |
|
|
|
def to_numpy_array(self, image, rescale=None, channel_first=True): |
|
""" |
|
Converts `image` to a numpy array. Optionally rescales it and puts the channel dimension as the first |
|
dimension. |
|
|
|
Args: |
|
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): |
|
The image to convert to a NumPy array. |
|
rescale (`bool`, *optional*): |
|
Whether or not to apply the scaling factor (to make pixel values floats between 0. and 1.). Will |
|
default to `True` if the image is a PIL Image or an array/tensor of integers, `False` otherwise. |
|
channel_first (`bool`, *optional*, defaults to `True`): |
|
Whether or not to permute the dimensions of the image to put the channel dimension first. |
|
""" |
|
self._ensure_format_supported(image) |
|
|
|
if isinstance(image, PIL.Image.Image): |
|
image = np.array(image) |
|
|
|
if is_torch_tensor(image): |
|
image = image.numpy() |
|
|
|
rescale = isinstance(image.flat[0], np.integer) if rescale is None else rescale |
|
|
|
if rescale: |
|
image = self.rescale(image.astype(np.float32), 1 / 255.0) |
|
|
|
if channel_first and image.ndim == 3: |
|
image = image.transpose(2, 0, 1) |
|
|
|
return image |
|
|
|
def expand_dims(self, image): |
|
""" |
|
Expands 2-dimensional `image` to 3 dimensions. |
|
|
|
Args: |
|
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): |
|
The image to expand. |
|
""" |
|
self._ensure_format_supported(image) |
|
|
|
|
|
if isinstance(image, PIL.Image.Image): |
|
return image |
|
|
|
if is_torch_tensor(image): |
|
image = image.unsqueeze(0) |
|
else: |
|
image = np.expand_dims(image, axis=0) |
|
return image |
|
|
|
def normalize(self, image, mean, std, rescale=False): |
|
""" |
|
Normalizes `image` with `mean` and `std`. Note that this will trigger a conversion of `image` to a NumPy array |
|
if it's a PIL Image. |
|
|
|
Args: |
|
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): |
|
The image to normalize. |
|
mean (`List[float]` or `np.ndarray` or `torch.Tensor`): |
|
The mean (per channel) to use for normalization. |
|
std (`List[float]` or `np.ndarray` or `torch.Tensor`): |
|
The standard deviation (per channel) to use for normalization. |
|
rescale (`bool`, *optional*, defaults to `False`): |
|
Whether or not to rescale the image to be between 0 and 1. If a PIL image is provided, scaling will |
|
happen automatically. |
|
""" |
|
self._ensure_format_supported(image) |
|
|
|
if isinstance(image, PIL.Image.Image): |
|
image = self.to_numpy_array(image, rescale=True) |
|
|
|
|
|
elif rescale: |
|
if isinstance(image, np.ndarray): |
|
image = self.rescale(image.astype(np.float32), 1 / 255.0) |
|
elif is_torch_tensor(image): |
|
image = self.rescale(image.float(), 1 / 255.0) |
|
|
|
if isinstance(image, np.ndarray): |
|
if not isinstance(mean, np.ndarray): |
|
mean = np.array(mean).astype(image.dtype) |
|
if not isinstance(std, np.ndarray): |
|
std = np.array(std).astype(image.dtype) |
|
elif is_torch_tensor(image): |
|
import torch |
|
|
|
if not isinstance(mean, torch.Tensor): |
|
mean = torch.tensor(mean) |
|
if not isinstance(std, torch.Tensor): |
|
std = torch.tensor(std) |
|
|
|
if image.ndim == 3 and image.shape[0] in [1, 3]: |
|
return (image - mean[:, None, None]) / std[:, None, None] |
|
else: |
|
return (image - mean) / std |
|
|
|
def resize(self, image, size, resample=None, default_to_square=True, max_size=None): |
|
""" |
|
Resizes `image`. Enforces conversion of input to PIL.Image. |
|
|
|
Args: |
|
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): |
|
The image to resize. |
|
size (`int` or `Tuple[int, int]`): |
|
The size to use for resizing the image. If `size` is a sequence like (h, w), output size will be |
|
matched to this. |
|
|
|
If `size` is an int and `default_to_square` is `True`, then image will be resized to (size, size). If |
|
`size` is an int and `default_to_square` is `False`, then smaller edge of the image will be matched to |
|
this number. i.e, if height > width, then image will be rescaled to (size * height / width, size). |
|
resample (`int`, *optional*, defaults to `PILImageResampling.BILINEAR`): |
|
The filter to user for resampling. |
|
default_to_square (`bool`, *optional*, defaults to `True`): |
|
How to convert `size` when it is a single int. If set to `True`, the `size` will be converted to a |
|
square (`size`,`size`). If set to `False`, will replicate |
|
[`torchvision.transforms.Resize`](https://pytorch.org/vision/stable/transforms.html#torchvision.transforms.Resize) |
|
with support for resizing only the smallest edge and providing an optional `max_size`. |
|
max_size (`int`, *optional*, defaults to `None`): |
|
The maximum allowed for the longer edge of the resized image: if the longer edge of the image is |
|
greater than `max_size` after being resized according to `size`, then the image is resized again so |
|
that the longer edge is equal to `max_size`. As a result, `size` might be overruled, i.e the smaller |
|
edge may be shorter than `size`. Only used if `default_to_square` is `False`. |
|
|
|
Returns: |
|
image: A resized `PIL.Image.Image`. |
|
""" |
|
resample = resample if resample is not None else PILImageResampling.BILINEAR |
|
|
|
self._ensure_format_supported(image) |
|
|
|
if not isinstance(image, PIL.Image.Image): |
|
image = self.to_pil_image(image) |
|
|
|
if isinstance(size, list): |
|
size = tuple(size) |
|
|
|
if isinstance(size, int) or len(size) == 1: |
|
if default_to_square: |
|
size = (size, size) if isinstance(size, int) else (size[0], size[0]) |
|
else: |
|
width, height = image.size |
|
|
|
short, long = (width, height) if width <= height else (height, width) |
|
requested_new_short = size if isinstance(size, int) else size[0] |
|
|
|
if short == requested_new_short: |
|
return image |
|
|
|
new_short, new_long = requested_new_short, int(requested_new_short * long / short) |
|
|
|
if max_size is not None: |
|
if max_size <= requested_new_short: |
|
raise ValueError( |
|
f"max_size = {max_size} must be strictly greater than the requested " |
|
f"size for the smaller edge size = {size}" |
|
) |
|
if new_long > max_size: |
|
new_short, new_long = int(max_size * new_short / new_long), max_size |
|
|
|
size = (new_short, new_long) if width <= height else (new_long, new_short) |
|
|
|
return image.resize(size, resample=resample) |
|
|
|
def center_crop(self, image, size): |
|
""" |
|
Crops `image` to the given size using a center crop. Note that if the image is too small to be cropped to the |
|
size given, it will be padded (so the returned result has the size asked). |
|
|
|
Args: |
|
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape (n_channels, height, width) or (height, width, n_channels)): |
|
The image to resize. |
|
size (`int` or `Tuple[int, int]`): |
|
The size to which crop the image. |
|
|
|
Returns: |
|
new_image: A center cropped `PIL.Image.Image` or `np.ndarray` or `torch.Tensor` of shape: (n_channels, |
|
height, width). |
|
""" |
|
self._ensure_format_supported(image) |
|
|
|
if not isinstance(size, tuple): |
|
size = (size, size) |
|
|
|
|
|
if is_torch_tensor(image) or isinstance(image, np.ndarray): |
|
if image.ndim == 2: |
|
image = self.expand_dims(image) |
|
image_shape = image.shape[1:] if image.shape[0] in [1, 3] else image.shape[:2] |
|
else: |
|
image_shape = (image.size[1], image.size[0]) |
|
|
|
top = (image_shape[0] - size[0]) // 2 |
|
bottom = top + size[0] |
|
left = (image_shape[1] - size[1]) // 2 |
|
right = left + size[1] |
|
|
|
|
|
if isinstance(image, PIL.Image.Image): |
|
return image.crop((left, top, right, bottom)) |
|
|
|
|
|
channel_first = True if image.shape[0] in [1, 3] else False |
|
|
|
|
|
if not channel_first: |
|
if isinstance(image, np.ndarray): |
|
image = image.transpose(2, 0, 1) |
|
if is_torch_tensor(image): |
|
image = image.permute(2, 0, 1) |
|
|
|
|
|
if top >= 0 and bottom <= image_shape[0] and left >= 0 and right <= image_shape[1]: |
|
return image[..., top:bottom, left:right] |
|
|
|
|
|
new_shape = image.shape[:-2] + (max(size[0], image_shape[0]), max(size[1], image_shape[1])) |
|
if isinstance(image, np.ndarray): |
|
new_image = np.zeros_like(image, shape=new_shape) |
|
elif is_torch_tensor(image): |
|
new_image = image.new_zeros(new_shape) |
|
|
|
top_pad = (new_shape[-2] - image_shape[0]) // 2 |
|
bottom_pad = top_pad + image_shape[0] |
|
left_pad = (new_shape[-1] - image_shape[1]) // 2 |
|
right_pad = left_pad + image_shape[1] |
|
new_image[..., top_pad:bottom_pad, left_pad:right_pad] = image |
|
|
|
top += top_pad |
|
bottom += top_pad |
|
left += left_pad |
|
right += left_pad |
|
|
|
new_image = new_image[ |
|
..., max(0, top) : min(new_image.shape[-2], bottom), max(0, left) : min(new_image.shape[-1], right) |
|
] |
|
|
|
return new_image |
|
|
|
def flip_channel_order(self, image): |
|
""" |
|
Flips the channel order of `image` from RGB to BGR, or vice versa. Note that this will trigger a conversion of |
|
`image` to a NumPy array if it's a PIL Image. |
|
|
|
Args: |
|
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): |
|
The image whose color channels to flip. If `np.ndarray` or `torch.Tensor`, the channel dimension should |
|
be first. |
|
""" |
|
self._ensure_format_supported(image) |
|
|
|
if isinstance(image, PIL.Image.Image): |
|
image = self.to_numpy_array(image) |
|
|
|
return image[::-1, :, :] |
|
|
|
def rotate(self, image, angle, resample=None, expand=0, center=None, translate=None, fillcolor=None): |
|
""" |
|
Returns a rotated copy of `image`. This method returns a copy of `image`, rotated the given number of degrees |
|
counter clockwise around its centre. |
|
|
|
Args: |
|
image (`PIL.Image.Image` or `np.ndarray` or `torch.Tensor`): |
|
The image to rotate. If `np.ndarray` or `torch.Tensor`, will be converted to `PIL.Image.Image` before |
|
rotating. |
|
|
|
Returns: |
|
image: A rotated `PIL.Image.Image`. |
|
""" |
|
resample = resample if resample is not None else PIL.Image.NEAREST |
|
|
|
self._ensure_format_supported(image) |
|
|
|
if not isinstance(image, PIL.Image.Image): |
|
image = self.to_pil_image(image) |
|
|
|
return image.rotate( |
|
angle, resample=resample, expand=expand, center=center, translate=translate, fillcolor=fillcolor |
|
) |
|
|