mpt-7b-8k-chat / config_utils.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
fdb2891 verified
raw
history blame
5.69 kB
import contextlib
import logging
import math
import warnings
from typing import Any, Dict, Literal, Mapping, Optional, Tuple, Union
from .utils import init_empty_weights
log = logging.getLogger(__name__)
def pop_config(cfg: DictConfig, key: str, must_exist: bool=True, default_value: Any=None, convert: bool=False) -> Any:
"""Pop a value from the main config file and return it.
If the key does not exist, return the default_value or raise a RuntimeError
depending on the must_exist flag. If the convert flag is set to True, then
we will convert the value to a python object using OmegaConf.to_container.
"""
value = cfg.pop(key, None)
if value is not None and convert:
if not isinstance(value, DictConfig) and (not isinstance(value, ListConfig)):
raise ValueError(f'The key {key} has a value of type {type(value)} that cannot be converted to a dict or list. Please check your yaml.')
return om.to_container(value)
elif value is not None:
return value
elif must_exist:
raise NameError(f'The {key} parameter is missing and must exist for execution. Please check your yaml.')
else:
return default_value
def calculate_batch_size_info(global_batch_size: int, device_microbatch_size: Union[int, Literal['auto']]) -> Tuple[int, Union[int, Literal['auto']], Union[int, Literal['auto']]]:
if global_batch_size % dist.get_world_size() != 0:
raise ValueError(f'Global batch size {global_batch_size} is not divisible by {dist.get_world_size()} ' + 'as a result, the batch size would be truncated, please adjust `global_batch_size` ' + f'to be divisible by world size, {dist.get_world_size()}.')
device_batch_size = global_batch_size // dist.get_world_size()
if device_microbatch_size == 'auto':
device_grad_accum = 'auto'
elif isinstance(device_microbatch_size, int):
if device_microbatch_size > device_batch_size:
log.warn(f'device_microbatch_size > device_batch_size, ' + f'will be reduced from {device_microbatch_size} -> {device_batch_size}.')
device_microbatch_size = device_batch_size
device_grad_accum = math.ceil(device_batch_size / device_microbatch_size)
else:
raise ValueError(f'Not sure how to parse device_microbatch_size={device_microbatch_size!r}')
return (device_batch_size, device_microbatch_size, device_grad_accum)
def update_batch_size_info(cfg: DictConfig) -> DictConfig:
device_train_batch_size, device_train_microbatch_size, device_train_grad_accum = calculate_batch_size_info(cfg.global_train_batch_size, cfg.device_train_microbatch_size)
cfg.n_gpus = dist.get_world_size()
cfg.device_train_batch_size = device_train_batch_size
cfg.device_train_microbatch_size = device_train_microbatch_size
cfg.device_train_grad_accum = device_train_grad_accum
if 'device_eval_batch_size' not in cfg:
if cfg.device_train_microbatch_size == 'auto':
cfg.device_eval_batch_size = 1
else:
cfg.device_eval_batch_size = cfg.device_train_microbatch_size
return cfg
def process_init_device(model_cfg: DictConfig, fsdp_config: Optional[Dict]):
init_context = contextlib.nullcontext()
if 'init_device' in model_cfg:
assert model_cfg.init_device in ['meta', 'cpu', 'mixed']
if fsdp_config is None and model_cfg.init_device == 'meta':
warnings.warn("Using `cfg.model.init_device='meta'` is only valid when using FSDP! " + "Reverting to `cfg.model.init_device='cpu'`.")
model_cfg.init_device = 'cpu'
if model_cfg.init_device == 'meta':
init_context = init_empty_weights()
if model_cfg.init_device == 'mixed':
if fsdp_config is None:
raise NotImplementedError('Using init_device `mixed` is only supported with FSDP. ' + 'Please add a FSDP config.')
if not fsdp_config.get('sync_module_states', False):
warnings.warn('Setting `sync_module_states = True` for FSDP. This is required when using mixed initialization.')
fsdp_config['sync_module_states'] = True
fsdp_config.setdefault('use_orig_params', False)
fsdp_config.setdefault('load_monolith_rank0_only', True)
master_dtype = model_cfg.get('master_weights_dtype')
small_dtypes = ('bf16', 'fp16', 'float16', 'bfloat16', 'amp_fp16', 'amp_bf16')
if fsdp_config and master_dtype in small_dtypes:
reduce_dtype = None
buffer_dtype = None
mixed_precision = fsdp_config.get('mixed_precision')
if isinstance(mixed_precision, Mapping):
reduce_dtype = mixed_precision.get('reduce_dtype')
buffer_dtype = mixed_precision.get('buffer_dtype')
fsdp_config['mixed_precision'] = {'param_dtype': None, 'reduce_dtype': reduce_dtype, 'buffer_dtype': buffer_dtype, 'keep_low_precision_grads': True}
return init_context
def log_config(cfg: DictConfig) -> None:
"""Logs the current config and updates the wandb and mlflow configs.
This function can be called multiple times to update the wandb and MLflow
config with different variables.
"""
print(om.to_yaml(cfg))
if 'wandb' in cfg.get('loggers', {}):
try:
import wandb
except ImportError as e:
raise e
if wandb.run:
wandb.config.update(om.to_container(cfg, resolve=True))
if 'mlflow' in cfg.get('loggers', {}):
try:
import mlflow
except ImportError as e:
raise e
if mlflow.active_run():
mlflow.log_params(params=om.to_container(cfg, resolve=True))