GCycleGAN / utils /dataloader.py
Egrt's picture
init
95e767b
raw
history blame contribute delete
No virus
1.79 kB
import numpy as np
import torch
from PIL import Image
from torch.utils.data.dataset import Dataset
from utils.utils import cvtColor, preprocess_input
class CycleGanDataset(Dataset):
def __init__(self, annotation_lines_A, annotation_lines_B, input_shape):
super(CycleGanDataset, self).__init__()
self.annotation_lines_A = annotation_lines_A
self.annotation_lines_B = annotation_lines_B
self.length_A = len(self.annotation_lines_A)
self.length_B = len(self.annotation_lines_B)
self.input_shape = input_shape
def __len__(self):
return max(self.length_A, self.length_B)
def __getitem__(self, index):
index_A = index % self.length_A
image_A = Image.open(self.annotation_lines_A[index_A].split(';')[1].split()[0])
image_A = cvtColor(image_A).resize([self.input_shape[1], self.input_shape[0]], Image.BICUBIC)
image_A = np.array(image_A, dtype=np.float32)
image_A = np.transpose(preprocess_input(image_A), (2, 0, 1))
index_B = index % self.length_B
image_B = Image.open(self.annotation_lines_B[index_B].split(';')[1].split()[0])
image_B = cvtColor(image_B).resize([self.input_shape[1], self.input_shape[0]], Image.BICUBIC)
image_B = np.array(image_B, dtype=np.float32)
image_B = np.transpose(preprocess_input(image_B), (2, 0, 1))
return image_A, image_B
def CycleGan_dataset_collate(batch):
images_A = []
images_B = []
for image_A, image_B in batch:
images_A.append(image_A)
images_B.append(image_B)
images_A = torch.from_numpy(np.array(images_A, np.float32))
images_B = torch.from_numpy(np.array(images_B, np.float32))
return images_A, images_B