""" Adapted from https://github.com/nv-nguyen/template-pose/blob/main/src/utils/augmentation.py """ from torchvision import transforms from PIL import ImageEnhance, ImageFilter, Image import numpy as np import random import logging from torchvision.transforms import RandomResizedCrop, ToTensor class PillowRGBAugmentation: def __init__(self, pillow_fn, p, factor_interval): self._pillow_fn = pillow_fn self.p = p self.factor_interval = factor_interval def __call__(self, PIL_image): if random.random() <= self.p: factor = random.uniform(*self.factor_interval) if PIL_image.mode != "RGB": logging.warning( f"Error when apply data aug, image mode: {PIL_image.mode}" ) imgs = imgs.convert("RGB") logging.warning(f"Success to change to {PIL_image.mode}") PIL_image = (self._pillow_fn(PIL_image).enhance(factor=factor)).convert( "RGB" ) return PIL_image class PillowSharpness(PillowRGBAugmentation): def __init__( self, p=0.3, factor_interval=(0, 40.0), ): super().__init__( pillow_fn=ImageEnhance.Sharpness, p=p, factor_interval=factor_interval, ) class PillowContrast(PillowRGBAugmentation): def __init__( self, p=0.3, factor_interval=(0.5, 1.6), ): super().__init__( pillow_fn=ImageEnhance.Contrast, p=p, factor_interval=factor_interval, ) class PillowBrightness(PillowRGBAugmentation): def __init__( self, p=0.5, factor_interval=(0.5, 2.0), ): super().__init__( pillow_fn=ImageEnhance.Brightness, p=p, factor_interval=factor_interval, ) class PillowColor(PillowRGBAugmentation): def __init__( self, p=1, factor_interval=(0.0, 20.0), ): super().__init__( pillow_fn=ImageEnhance.Color, p=p, factor_interval=factor_interval, ) class PillowBlur: def __init__(self, p=0.4, factor_interval=(1, 3)): self.p = p self.k = random.randint(*factor_interval) def __call__(self, PIL_image): if random.random() <= self.p: PIL_image = PIL_image.filter(ImageFilter.GaussianBlur(self.k)) return PIL_image class NumpyGaussianNoise: def __init__(self, p, factor_interval=(0.01, 0.3)): self.noise_ratio = random.uniform(*factor_interval) self.p = p def __call__(self, img): if random.random() <= self.p: img = np.copy(img) noisesigma = random.uniform(0, self.noise_ratio) gauss = np.random.normal(0, noisesigma, img.shape) * 255 img = img + gauss img[img > 255] = 255 img[img < 0] = 0 return Image.fromarray(np.uint8(img)) class StandardAugmentation: def __init__( self, names, brightness, contrast, sharpness, color, blur, gaussian_noise ): self.brightness = brightness self.contrast = contrast self.sharpness = sharpness self.color = color self.blur = blur self.gaussian_noise = gaussian_noise # define a dictionary of augmentation functions to be applied self.names = names.split(",") self.augmentations = { "brightness": self.brightness, "contrast": self.contrast, "sharpness": self.sharpness, "color": self.color, "blur": self.blur, "gaussian_noise": self.gaussian_noise, } def __call__(self, img): for name in self.names: img = self.augmentations[name](img) return img class GeometricAugmentation: def __init__( self, names, random_resized_crop, random_horizontal_flip, random_vertical_flip, random_rotation, ): self.random_resized_crop = random_resized_crop self.random_horizontal_flip = random_horizontal_flip self.random_vertical_flip = random_vertical_flip self.random_rotation = random_rotation self.names = names.split(",") self.augmentations = { "random_resized_crop": self.random_resized_crop, "random_horizontal_flip": self.random_horizontal_flip, "random_vertical_flip": self.random_vertical_flip, "random_rotation": self.random_rotation, } def __call__(self, img): for name in self.names: img = self.augmentations[name](img) return img class ImageAugmentation: def __init__( self, names, clip_transform, standard_augmentation, geometric_augmentation ): self.clip_transform = clip_transform self.standard_augmentation = standard_augmentation self.geometric_augmentation = geometric_augmentation self.names = names.split(",") self.transforms = { "clip_transform": self.clip_transform, "standard_augmentation": self.standard_augmentation, "geometric_augmentation": self.geometric_augmentation, } print(f"Image augmentation: {self.names}") def __call__(self, img): for name in self.names: img = self.transforms[name](img) return img if __name__ == "__main__": # sanity check import glob import torchvision.transforms as transforms from torchvision.utils import save_image from omegaconf import DictConfig, OmegaConf from hydra.utils import instantiate import torch from PIL import Image augmentation_config = OmegaConf.load( "./configs/dataset/train_transform/augmentation.yaml" ) augmentation_config.names = "standard_augmentation,geometric_augmentation" augmentation_transform = instantiate(augmentation_config) img_paths = glob.glob("./datasets/osv5m/test/images/*.jpg") num_try = 20 num_try_per_image = 8 num_imgs = 8 for idx in range(num_try): imgs = [] for idx_img in range(num_imgs): img = Image.open(img_paths[idx_img]) for idx_try in range(num_try_per_image): if idx_try == 0: imgs.append(ToTensor()(img.resize((224, 224)))) img_aug = augmentation_transform(img.copy()) img_aug = ToTensor()(img_aug) imgs.append(img_aug) imgs = torch.stack(imgs) save_image(imgs, f"augmentation_{idx:03d}.png", nrow=9)