court-records-htr / augments.py
MikkoLipsanen's picture
Upload 3 files
1838a16 verified
import torch
import random
import torchvision.transforms as T
import numpy as np
class RandAug:
"""Randomly chosen image augmentations."""
def __init__(self):
# Augmentation options
self.trans = ['rotation', 'blur', 'color', 'sharpness']
def __call__(self, img):
# Randomly choose the number of augmentations used for input image
n_transforms = random.randint(1, len(self.trans))
# Randomly choose the augmentation types
transforms = random.sample(self.trans, n_transforms)
# Implement the augmentations sequentially
if 'rotation' in transforms:
rotation = random.randint(-10, 10)
img = T.functional.rotate(img=img, angle=rotation, expand=True, fill=255)
if 'blur' in transforms:
kernel = random.choice([1,3,5])
transform = T.GaussianBlur(kernel, sigma=(0.1, 2.0))
img = transform(img)
if 'color' in transforms:
rand_brightness = random.uniform(0, 0.3)
rand_hue = random.uniform(0, 0.5)
rand_contrast = random.uniform(0, 0.5)
rand_saturation = random.uniform(0, 0.5)
transform = T.ColorJitter(brightness=rand_brightness, contrast=rand_contrast, saturation=rand_saturation, hue=rand_hue)
img = transform(img)
if 'sharpness' in transforms:
sharpness = 1+(np.random.exponential()/2)
trans = T.RandomAdjustSharpness(sharpness, p=1)
img = transform(img)
return img