""" Helpers for distributed training. """ import datetime import io import os import socket import blobfile as bf from pdb import set_trace as st # from mpi4py import MPI import torch as th import torch.distributed as dist # Change this to reflect your cluster layout. # The GPU for a given rank is (rank % GPUS_PER_NODE). GPUS_PER_NODE = 8 SETUP_RETRY_COUNT = 3 def get_rank(): if not dist.is_available(): return 0 if not dist.is_initialized(): return 0 return dist.get_rank() def synchronize(): if not dist.is_available(): return if not dist.is_initialized(): return world_size = dist.get_world_size() if world_size == 1: return dist.barrier() def get_world_size(): if not dist.is_available(): return 1 if not dist.is_initialized(): return 1 return dist.get_world_size() def setup_dist(args): """ Setup a distributed process group. """ if dist.is_initialized(): return # print(f"{os.environ['MASTER_ADDR']=} {args.master_port=}") # dist.init_process_group(backend='nccl', init_method='env://', rank=args.local_rank, world_size=th.cuda.device_count(), timeout=datetime.timedelta(seconds=5400)) # st() no mark dist.init_process_group(backend='nccl', init_method='env://', timeout=datetime.timedelta(seconds=54000)) print(f"{args.local_rank=} init complete") # synchronize() # extra memory on rank 0, why? th.cuda.empty_cache() def cleanup(): dist.destroy_process_group() def dev(): """ Get the device to use for torch.distributed. """ if th.cuda.is_available(): if get_world_size() > 1: return th.device(f"cuda:{get_rank() % GPUS_PER_NODE}") return th.device(f"cuda") return th.device("cpu") # def load_state_dict(path, submodule_name='', **kwargs): def load_state_dict(path, **kwargs): """ Load a PyTorch file without redundant fetches across MPI ranks. """ # chunk_size = 2 ** 30 # MPI has a relatively small size limit # if get_rank() == 0: # with bf.BlobFile(path, "rb") as f: # data = f.read() # num_chunks = len(data) // chunk_size # if len(data) % chunk_size: # num_chunks += 1 # MPI.COMM_WORLD.bcast(num_chunks) # for i in range(0, len(data), chunk_size): # MPI.COMM_WORLD.bcast(data[i : i + chunk_size]) # else: # num_chunks = MPI.COMM_WORLD.bcast(None) # data = bytes() # for _ in range(num_chunks): # data += MPI.COMM_WORLD.bcast(None) # return th.load(io.BytesIO(data), **kwargs) # with open(path) as f: ckpt = th.load(path, **kwargs) # if submodule_name != '': # assert submodule_name in ckpt # return ckpt[submodule_name] # else: return ckpt def sync_params(params): """ Synchronize a sequence of Tensors across ranks from rank 0. """ # for k, p in params: for p in params: with th.no_grad(): try: dist.broadcast(p, 0) except Exception as e: print(k, e) # print(e) def _find_free_port(): try: s = socket.socket(socket.AF_INET, socket.SOCK_STREAM) s.bind(("", 0)) s.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, 1) return s.getsockname()[1] finally: s.close() _num_moments = 3 # [num_scalars, sum_of_scalars, sum_of_squares] _reduce_dtype = th.float32 # Data type to use for initial per-tensor reduction. _counter_dtype = th.float64 # Data type to use for the internal counters. _rank = 0 # Rank of the current process. _sync_device = None # Device to use for multiprocess communication. None = single-process. _sync_called = False # Has _sync() been called yet? _counters = dict() # Running counters on each device, updated by report(): name => device => torch.Tensor _cumulative = dict() # Cumulative counters on the CPU, updated by _sync(): name => torch.Tensor def init_multiprocessing(rank, sync_device): r"""Initializes `utils.torch_utils.training_stats` for collecting statistics across multiple processes. This function must be called after `torch.distributed.init_process_group()` and before `Collector.update()`. The call is not necessary if multi-process collection is not needed. Args: rank: Rank of the current process. sync_device: PyTorch device to use for inter-process communication, or None to disable multi-process collection. Typically `torch.device('cuda', rank)`. """ global _rank, _sync_device assert not _sync_called _rank = rank _sync_device = sync_device