import torch import random import torchvision.transforms as T import numpy as np class RandAug: """Randomly chosen image augmentations.""" def __init__(self): # Augmentation options self.trans = ['rotation', 'blur', 'color', 'sharpness'] def __call__(self, img): # Randomly choose the number of augmentations used for input image n_transforms = random.randint(1, len(self.trans)) # Randomly choose the augmentation types transforms = random.sample(self.trans, n_transforms) # Implement the augmentations sequentially if 'rotation' in transforms: rotation = random.randint(-10, 10) img = T.functional.rotate(img=img, angle=rotation, expand=True, fill=255) if 'blur' in transforms: kernel = random.choice([1,3,5]) transform = T.GaussianBlur(kernel, sigma=(0.1, 2.0)) img = transform(img) if 'color' in transforms: rand_brightness = random.uniform(0, 0.3) rand_hue = random.uniform(0, 0.5) rand_contrast = random.uniform(0, 0.5) rand_saturation = random.uniform(0, 0.5) transform = T.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue) img = transform(img) if 'sharpness' in transforms: sharpness = 1+(np.random.exponential()/2) trans = T.RandomAdjustSharpness(sharpness, p=1) img = transform(img) return img