JoJoGan-powerhow2 / e4e /datasets /gt_res_dataset.py
Sanket
.
3d37b6e
raw
history blame
890 Bytes
#!/usr/bin/python
# encoding: utf-8
import os
from torch.utils.data import Dataset
from PIL import Image
import torch
class GTResDataset(Dataset):
def __init__(self, root_path, gt_dir=None, transform=None, transform_train=None):
self.pairs = []
for f in os.listdir(root_path):
image_path = os.path.join(root_path, f)
gt_path = os.path.join(gt_dir, f)
if f.endswith(".jpg") or f.endswith(".png"):
self.pairs.append([image_path, gt_path.replace('.png', '.jpg'), None])
self.transform = transform
self.transform_train = transform_train
def __len__(self):
return len(self.pairs)
def __getitem__(self, index):
from_path, to_path, _ = self.pairs[index]
from_im = Image.open(from_path).convert('RGB')
to_im = Image.open(to_path).convert('RGB')
if self.transform:
to_im = self.transform(to_im)
from_im = self.transform(from_im)
return from_im, to_im