|
import gc |
|
import os |
|
import logging |
|
from typing import Optional, TypeVar, List, Tuple |
|
|
|
import torch |
|
import torch.distributed as dist |
|
|
|
T = TypeVar("T") |
|
|
|
|
|
log = logging.getLogger(__name__) |
|
|
|
|
|
def seed_all(seed: int): |
|
"""Seed all rng objects.""" |
|
import random |
|
|
|
import numpy as np |
|
|
|
if seed < 0 or seed > 2**32 - 1: |
|
raise ValueError(f"Seed {seed} is invalid. It must be on [0; 2^32 - 1]") |
|
random.seed(seed) |
|
np.random.seed(seed) |
|
torch.manual_seed(seed) |
|
|
|
|
|
torch.cuda.manual_seed_all(seed) |
|
|
|
|
|
def is_distributed() -> bool: |
|
return dist.is_available() and dist.is_initialized() |
|
|
|
|
|
def get_node_rank() -> int: |
|
return int(os.environ.get("NODE_RANK") or (get_global_rank() - get_local_rank()) // get_local_world_size()) |
|
|
|
|
|
def get_world_size() -> int: |
|
if is_distributed(): |
|
return dist.get_world_size() |
|
else: |
|
return 1 |
|
|
|
|
|
def get_local_world_size() -> int: |
|
return int(os.environ.get("LOCAL_WORLD_SIZE") or 1) |
|
|
|
|
|
def get_global_rank() -> int: |
|
if is_distributed(): |
|
return int(os.environ.get("RANK") or dist.get_rank()) |
|
else: |
|
return 0 |
|
|
|
|
|
def get_local_rank() -> int: |
|
return int(os.environ.get("LOCAL_RANK") or 0) |
|
|
|
|
|
def get_fs_local_rank() -> int: |
|
"""Get the local rank per filesystem, meaning that, regardless of the number of nodes, |
|
if all ranks share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_global_rank()`, |
|
but if nodes do not share the same filesystem then `get_fs_local_rank()` will be equivalent to `get_local_rank()`. |
|
""" |
|
if os.environ.get("OLMO_SHARED_FS"): |
|
return int(os.environ.get("FS_LOCAL_RANK") or get_global_rank()) |
|
else: |
|
return int(os.environ.get("FS_LOCAL_RANK") or get_local_rank()) |
|
|
|
|
|
def move_to_device(o: T, device: torch.device) -> T: |
|
if isinstance(o, torch.Tensor): |
|
return o.to(device) |
|
elif isinstance(o, dict): |
|
return {k: move_to_device(v, device) for k, v in o.items()} |
|
elif isinstance(o, list): |
|
return [move_to_device(x, device) for x in o] |
|
elif isinstance(o, tuple): |
|
return tuple((move_to_device(x, device) for x in o)) |
|
else: |
|
return o |
|
|
|
|
|
def ensure_finite_(x: torch.Tensor, check_neg_inf: bool = True, check_pos_inf: bool = False): |
|
""" |
|
Modify ``x`` in place to replace ``float("-inf")`` with the minimum value of the dtype when ``check_neg_inf`` |
|
is ``True`` and to replace ``float("inf")`` with the maximum value of the dtype when ``check_pos_inf`` is ``True``. |
|
""" |
|
if check_neg_inf: |
|
x.masked_fill_(x == float("-inf"), torch.finfo(x.dtype).min) |
|
if check_pos_inf: |
|
x.masked_fill_(x == float("inf"), torch.finfo(x.dtype).max) |
|
|
|
|
|
def get_default_device() -> torch.device: |
|
if torch.cuda.is_available() and torch.cuda.is_initialized(): |
|
return torch.device("cuda") |
|
else: |
|
return torch.device("cpu") |
|
|
|
|
|
def barrier() -> None: |
|
if is_distributed(): |
|
dist.barrier() |
|
|
|
|
|
def peak_gpu_memory(reset: bool = False) -> Optional[float]: |
|
""" |
|
Get the peak GPU memory usage in MB across all ranks. |
|
Only rank 0 will get the final result. |
|
""" |
|
if not torch.cuda.is_available(): |
|
return None |
|
|
|
device = torch.device("cuda") |
|
peak_mb = torch.cuda.max_memory_allocated(device) / 1000000 |
|
if is_distributed(): |
|
peak_mb_tensor = torch.tensor(peak_mb, device=device) |
|
dist.reduce(peak_mb_tensor, 0, dist.ReduceOp.MAX) |
|
peak_mb = peak_mb_tensor.item() |
|
|
|
if reset: |
|
|
|
torch.cuda.reset_max_memory_allocated(device) |
|
|
|
return peak_mb |
|
|
|
|
|
V = TypeVar("V", bool, int, float) |
|
|
|
|
|
def synchronize_value(value: V, device: torch.device) -> V: |
|
if dist.is_available() and dist.is_initialized(): |
|
value_tensor = torch.tensor(value, device=device) |
|
dist.broadcast(value_tensor, 0) |
|
return value_tensor.item() |
|
else: |
|
return value |
|
|
|
|
|
def synchronize_flag(flag: bool, device: torch.device) -> bool: |
|
return synchronize_value(flag, device) |
|
|
|
|
|
def gc_cuda(): |
|
gc.collect() |
|
if torch.cuda.is_available(): |
|
torch.cuda.empty_cache() |
|
|
|
|
|
def listinstr(lst, s, delimiter=None): |
|
assert isinstance(lst, list) |
|
for item in lst: |
|
if delimiter: |
|
if all(x in s for x in item.split(delimiter)): |
|
return True |
|
else: |
|
if item in s: |
|
return True |
|
return False |
|
|
|
|
|
def freeze_module(module: torch.nn.Module, exclude_params: Optional[List[str]] = None): |
|
for name, param in module.named_parameters(): |
|
if exclude_params is not None and listinstr(exclude_params, name): |
|
continue |
|
param.requires_grad = False |
|
|
|
|
|
def freeze_parameters_by_name(model: torch.nn.Module, freeze_names: Tuple[str]): |
|
for name in freeze_names: |
|
try: |
|
module_or_param = model.get_submodule(name) |
|
except: |
|
try: |
|
module_or_param = model.get_parameter(name) |
|
except: |
|
log.warning(f"Could not find module or parameter with name {name}") |
|
if isinstance(module_or_param, torch.nn.Module): |
|
freeze_module(module_or_param) |
|
else: |
|
module_or_param.requires_grad = False |