Purple11's picture
Upload 746 files
713f5bd
raw
history blame
998 Bytes
import os
import numpy as np
import albumentations
from torch.utils.data import Dataset
from taming.data.base import ImagePaths, NumpyPaths, ConcatDatasetWithIndex
class CustomBase(Dataset):
def __init__(self, *args, **kwargs):
super().__init__()
self.data = None
def __len__(self):
return len(self.data)
def __getitem__(self, i):
example = self.data[i]
return example
class CustomTrain(CustomBase):
def __init__(self, size, training_images_list_file):
super().__init__()
with open(training_images_list_file, "r") as f:
paths = f.read().splitlines()
self.data = ImagePaths(paths=paths, size=size, random_crop=False)
class CustomTest(CustomBase):
def __init__(self, size, test_images_list_file):
super().__init__()
with open(test_images_list_file, "r") as f:
paths = f.read().splitlines()
self.data = ImagePaths(paths=paths, size=size, random_crop=False)