|
|
|
|
|
from torch.utils.data import BatchSampler, DataLoader, IterableDataset |
|
|
|
|
|
_PYTORCH_DATALOADER_KWARGS = { |
|
"batch_size": 1, |
|
"shuffle": False, |
|
"sampler": None, |
|
"batch_sampler": None, |
|
"num_workers": 0, |
|
"collate_fn": None, |
|
"pin_memory": False, |
|
"drop_last": False, |
|
"timeout": 0, |
|
"worker_init_fn": None, |
|
"multiprocessing_context": None, |
|
"generator": None, |
|
"prefetch_factor": 2, |
|
"persistent_workers": False, |
|
} |
|
|
|
|
|
class SkipBatchSampler(BatchSampler): |
|
""" |
|
A `torch.utils.data.BatchSampler` that skips the first `n` batches of another `torch.utils.data.BatchSampler`. |
|
""" |
|
|
|
def __init__(self, batch_sampler, skip_batches=0): |
|
self.batch_sampler = batch_sampler |
|
self.skip_batches = skip_batches |
|
|
|
def __iter__(self): |
|
for index, samples in enumerate(self.batch_sampler): |
|
if index >= self.skip_batches: |
|
yield samples |
|
|
|
@property |
|
def total_length(self): |
|
return len(self.batch_sampler) |
|
|
|
def __len__(self): |
|
return len(self.batch_sampler) - self.skip_batches |
|
|
|
|
|
class SkipDataLoader(DataLoader): |
|
""" |
|
Subclass of a PyTorch `DataLoader` that will skip the first batches. |
|
|
|
Args: |
|
dataset (`torch.utils.data.dataset.Dataset`): |
|
The dataset to use to build this datalaoder. |
|
skip_batches (`int`, *optional*, defaults to 0): |
|
The number of batches to skip at the beginning. |
|
kwargs: |
|
All other keyword arguments to pass to the regular `DataLoader` initialization. |
|
""" |
|
|
|
def __init__(self, dataset, skip_batches=0, **kwargs): |
|
super().__init__(dataset, **kwargs) |
|
self.skip_batches = skip_batches |
|
|
|
def __iter__(self): |
|
for index, batch in enumerate(super().__iter__()): |
|
if index >= self.skip_batches: |
|
yield batch |
|
|
|
|
|
|
|
def skip_first_batches(dataloader, num_batches=0): |
|
""" |
|
Creates a `torch.utils.data.DataLoader` that will efficiently skip the first `num_batches`. |
|
""" |
|
dataset = dataloader.dataset |
|
sampler_is_batch_sampler = False |
|
if isinstance(dataset, IterableDataset): |
|
new_batch_sampler = None |
|
else: |
|
sampler_is_batch_sampler = isinstance(dataloader.sampler, BatchSampler) |
|
batch_sampler = ( |
|
dataloader.sampler if sampler_is_batch_sampler else dataloader.batch_sampler |
|
) |
|
new_batch_sampler = SkipBatchSampler(batch_sampler, skip_batches=num_batches) |
|
|
|
|
|
ignore_kwargs = [ |
|
"batch_size", |
|
"shuffle", |
|
"sampler", |
|
"batch_sampler", |
|
"drop_last", |
|
] |
|
|
|
kwargs = { |
|
k: getattr(dataloader, k, _PYTORCH_DATALOADER_KWARGS[k]) |
|
for k in _PYTORCH_DATALOADER_KWARGS |
|
if k not in ignore_kwargs |
|
} |
|
|
|
|
|
if new_batch_sampler is None: |
|
kwargs["drop_last"] = dataloader.drop_last |
|
kwargs["batch_size"] = dataloader.batch_size |
|
|
|
if new_batch_sampler is None: |
|
|
|
dataloader = SkipDataLoader(dataset, skip_batches=num_batches, **kwargs) |
|
else: |
|
dataloader = DataLoader(dataset, batch_sampler=new_batch_sampler, **kwargs) |
|
|
|
return dataloader |
|
|