Spaces:
Runtime error
Runtime error
import numbers | |
import os | |
import queue as Queue | |
import threading | |
from functools import partial | |
from typing import Iterable | |
import mxnet as mx | |
import numpy as np | |
import torch | |
from torch import distributed | |
from torch.utils.data import DataLoader | |
from torch.utils.data import Dataset | |
from torchvision import transforms | |
from torchvision.datasets import ImageFolder | |
from utils.utils_distributed_sampler import DistributedSampler | |
from utils.utils_distributed_sampler import get_dist_info | |
from utils.utils_distributed_sampler import worker_init_fn | |
def get_dataloader( | |
root_dir, | |
local_rank, | |
batch_size, | |
dali=False, | |
seed=2048, | |
num_workers=2, | |
) -> Iterable: | |
rec = os.path.join(root_dir, "train.rec") | |
idx = os.path.join(root_dir, "train.idx") | |
train_set = None | |
# Synthetic | |
if root_dir == "synthetic": | |
train_set = SyntheticDataset() | |
dali = False | |
# Mxnet RecordIO | |
elif os.path.exists(rec) and os.path.exists(idx): | |
train_set = MXFaceDataset(root_dir=root_dir, local_rank=local_rank) | |
# Image Folder | |
else: | |
transform = transforms.Compose( | |
[ | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
] | |
) | |
train_set = ImageFolder(root_dir, transform) | |
# DALI | |
if dali: | |
return dali_data_iter(batch_size=batch_size, rec_file=rec, idx_file=idx, num_threads=2, local_rank=local_rank) | |
rank, world_size = get_dist_info() | |
train_sampler = DistributedSampler(train_set, num_replicas=world_size, rank=rank, shuffle=True, seed=seed) | |
if seed is None: | |
init_fn = None | |
else: | |
init_fn = partial(worker_init_fn, num_workers=num_workers, rank=rank, seed=seed) | |
train_loader = DataLoaderX( | |
local_rank=local_rank, | |
dataset=train_set, | |
batch_size=batch_size, | |
sampler=train_sampler, | |
num_workers=num_workers, | |
pin_memory=True, | |
drop_last=True, | |
worker_init_fn=init_fn, | |
) | |
return train_loader | |
class BackgroundGenerator(threading.Thread): | |
def __init__(self, generator, local_rank, max_prefetch=6): | |
super(BackgroundGenerator, self).__init__() | |
self.queue = Queue.Queue(max_prefetch) | |
self.generator = generator | |
self.local_rank = local_rank | |
self.daemon = True | |
self.start() | |
def run(self): | |
torch.cuda.set_device(self.local_rank) | |
for item in self.generator: | |
self.queue.put(item) | |
self.queue.put(None) | |
def next(self): | |
next_item = self.queue.get() | |
if next_item is None: | |
raise StopIteration | |
return next_item | |
def __next__(self): | |
return self.next() | |
def __iter__(self): | |
return self | |
class DataLoaderX(DataLoader): | |
def __init__(self, local_rank, **kwargs): | |
super(DataLoaderX, self).__init__(**kwargs) | |
self.stream = torch.cuda.Stream(local_rank) | |
self.local_rank = local_rank | |
def __iter__(self): | |
self.iter = super(DataLoaderX, self).__iter__() | |
self.iter = BackgroundGenerator(self.iter, self.local_rank) | |
self.preload() | |
return self | |
def preload(self): | |
self.batch = next(self.iter, None) | |
if self.batch is None: | |
return None | |
with torch.cuda.stream(self.stream): | |
for k in range(len(self.batch)): | |
self.batch[k] = self.batch[k].to(device=self.local_rank, non_blocking=True) | |
def __next__(self): | |
torch.cuda.current_stream().wait_stream(self.stream) | |
batch = self.batch | |
if batch is None: | |
raise StopIteration | |
self.preload() | |
return batch | |
class MXFaceDataset(Dataset): | |
def __init__(self, root_dir, local_rank): | |
super(MXFaceDataset, self).__init__() | |
self.transform = transforms.Compose( | |
[ | |
transforms.ToPILImage(), | |
transforms.RandomHorizontalFlip(), | |
transforms.ToTensor(), | |
transforms.Normalize(mean=[0.5, 0.5, 0.5], std=[0.5, 0.5, 0.5]), | |
] | |
) | |
self.root_dir = root_dir | |
self.local_rank = local_rank | |
path_imgrec = os.path.join(root_dir, "train.rec") | |
path_imgidx = os.path.join(root_dir, "train.idx") | |
self.imgrec = mx.recordio.MXIndexedRecordIO(path_imgidx, path_imgrec, "r") | |
s = self.imgrec.read_idx(0) | |
header, _ = mx.recordio.unpack(s) | |
if header.flag > 0: | |
self.header0 = (int(header.label[0]), int(header.label[1])) | |
self.imgidx = np.array(range(1, int(header.label[0]))) | |
else: | |
self.imgidx = np.array(list(self.imgrec.keys)) | |
def __getitem__(self, index): | |
idx = self.imgidx[index] | |
s = self.imgrec.read_idx(idx) | |
header, img = mx.recordio.unpack(s) | |
label = header.label | |
if not isinstance(label, numbers.Number): | |
label = label[0] | |
label = torch.tensor(label, dtype=torch.long) | |
sample = mx.image.imdecode(img).asnumpy() | |
if self.transform is not None: | |
sample = self.transform(sample) | |
return sample, label | |
def __len__(self): | |
return len(self.imgidx) | |
class SyntheticDataset(Dataset): | |
def __init__(self): | |
super(SyntheticDataset, self).__init__() | |
img = np.random.randint(0, 255, size=(112, 112, 3), dtype=np.int32) | |
img = np.transpose(img, (2, 0, 1)) | |
img = torch.from_numpy(img).squeeze(0).float() | |
img = ((img / 255) - 0.5) / 0.5 | |
self.img = img | |
self.label = 1 | |
def __getitem__(self, index): | |
return self.img, self.label | |
def __len__(self): | |
return 1000000 | |
def dali_data_iter( | |
batch_size: int, | |
rec_file: str, | |
idx_file: str, | |
num_threads: int, | |
initial_fill=32768, | |
random_shuffle=True, | |
prefetch_queue_depth=1, | |
local_rank=0, | |
name="reader", | |
mean=(127.5, 127.5, 127.5), | |
std=(127.5, 127.5, 127.5), | |
): | |
""" | |
Parameters: | |
---------- | |
initial_fill: int | |
Size of the buffer that is used for shuffling. If random_shuffle is False, this parameter is ignored. | |
""" | |
rank: int = distributed.get_rank() | |
world_size: int = distributed.get_world_size() | |
import nvidia.dali.fn as fn | |
import nvidia.dali.types as types | |
from nvidia.dali.pipeline import Pipeline | |
from nvidia.dali.plugin.pytorch import DALIClassificationIterator | |
pipe = Pipeline( | |
batch_size=batch_size, | |
num_threads=num_threads, | |
device_id=local_rank, | |
prefetch_queue_depth=prefetch_queue_depth, | |
) | |
condition_flip = fn.random.coin_flip(probability=0.5) | |
with pipe: | |
jpegs, labels = fn.readers.mxnet( | |
path=rec_file, | |
index_path=idx_file, | |
initial_fill=initial_fill, | |
num_shards=world_size, | |
shard_id=rank, | |
random_shuffle=random_shuffle, | |
pad_last_batch=False, | |
name=name, | |
) | |
images = fn.decoders.image(jpegs, device="mixed", output_type=types.RGB) | |
images = fn.crop_mirror_normalize(images, dtype=types.FLOAT, mean=mean, std=std, mirror=condition_flip) | |
pipe.set_outputs(images, labels) | |
pipe.build() | |
return DALIWarper( | |
DALIClassificationIterator( | |
pipelines=[pipe], | |
reader_name=name, | |
) | |
) | |
class DALIWarper(object): | |
def __init__(self, dali_iter): | |
self.iter = dali_iter | |
def __next__(self): | |
data_dict = self.iter.__next__()[0] | |
tensor_data = data_dict["data"].cuda() | |
tensor_label: torch.Tensor = data_dict["label"].cuda().long() | |
tensor_label.squeeze_() | |
return tensor_data, tensor_label | |
def __iter__(self): | |
return self | |
def reset(self): | |
self.iter.reset() | |