import logging import math import multiprocessing import os import pickle import queue import socket import time from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor from multiprocessing.managers import BaseManager from multiprocessing.shared_memory import SharedMemory from os.path import exists from pathlib import Path from typing import Any, Dict, Iterator, List, Optional, Sequence, Union import psutil import tensorflow as tf import numpy as np import torch import torch.utils.data import clu from clu.data.dataset_iterator import Element from .aliases import PathOrStr from .torch_util import barrier, get_fs_local_rank, get_global_rank, get_world_size, get_node_rank, \ get_local_world_size, get_local_rank, move_to_device from .util import roundrobin, threaded_generator from .data_factory import SeqioDataset from .multimodal_preprocessor import MultiModalPreprocessor from .preprocesssors import rename import torch.distributed as dist from . import tasks __all__ = ["MMIterableDataset"] log = logging.getLogger(__name__) def batch_fn(batch, for_inference): if for_inference: out = {} for k, v in batch.items(): if k.startswith("metadata/"): out[k] = v else: out[k] = torch.from_numpy(v) return out else: out = {k: torch.from_numpy(v) for k, v in batch.items() if not k.startswith("metadata/")} out["metadata"] = [{} for _ in out["input_ids"]] return out class PyTorchDatasetIterator(clu.data.dataset_iterator.TfDatasetIterator): def __init__(self, dataset, *, checkpoint: bool, for_inference: bool): self.for_inference = for_inference super().__init__(dataset, checkpoint=checkpoint) def __next__(self) -> Element: batch = {k: v.numpy() for k, v in next(self.iterator).items()} return batch_fn(batch, self.for_inference) def __len__(self) -> int: return len(self._dataset) class MMIterableDataset(torch.utils.data.IterableDataset[Dict[str, Any]]): def __init__( self, dataset: SeqioDataset, preprocessor: MultiModalPreprocessor, world_size: Optional[int] = None, rank: Optional[int] = None, ): self.preprocessor = preprocessor self.rank = rank if rank is not None else get_global_rank() self.world_size = world_size if world_size is not None else get_world_size() self.dataset_config = dataset data_iter = dataset.build( self.preprocessor, self.rank, self.world_size, ) data_iter: tf.data.Dataset = rename(input_ids="input_tokens", labels="target_tokens")(data_iter) self.dataset = data_iter self.data_iter = PyTorchDatasetIterator( data_iter, checkpoint=True, for_inference=dataset.for_inference) def reset(self): self.data_iter.reset() def save(self, filename: PathOrStr): self.data_iter.save(filename) def restore(self, filename: PathOrStr): self.data_iter.restore(filename) def __iter__(self) -> Iterator[Dict[str, Any]]: return self.data_iter def _split_batch(batch, n): subbatches = [{} for _ in range(n)] for k, v in batch.items(): assert len(v) % n == 0, f"n={n} but {k} has {len(v)}" subatch_dim = len(v) // n for i, subbatch in enumerate(subbatches): subbatch[k] = v[i * subatch_dim:(i + 1) * subatch_dim] return subbatches def tf_to_torch_dtype(tf_dtype): dtype_mapping = { tf.float16: torch.float16, tf.float32: torch.float32, tf.float64: torch.float64, tf.int8: torch.int8, tf.uint8: torch.uint8, tf.int16: torch.int16, tf.int32: torch.int32, tf.int64: torch.int64, tf.bool: torch.bool, } return dtype_mapping[tf_dtype] class PeerToPeer(torch.utils.data.IterableDataset[Dict[str, Any]]): """ This dataloader runs the tf.data.Dataset on one processes per a node, and then transfers the batch to the other processes. For 7B model about a 10% performance despite my attempts to make it asynchronous The advantage is that it avoids the overhead of running multiple tf.data.Dataset in one node """ def __init__( self, dataset: SeqioDataset, preprocessor: MultiModalPreprocessor, world_size: Optional[int] = None, rank: Optional[int] = None, device=None ): assert get_world_size() % get_local_world_size() == 0 self.device = device self.device_batch_size = dataset.global_batch_size // get_world_size() self.preprocessor = preprocessor self.seqio_dataset = dataset lws = get_local_world_size() if get_local_rank() == 0: tf_dataset = dataset.build( self.preprocessor, get_node_rank(), get_world_size() // lws, ) tf_dataset = rename(input_ids="input_tokens", labels="target_tokens")(tf_dataset) self.dataset = tf_dataset device_spec = {k: ((v.shape[0]//lws,) + tuple(v.shape[1:]), tf_to_torch_dtype(v.dtype)) for k, v in tf_dataset.element_spec.items()} else: self.dataset = None device_spec = None broadcast = [device_spec] torch.distributed.broadcast_object_list(broadcast) self.device_spec = broadcast[0] self._node_group_ranks = ranks = [(i + get_node_rank()*lws) for i in range(lws)] if get_local_rank() == 0: assert get_global_rank() == self._node_group_ranks[0] self._keys = sorted(self.device_spec) self.multithread_pin = False def _pin(self, it, on): batch = next(it) batch = {k: torch.from_numpy(v) for k, v in batch.items()} batch = _split_batch(batch, len(self._node_group_ranks)) return [{k: v.pin_memory() for k, v in subbatch.items()} for subbatch in batch] def _send_pinned(self, batch): requests = [] for rank_ix, rank in enumerate(self._node_group_ranks[1:], start=1): for k in self._keys: batch[rank_ix][k] = batch[rank_ix][k].to(self.device, non_blocking=True) requests.append(dist.P2POp(dist.isend, batch[rank_ix][k], rank)) ops = dist.batch_isend_irecv(requests) return batch[0], ops def _send(self, it, on): if get_local_rank() == 0: try: batch = next(it) batch = {k: torch.from_numpy(v) for k, v in batch.items()} batch = _split_batch(batch, len(self._node_group_ranks)) except StopIteration: # Special batch to indicate iteration is done batch = [ {k: torch.full(sh, -10, dtype=dtype, device=self.device) for k, (sh, dtype) in self.device_spec.items()} for _ in range(len(self._node_group_ranks)) ] # pin_memory so the device transfer can be non_blocking batch = [{k: v.pin_memory() for k, v in subbatch.items()} for subbatch in batch] requests = [] for rank_ix, rank in enumerate(self._node_group_ranks[1:], start=1): for k in self._keys: batch[rank_ix][k] = batch[rank_ix][k].to(self.device, non_blocking=True) requests.append(dist.P2POp(dist.isend, batch[rank_ix][k], rank)) ops = dist.batch_isend_irecv(requests) batch = batch[0] else: batch = {k: torch.zeros(sh, dtype=dtype, device=self.device) for k, (sh, dtype) in self.device_spec.items()} requests = [] for k in self._keys: requests.append(dist.P2POp(dist.irecv, batch[k], self._node_group_ranks[0])) ops = dist.batch_isend_irecv(requests) return batch, ops def __iter__(self): on = 0 if get_local_rank() == 0: it = iter(self.dataset.as_numpy_iterator()) else: it = None if get_local_rank() == 0 and self.multithread_pin: # Try to be clever and do memory pinning in a seperate thread, in practice # didn't seem to help much so off by default for now # Currently does not support finite dataset with ThreadPoolExecutor(max_workers=1) as pool: _is_sending = self._send_pinned(self._pin(it, on)) _is_pinning = pool.submit(self._pin, it, on) on += 1 while True: result = _is_sending _is_sending = self._send_pinned(_is_pinning.result()) _is_pinning = pool.submit(self._pin, it, on) on += 1 for op in result[1]: op.wait() yield result[0] else: _in_flight = self._send(it, on) on += 1 while True: on += 1 next_batch = self._send(it, on) # queue up the next batch for op in _in_flight[1]: # wait for the current batch op.wait() if _in_flight["input_ids"][0] != -10: # indicates no more data return yield _in_flight[0] _in_flight = next_batch