|
import logging |
|
import math |
|
import multiprocessing |
|
import os |
|
import pickle |
|
import queue |
|
import socket |
|
import time |
|
from concurrent.futures import ProcessPoolExecutor, ThreadPoolExecutor |
|
from multiprocessing.managers import BaseManager |
|
from multiprocessing.shared_memory import SharedMemory |
|
from os.path import exists |
|
from pathlib import Path |
|
from typing import Any, Dict, Iterator, List, Optional, Sequence, Union |
|
|
|
import psutil |
|
import tensorflow as tf |
|
import numpy as np |
|
import torch |
|
import torch.utils.data |
|
import clu |
|
from clu.data.dataset_iterator import Element |
|
|
|
|
|
from .aliases import PathOrStr |
|
from .torch_util import barrier, get_fs_local_rank, get_global_rank, get_world_size, get_node_rank, \ |
|
get_local_world_size, get_local_rank, move_to_device |
|
from .util import roundrobin, threaded_generator |
|
from .data_factory import SeqioDataset |
|
from .multimodal_preprocessor import MultiModalPreprocessor |
|
from .preprocesssors import rename |
|
import torch.distributed as dist |
|
from . import tasks |
|
|
|
__all__ = ["MMIterableDataset"] |
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
def batch_fn(batch, for_inference): |
|
if for_inference: |
|
out = {} |
|
for k, v in batch.items(): |
|
if k.startswith("metadata/"): |
|
out[k] = v |
|
else: |
|
out[k] = torch.from_numpy(v) |
|
return out |
|
else: |
|
out = {k: torch.from_numpy(v) for k, v in batch.items() if not k.startswith("metadata/")} |
|
out["metadata"] = [{} for _ in out["input_ids"]] |
|
return out |
|
|
|
|
|
class PyTorchDatasetIterator(clu.data.dataset_iterator.TfDatasetIterator): |
|
def __init__(self, dataset, *, checkpoint: bool, for_inference: bool): |
|
self.for_inference = for_inference |
|
super().__init__(dataset, checkpoint=checkpoint) |
|
|
|
def __next__(self) -> Element: |
|
batch = {k: v.numpy() for k, v in next(self.iterator).items()} |
|
return batch_fn(batch, self.for_inference) |
|
|
|
def __len__(self) -> int: |
|
return len(self._dataset) |
|
|
|
|
|
class MMIterableDataset(torch.utils.data.IterableDataset[Dict[str, Any]]): |
|
def __init__( |
|
self, |
|
dataset: SeqioDataset, |
|
preprocessor: MultiModalPreprocessor, |
|
world_size: Optional[int] = None, |
|
rank: Optional[int] = None, |
|
): |
|
self.preprocessor = preprocessor |
|
self.rank = rank if rank is not None else get_global_rank() |
|
self.world_size = world_size if world_size is not None else get_world_size() |
|
self.dataset_config = dataset |
|
|
|
data_iter = dataset.build( |
|
self.preprocessor, |
|
self.rank, |
|
self.world_size, |
|
) |
|
|
|
data_iter: tf.data.Dataset = rename(input_ids="input_tokens", labels="target_tokens")(data_iter) |
|
self.dataset = data_iter |
|
self.data_iter = PyTorchDatasetIterator( |
|
data_iter, checkpoint=True, for_inference=dataset.for_inference) |
|
|
|
def reset(self): |
|
self.data_iter.reset() |
|
|
|
def save(self, filename: PathOrStr): |
|
self.data_iter.save(filename) |
|
|
|
def restore(self, filename: PathOrStr): |
|
self.data_iter.restore(filename) |
|
|
|
def __iter__(self) -> Iterator[Dict[str, Any]]: |
|
return self.data_iter |
|
|
|
|
|
def _split_batch(batch, n): |
|
subbatches = [{} for _ in range(n)] |
|
for k, v in batch.items(): |
|
assert len(v) % n == 0, f"n={n} but {k} has {len(v)}" |
|
subatch_dim = len(v) // n |
|
for i, subbatch in enumerate(subbatches): |
|
subbatch[k] = v[i * subatch_dim:(i + 1) * subatch_dim] |
|
return subbatches |
|
|
|
|
|
def tf_to_torch_dtype(tf_dtype): |
|
dtype_mapping = { |
|
tf.float16: torch.float16, |
|
tf.float32: torch.float32, |
|
tf.float64: torch.float64, |
|
tf.int8: torch.int8, |
|
tf.uint8: torch.uint8, |
|
tf.int16: torch.int16, |
|
tf.int32: torch.int32, |
|
tf.int64: torch.int64, |
|
tf.bool: torch.bool, |
|
} |
|
return dtype_mapping[tf_dtype] |
|
|
|
|
|
class PeerToPeer(torch.utils.data.IterableDataset[Dict[str, Any]]): |
|
""" |
|
This dataloader runs the tf.data.Dataset on one processes per a node, and then |
|
transfers the batch to the other processes. For 7B model about a 10% performance |
|
despite my attempts to make it asynchronous |
|
|
|
The advantage is that it avoids the overhead of running multiple tf.data.Dataset |
|
in one node |
|
""" |
|
|
|
def __init__( |
|
self, |
|
dataset: SeqioDataset, |
|
preprocessor: MultiModalPreprocessor, |
|
world_size: Optional[int] = None, |
|
rank: Optional[int] = None, |
|
device=None |
|
): |
|
assert get_world_size() % get_local_world_size() == 0 |
|
self.device = device |
|
self.device_batch_size = dataset.global_batch_size // get_world_size() |
|
|
|
self.preprocessor = preprocessor |
|
self.seqio_dataset = dataset |
|
|
|
lws = get_local_world_size() |
|
|
|
if get_local_rank() == 0: |
|
tf_dataset = dataset.build( |
|
self.preprocessor, |
|
get_node_rank(), |
|
get_world_size() // lws, |
|
) |
|
|
|
tf_dataset = rename(input_ids="input_tokens", labels="target_tokens")(tf_dataset) |
|
self.dataset = tf_dataset |
|
device_spec = {k: ((v.shape[0]//lws,) + tuple(v.shape[1:]), tf_to_torch_dtype(v.dtype)) |
|
for k, v in tf_dataset.element_spec.items()} |
|
else: |
|
self.dataset = None |
|
device_spec = None |
|
|
|
broadcast = [device_spec] |
|
torch.distributed.broadcast_object_list(broadcast) |
|
self.device_spec = broadcast[0] |
|
|
|
self._node_group_ranks = ranks = [(i + get_node_rank()*lws) for i in range(lws)] |
|
if get_local_rank() == 0: |
|
assert get_global_rank() == self._node_group_ranks[0] |
|
self._keys = sorted(self.device_spec) |
|
self.multithread_pin = False |
|
|
|
def _pin(self, it, on): |
|
batch = next(it) |
|
batch = {k: torch.from_numpy(v) for k, v in batch.items()} |
|
batch = _split_batch(batch, len(self._node_group_ranks)) |
|
return [{k: v.pin_memory() for k, v in subbatch.items()} for subbatch in batch] |
|
|
|
def _send_pinned(self, batch): |
|
requests = [] |
|
for rank_ix, rank in enumerate(self._node_group_ranks[1:], start=1): |
|
for k in self._keys: |
|
batch[rank_ix][k] = batch[rank_ix][k].to(self.device, non_blocking=True) |
|
requests.append(dist.P2POp(dist.isend, batch[rank_ix][k], rank)) |
|
ops = dist.batch_isend_irecv(requests) |
|
return batch[0], ops |
|
|
|
def _send(self, it, on): |
|
if get_local_rank() == 0: |
|
try: |
|
batch = next(it) |
|
batch = {k: torch.from_numpy(v) for k, v in batch.items()} |
|
batch = _split_batch(batch, len(self._node_group_ranks)) |
|
except StopIteration: |
|
|
|
batch = [ |
|
{k: torch.full(sh, -10, dtype=dtype, device=self.device) |
|
for k, (sh, dtype) in self.device_spec.items()} |
|
for _ in range(len(self._node_group_ranks)) |
|
] |
|
|
|
|
|
batch = [{k: v.pin_memory() for k, v in subbatch.items()} |
|
for subbatch in batch] |
|
|
|
requests = [] |
|
for rank_ix, rank in enumerate(self._node_group_ranks[1:], start=1): |
|
for k in self._keys: |
|
batch[rank_ix][k] = batch[rank_ix][k].to(self.device, non_blocking=True) |
|
requests.append(dist.P2POp(dist.isend, batch[rank_ix][k], rank)) |
|
ops = dist.batch_isend_irecv(requests) |
|
batch = batch[0] |
|
else: |
|
batch = {k: torch.zeros(sh, dtype=dtype, device=self.device) |
|
for k, (sh, dtype) in self.device_spec.items()} |
|
requests = [] |
|
for k in self._keys: |
|
requests.append(dist.P2POp(dist.irecv, batch[k], self._node_group_ranks[0])) |
|
ops = dist.batch_isend_irecv(requests) |
|
return batch, ops |
|
|
|
def __iter__(self): |
|
on = 0 |
|
if get_local_rank() == 0: |
|
it = iter(self.dataset.as_numpy_iterator()) |
|
else: |
|
it = None |
|
|
|
if get_local_rank() == 0 and self.multithread_pin: |
|
|
|
|
|
|
|
with ThreadPoolExecutor(max_workers=1) as pool: |
|
_is_sending = self._send_pinned(self._pin(it, on)) |
|
_is_pinning = pool.submit(self._pin, it, on) |
|
on += 1 |
|
while True: |
|
result = _is_sending |
|
_is_sending = self._send_pinned(_is_pinning.result()) |
|
_is_pinning = pool.submit(self._pin, it, on) |
|
on += 1 |
|
for op in result[1]: |
|
op.wait() |
|
yield result[0] |
|
else: |
|
_in_flight = self._send(it, on) |
|
on += 1 |
|
while True: |
|
on += 1 |
|
next_batch = self._send(it, on) |
|
for op in _in_flight[1]: |
|
op.wait() |
|
if _in_flight["input_ids"][0] != -10: |
|
return |
|
yield _in_flight[0] |
|
_in_flight = next_batch |
|
|
|
|