Spaces:
Runtime error
Runtime error
import torchvision.transforms as T | |
TRANSFORMS = dict() | |
def register_transform(transform): | |
name = transform.__name__ | |
if name in TRANSFORMS: | |
raise RuntimeError(f'Transform {name} has already registered.') | |
TRANSFORMS.update({name: transform}) | |
def get_transform(type, resolution): | |
transform = TRANSFORMS[type](resolution) | |
transform = T.Compose(transform) | |
transform.image_size = resolution | |
return transform | |
def default_train(n_px): | |
transform = [ | |
T.Lambda(lambda img: img.convert('RGB')), | |
T.Resize(n_px), # Image.BICUBIC | |
T.CenterCrop(n_px), | |
# T.RandomHorizontalFlip(), | |
T.ToTensor(), | |
T.Normalize([.5], [.5]), | |
] | |
return transform | |