# Very loosely inspired by indexed_dataset in Fairseq, Megatron # https://github.com/NVIDIA/Megatron-LM/blob/main/megatron/data/indexed_dataset.py import os import struct import random import numpy as np import torch from torch.utils.data import IterableDataset, get_worker_info dtypes = { 1: np.uint8, 2: np.int8, 3: np.int16, 4: np.int32, 5: np.int64, 6: np.float32, 7: np.float64, 8: np.uint16, } def code(dtype): for k in dtypes.keys(): if dtypes[k] == dtype: return k raise ValueError(dtype) HDR_MAGIC = b"LITPKDS" HDR_SIZE = 24 # bytes class PackedDataset(IterableDataset): def __init__(self, filenames, n_chunks, block_size, seed=12345, shuffle=True, wrap=False, num_processes=1, process_rank=0): self._filenames = filenames self._n_chunks = n_chunks self._block_size = block_size self._seed = seed self._shuffle = shuffle self._wrap = wrap self._num_processes = num_processes self._process_rank = process_rank def __iter__(self): worker_info = get_worker_info() num_workers = worker_info.num_workers if worker_info is not None else 1 worker_id = worker_info.id if worker_info is not None else 0 num_shards = num_workers * self._num_processes shard_id = self._process_rank * num_workers + worker_id max_num_files = len(self._filenames) // num_shards * num_shards filenames = self._filenames[shard_id : max_num_files : num_shards] return PackedDatasetIterator( filenames=filenames, n_chunks=self._n_chunks, block_size=self._block_size, seed=self._seed, shuffle=self._shuffle, wrap=self._wrap, ) class PackedDatasetBuilder(object): def __init__( self, outdir, prefix, chunk_size, sep_token, dtype="auto", vocab_size=None, ): if dtype == "auto": if vocab_size is None: raise ValueError("vocab_size cannot be None when dtype='auto'") if vocab_size is not None and vocab_size < 65500: self._dtype = np.uint16 else: self._dtype = np.int32 else: self._dtype = dtype self._counter = 0 self._chunk_size = chunk_size self._outdir = outdir self._prefix = prefix self._sep_token = sep_token self._arr = np.zeros(self._chunk_size, dtype=self._dtype) self._arr.fill(self._sep_token) self._idx = 0 self._version = 1 self._filenames = [] def _write_chunk(self): filename = f"{self._prefix}_{self._counter:010d}.bin" filename = os.path.join(self._outdir, filename) with open(filename, "wb") as f: f.write(HDR_MAGIC) f.write(struct.pack(" self._chunk_size: part_len = self._chunk_size - self._idx self._arr[self._idx : self._idx + part_len] = arr[:part_len] self._write_chunk() arr = arr[part_len:] arr_len = arr.shape[0] self._arr[self._idx : self._idx + arr_len] = arr self._idx += arr_len def write_reminder(self): self._write_chunk() class PackedDatasetIterator: def __init__(self, filenames, n_chunks, block_size, seed, shuffle, wrap): self._seed = seed self._shuffle = shuffle self._rng = np.random.default_rng(seed) if shuffle else None self._block_idxs = None self._wrap = wrap # TODO: instead of filenames, we could have a single text stream # (or text file) with the sequence of all files to be # fetched/loaded. self._filenames = filenames self._file_idx = 0 self._n_chunks = n_chunks self._dtype = None self._block_size = block_size self._n_blocks = None self._mmaps = [] self._buffers = [] self._block_idxs = [] self._curr_idx = 0 self._load_n_chunks() def _read_header(self, path): with open(path, "rb") as f: magic = f.read(len(HDR_MAGIC)) assert magic == HDR_MAGIC, "File doesn't match expected format." version = struct.unpack(" len(self._filenames[self._file_idx:]): if not self._wrap: raise StopIteration else: self._file_idx = 0 for i in range(self._n_chunks): filename = self._filenames[self._file_idx + i] if self._dtype is None: self._dtype, self._chunk_size = self._read_header( filename ) self._n_blocks = self._chunk_size // self._block_size # TODO: check header matches with previous files mmap = np.memmap(filename, mode="r", order="C", offset=HDR_SIZE) self._mmaps.append(mmap) self._buffers.append(memoryview(mmap)) self._file_idx += self._n_chunks n_all_blocks = self._n_chunks * self._n_blocks self._block_idxs = ( self._rng.permutation(n_all_blocks) if self._shuffle else range(n_all_blocks) ) self._curr_idx = 0 def __del__(self): self._close_mmaps() del self._mmaps del self._buffers def __iter__(self): return self def __next__(self): if self._curr_idx >= len(self._block_idxs): self._load_n_chunks() # TODO: trigger fetching next next n_chunks if remote block_idx = self._block_idxs[self._curr_idx] chunk_id = block_idx // self._n_blocks buffer = self._buffers[chunk_id] elem_id = (block_idx % self._n_blocks) * self._block_size offset = np.dtype(self._dtype).itemsize * elem_id arr = np.frombuffer( buffer, dtype=self._dtype, count=self._block_size, offset=offset ) self._curr_idx += 1 return torch.from_numpy(arr.astype(np.int64)) class CombinedDataset(IterableDataset): def __init__(self, datasets, seed, weights=None): self._seed = seed self._datasets = datasets self._weights = weights n_datasets = len(datasets) if weights is None: self._weights = [1 / n_datasets] * n_datasets def __iter__(self): return CombinedDatasetIterator(self._datasets, self._seed, self._weights) class CombinedDatasetIterator: def __init__(self, datasets, seed, weights): self._datasets = [iter(el) for el in datasets] self._weights = weights self._rng = random.Random(seed) def __next__(self): dataset, = self._rng.choices(self._datasets, weights=self._weights, k=1) return next(dataset)