JoJoGan-powerhow2 / e4e /datasets /inference_dataset.py
Sanket
.
3d37b6e
raw
history blame
639 Bytes
from torch.utils.data import Dataset
from PIL import Image
from utils import data_utils
class InferenceDataset(Dataset):
def __init__(self, root, opts, transform=None, preprocess=None):
self.paths = sorted(data_utils.make_dataset(root))
self.transform = transform
self.preprocess = preprocess
self.opts = opts
def __len__(self):
return len(self.paths)
def __getitem__(self, index):
from_path = self.paths[index]
if self.preprocess is not None:
from_im = self.preprocess(from_path)
else:
from_im = Image.open(from_path).convert('RGB')
if self.transform:
from_im = self.transform(from_im)
return from_im