|
""" |
|
Helpers for distributed training. |
|
""" |
|
|
|
import datetime |
|
import io |
|
import os |
|
import socket |
|
|
|
import blobfile as bf |
|
from pdb import set_trace as st |
|
|
|
import torch as th |
|
import torch.distributed as dist |
|
|
|
|
|
|
|
GPUS_PER_NODE = 8 |
|
SETUP_RETRY_COUNT = 3 |
|
|
|
|
|
def get_rank(): |
|
if not dist.is_available(): |
|
return 0 |
|
|
|
if not dist.is_initialized(): |
|
return 0 |
|
|
|
return dist.get_rank() |
|
|
|
|
|
def synchronize(): |
|
if not dist.is_available(): |
|
return |
|
|
|
if not dist.is_initialized(): |
|
return |
|
|
|
world_size = dist.get_world_size() |
|
|
|
if world_size == 1: |
|
return |
|
|
|
dist.barrier() |
|
|
|
|
|
def get_world_size(): |
|
if not dist.is_available(): |
|
return 1 |
|
|
|
if not dist.is_initialized(): |
|
return 1 |
|
|
|
return dist.get_world_size() |
|
|
|
|
|
def setup_dist(args): |
|
""" |
|
Setup a distributed process group. |
|
""" |
|
if dist.is_initialized(): |
|
return |
|
|
|
|
|
|
|
|
|
|
|
dist.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(seconds=54000)) |
|
print(f"{args.local_rank=} init complete") |
|
|
|
|
|
|
|
th.cuda.empty_cache() |
|
|
|
def cleanup(): |
|
dist.destroy_process_group() |
|
|
|
def dev(): |
|
""" |
|
Get the device to use for torch.distributed. |
|
""" |
|
if th.cuda.is_available(): |
|
|
|
if get_world_size() > 1: |
|
return th.device(f"cuda:{get_rank() % GPUS_PER_NODE}") |
|
return th.device(f"cuda") |
|
|
|
return th.device("cpu") |
|
|
|
|
|
|
|
def load_state_dict(path, **kwargs): |
|
""" |
|
Load a PyTorch file without redundant fetches across MPI ranks. |
|
""" |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ckpt = th.load(path, **kwargs) |
|
|
|
|
|
|
|
|
|
return ckpt |
|
|
|
|
|
def sync_params(params): |
|
""" |
|
Synchronize a sequence of Tensors across ranks from rank 0. |
|
""" |
|
|
|
for p in params: |
|
with th.no_grad(): |
|
try: |
|
dist.broadcast(p, 0) |
|
except Exception as e: |
|
print(k, e) |
|
|
|
|
|
|
|
def _find_free_port(): |
|
try: |
|
s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) |
|
s.bind(("", 0)) |
|
s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) |
|
return s.getsockname()[1] |
|
finally: |
|
s.close() |
|
|
|
|
|
_num_moments = 3 |
|
_reduce_dtype = th.float32 |
|
_counter_dtype = th.float64 |
|
_rank = 0 |
|
_sync_device = None |
|
_sync_called = False |
|
_counters = dict() |
|
_cumulative = dict() |
|
|
|
def init_multiprocessing(rank, sync_device): |
|
r"""Initializes `utils.torch_utils.training_stats` for collecting statistics |
|
across multiple processes. |
|
This function must be called after |
|
`torch.distributed.init_process_group()` and before `Collector.update()`. |
|
The call is not necessary if multi-process collection is not needed. |
|
Args: |
|
rank: Rank of the current process. |
|
sync_device: PyTorch device to use for inter-process |
|
communication, or None to disable multi-process |
|
collection. Typically `torch.device('cuda', rank)`. |
|
""" |
|
global _rank, _sync_device |
|
assert not _sync_called |
|
_rank = rank |
|
_sync_device = sync_device |