import bisect import random from typing import Iterable from torch.utils.data import Dataset, IterableDataset class ConcatRepeatDataset(Dataset): datasets: list[Dataset] cumulative_sizes: list[int] repeats: list[int] @staticmethod def cumsum(sequence, repeats): r, s = [], 0 for dataset, repeat in zip(sequence, repeats): l = len(dataset) * repeat r.append(l + s) s += l return r def __init__(self, datasets: Iterable[Dataset], repeats: list[int]): super().__init__() self.datasets = list(datasets) self.repeats = repeats assert len(self.datasets) > 0, "datasets should not be an empty iterable" assert len(self.datasets) == len( repeats ), "datasets and repeats should have the same length" for d in self.datasets: assert not isinstance( d, IterableDataset ), "ConcatRepeatDataset does not support IterableDataset" self.cumulative_sizes = self.cumsum(self.datasets, self.repeats) def __len__(self): return self.cumulative_sizes[-1] def __getitem__(self, idx): dataset_idx = bisect.bisect_right(self.cumulative_sizes, idx) if dataset_idx == 0: sample_idx = idx else: sample_idx = idx - self.cumulative_sizes[dataset_idx - 1] dataset = self.datasets[dataset_idx] return dataset[sample_idx % len(dataset)]