|
import torch |
|
import random |
|
import torchvision.transforms as T |
|
import numpy as np |
|
|
|
class RandAug: |
|
"""Randomly chosen image augmentations.""" |
|
|
|
def __init__(self): |
|
|
|
self.trans = ['rotation', 'blur', 'color', 'sharpness'] |
|
|
|
def __call__(self, img): |
|
|
|
n_transforms = random.randint(1, len(self.trans)) |
|
|
|
transforms = random.sample(self.trans, n_transforms) |
|
|
|
|
|
if 'rotation' in transforms: |
|
rotation = random.randint(-10, 10) |
|
img = T.functional.rotate(img=img, angle=rotation, expand=True, fill=255) |
|
|
|
if 'blur' in transforms: |
|
kernel = random.choice([1,3,5]) |
|
transform = T.GaussianBlur(kernel, sigma=(0.1, 2.0)) |
|
img = transform(img) |
|
|
|
if 'color' in transforms: |
|
rand_brightness = random.uniform(0, 0.3) |
|
rand_hue = random.uniform(0, 0.5) |
|
rand_contrast = random.uniform(0, 0.5) |
|
rand_saturation = random.uniform(0, 0.5) |
|
transform = T.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue) |
|
img = transform(img) |
|
|
|
if 'sharpness' in transforms: |
|
sharpness = 1+(np.random.exponential()/2) |
|
trans = T.RandomAdjustSharpness(sharpness, p=1) |
|
img = transform(img) |
|
|
|
return img |