import contextlib import functools import logging import os import re from collections import OrderedDict from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, Union import torch from torch.optim.optimizer import Optimizer from torchmetrics import Metric from transformers import AutoTokenizer, PreTrainedTokenizerBase from .llmfoundry import registry from .callbacks import EvalGauntlet from .dataloader import build_dataloader from .tiktoken import TiktokenTokenizerWrapper from .registry_utils import construct_from_registry log = logging.getLogger(__name__) def build_evaluators(eval_loader_config: Optional[Union[DictConfig, ListConfig]], icl_tasks_config: Optional[Union[str, ListConfig]], eval_gauntlet_config: Optional[Union[str, DictConfig]], *, tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]: evaluators = [] if eval_loader_config is not None: evaluators = build_eval_loaders(eval_loader_config, tokenizer, device_eval_batch_size) logger_keys = [] eval_gauntlet_callback = None if icl_tasks_config is not None: icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(icl_tasks_config, eval_gauntlet_config, tokenizer, device_eval_batch_size, icl_seq_len, icl_subset_num_batches) evaluators.extend(icl_evaluators) return (evaluators, logger_keys, eval_gauntlet_callback) def build_eval_loaders(eval_loader_config: Union[DictConfig, ListConfig], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int) -> List[Evaluator]: evaluators: List[Evaluator] = [] if isinstance(eval_loader_config, ListConfig): eval_configs: ListConfig = eval_loader_config is_multi_eval = True else: eval_configs = ListConfig([eval_loader_config]) is_multi_eval = False for eval_config in eval_configs: eval_dataloader = build_dataloader(eval_config, tokenizer, device_eval_batch_size) eval_loader: Evaluator = Evaluator(label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', dataloader=eval_dataloader, metric_names=[]) evaluators.append(eval_loader) return evaluators def add_metrics_to_eval_loaders(evaluators: List[Evaluator], metric_names: List[str]) -> List[Evaluator]: eval_loaders, other_evaluators = ([], []) for evaluator in evaluators: if evaluator.metric_names == []: evaluator.metric_names = metric_names eval_loaders.append(evaluator) else: other_evaluators.append(evaluator) return eval_loaders + other_evaluators def build_icl_data_and_gauntlet(icl_tasks_config: Union[str, ListConfig], eval_gauntlet_config: Optional[Union[str, DictConfig]], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]: icl_evaluators, logger_keys = build_icl_evaluators(icl_tasks_config, tokenizer, icl_seq_len, device_eval_batch_size, icl_subset_num_batches=icl_subset_num_batches) eval_gauntlet_cb = None if eval_gauntlet_config is not None: if isinstance(eval_gauntlet_config, str): with open(eval_gauntlet_config, 'r') as icl_f: eval_gauntlet_cfg = om.load(icl_f) eval_gauntlet = eval_gauntlet_cfg.eval_gauntlet elif isinstance(eval_gauntlet_config, DictConfig): eval_gauntlet = eval_gauntlet_config else: raise ValueError(f'Got invalid type for eval_gauntlet_config: {type(eval_gauntlet_config)}') eval_gauntlet.logger_keys = logger_keys eval_gauntlet.benchmark_sizes = {e.label: e.dataloader.num_samples for e in icl_evaluators} eval_gauntlet_cb = EvalGauntlet(**eval_gauntlet) return (icl_evaluators, logger_keys, eval_gauntlet_cb) def build_composer_model(name: str, cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, init_context: Optional[ContextManager]=None, master_weights_dtype: Optional[str]=None) -> ComposerModel: """Builds a ComposerModel from the registry. Args: name (str): Name of the model to build. cfg (DictConfig): Configuration for the model. tokenizer (PreTrainedTokenizerBase): Tokenizer to use. init_context (Optional[ContextManager], optional): Context manager to use for initialization. Defaults to None. master_weights_dtype (Optional[str], optional): Master weights dtype. Defaults to None. Returns: ComposerModel: _description_ """ if init_context is None: init_context = contextlib.nullcontext() with init_context: model = construct_from_registry(name=name, registry=registry.models, pre_validation_function=ComposerModel, post_validation_function=None, kwargs={'om_model_config': cfg, 'tokenizer': tokenizer}) str_dtype_to_torch_dtype = {'f16': torch.float16, 'float16': torch.float16, 'bf16': torch.bfloat16, 'bfloat16': torch.bfloat16} if master_weights_dtype is not None: if master_weights_dtype not in str_dtype_to_torch_dtype: raise ValueError(f'Invalid master_weights_dtype: {master_weights_dtype}. ' + f'Valid options are: {list(str_dtype_to_torch_dtype.keys())}.') dtype = str_dtype_to_torch_dtype[master_weights_dtype] model = model.to(dtype=dtype) return model def build_callback(name: str, kwargs: Optional[Dict[str, Any]]=None, config: Any=None) -> Callback: """Builds a callback from the registry.""" registry_to_use = registry.callbacks if name in registry.callbacks_with_config: if kwargs is None: kwargs = {} if 'config' in kwargs: raise ValueError(f'`config` is a reserved keyword for callbacks with config. Please remove it from the kwargs.') kwargs['config'] = config registry_to_use = registry.callbacks_with_config return construct_from_registry(name=name, registry=registry_to_use, partial_function=True, pre_validation_function=Callback, post_validation_function=None, kwargs=kwargs) def build_logger(name: str, kwargs: Optional[Dict[str, Any]]=None) -> LoggerDestination: """Builds a logger from the registry.""" return construct_from_registry(name=name, registry=registry.loggers, partial_function=True, pre_validation_function=LoggerDestination, post_validation_function=None, kwargs=kwargs) def build_algorithm(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Algorithm: """Builds an algorithm from the registry.""" return construct_from_registry(name=name, registry=registry.algorithms, partial_function=True, pre_validation_function=Algorithm, post_validation_function=None, kwargs=kwargs) def build_metric(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Metric: """Builds a metric from the registry.""" return construct_from_registry(name=name, registry=registry.metrics, partial_function=True, pre_validation_function=Metric, post_validation_function=None, kwargs=kwargs) def _extract_param_groups(model: torch.nn.Module, optimizer_config: Optional[Dict[str, Any]]=None) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: """Extracts parameter groups defined in the optimizer config. The optimizer_config defines the optimizer args. It can additionally have key `disable_grad` which is a string or list of strings. If a string matches a parameter name, then that parameter will have `requires_grad=False`. This is useful for freezing parameters. It can additionally have a key `param_groups` which is a list of dicts. In this dict, key `param_str_match` defines a string; if a parameter name contains this string, then it will be in this parameter group. This is useful for grouping parameters together. The dict can also contain any other key that is a valid optimizer arg. Note: to handle name overlap conflicts, params are assigned to parameter groups and added to `param_groups` in the order that `param_str_match` appear in `param_groups`. Usage To disable gradient for all parameters that contain the string "norm" or "bias": ``` optimizer_config: { "name": "decoupled_lionw", "lr": 1e-3, "weight_decay": 1e-2, "betas": [0.9, 0.999], "eps": 1e-8, "disable_grad": ["norm", "bias"] } ``` To create and modify the optimizer parameters for all parameters that contain the string "norm" and "bias" separately: ``` optimizer_config: { "name": "decoupled_lionw", "lr": 1e-3, "weight_decay": 1e-2, "betas": [0.9, 0.999], "eps": 1e-8, "param_groups": [ { "param_str_match": "norm", "lr": 1e-4, "weight_decay": 0.0, }, { "param_str_match": "bias", "lr": 5e-4, "weight_decay": 0.0, }, ], } ``` Args: model (torch.nn.Module): model to extract parameters from optimizer_config (Dict[str, Any]): optimizer config Returns: Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of torch.Tensor's or dict's. Specifies what Tensors should be optimized and their param groupings. """ if optimizer_config is None: return model.parameters() if 'disable_grad' in optimizer_config.keys(): str_matches = optimizer_config.pop('disable_grad') if isinstance(str_matches, str): str_matches = [str_matches] for str_match in str_matches: for n, p in model.named_parameters(): if re.search(str_match, n): p.requires_grad = False log.debug(f'Setting `{n}.requires_grad = False`.') param_groups_config = optimizer_config.pop('param_groups', None) if param_groups_config is not None: params = [] param_dict = OrderedDict(((n, p) for n, p in model.named_parameters())) log.debug(f'Default optimizer settings: {optimizer_config}.') for param_group_config in param_groups_config: str_match = param_group_config.pop('param_str_match') filter_fn = functools.partial(re.search, str_match) param_names = [n for n in param_dict.keys() if filter_fn(n)] group_params = {'params': [param_dict.pop(n) for n in param_names]} group_params.update(param_group_config) log.debug(f'Creating optimizer param_group with parameters: {param_names} ' + f'(extracted using str_match={str_match!r}). The param_group optimizer ' + f'setting overrides are: {param_group_config}.') params.append(group_params) params.insert(0, {'params': param_dict.values()}) return params return model.parameters() def build_optimizer(model: torch.nn.Module, name: str, optimizer_config: Optional[Dict[str, Any]]=None) -> Optimizer: params = _extract_param_groups(model, optimizer_config) kwargs = optimizer_config if kwargs is None: kwargs = {} if 'params' in kwargs: raise ValueError('The `params` will be automatically extracted from the model and ' + 'optimizer config. Please remove it from the optimizer config kwargs.') kwargs['params'] = params return construct_from_registry(name=name, registry=registry.optimizers, partial_function=True, pre_validation_function=Optimizer, post_validation_function=None, kwargs=kwargs) def build_scheduler(name: str, scheduler_config: Optional[Dict[str, Any]]=None) -> ComposerScheduler: return construct_from_registry(name=name, registry=registry.schedulers, partial_function=True, pre_validation_function=ComposerScheduler, post_validation_function=None, kwargs=scheduler_config) def build_tokenizer(tokenizer_name: str, tokenizer_kwargs: Dict[str, Any]) -> PreTrainedTokenizerBase: os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1' os.environ['TOKENIZERS_PARALLELISM'] = 'false' signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup' if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1): with dist.local_rank_zero_download_and_wait(signal_file_path): pass if tokenizer_name.startswith('tiktoken'): tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs) else: tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs) tokenizer.model_max_length = tokenizer_kwargs.get('model_max_length', int(1e+30)) if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None: raise ValueError(f'The tokenizer {tokenizer_name} must have an eos_token.') if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1): if dist.get_local_rank() == 0: with open(signal_file_path, 'wb') as f: f.write(b'local_rank0_completed_tokenizer_setup') dist.barrier() if dist.get_local_rank() == 0: os.remove(signal_file_path) return tokenizer def build_icl_evaluators(icl_tasks: Union[str, ListConfig], tokenizer: PreTrainedTokenizerBase, default_max_seq_len: int, default_batch_size: int, destination_dir: Optional[str]=None, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str]]: if destination_dir is None: destination_dir = os.getcwd() evaluators = [] logger_keys = [] icl_tasks_list = None if isinstance(icl_tasks, str): log.info(f'Extracting ICL task config from path: {icl_tasks}') with open(icl_tasks, 'r') as icl_f: icl_task_cfg = om.load(icl_f) icl_tasks_list = icl_task_cfg.icl_tasks else: icl_tasks_list = icl_tasks def _validate_cfg(icl_cfg: DictConfig): assert 'label' in icl_cfg assert 'dataset_uri' in icl_cfg and icl_cfg.dataset_uri is not None assert 'icl_task_type' in icl_cfg assert 'num_fewshot' in icl_cfg if 'metric_names' not in icl_cfg: if icl_cfg.icl_task_type == 'language_modeling': icl_cfg.metric_names = ['InContextLearningLMAccuracy'] elif icl_cfg.icl_task_type == 'multiple_choice': icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy'] elif icl_cfg.icl_task_type == 'schema': icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy'] elif icl_cfg.icl_task_type == 'question_answering': icl_cfg.metric_names = ['InContextLearningQAAccuracy'] elif icl_cfg.icl_task_type == 'code_evaluation': icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy'] else: raise ValueError(f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.') if 'prompt_string' not in icl_cfg: icl_cfg.prompt_string = '' if 'example_delimiter' not in icl_cfg: icl_cfg.example_delimiter = '\n' if 'continuation_delimiter' not in icl_cfg: icl_cfg.continuation_delimiter = ' ' if 'max_seq_len' not in icl_cfg: icl_cfg.max_seq_len = default_max_seq_len if 'batch_size' not in icl_cfg: icl_cfg.batch_size = default_batch_size if 'pass_at_k' not in icl_cfg: icl_cfg.pass_at_k = 1 if 'fewshot_random_seed' not in icl_cfg: icl_cfg.fewshot_random_seed = 1234 if 'generations_per_sample' not in icl_cfg: icl_cfg.generations_per_sample = 1 if 'num_beams' in icl_cfg: raise ValueError('num_beams is no longer supported as a top level icl_task parameter.' + 'Please use generation_kwargs.num_beams instead.') for icl_cfg in icl_tasks_list: assert isinstance(icl_cfg, DictConfig) _validate_cfg(icl_cfg) for num_fewshot in list(icl_cfg.num_fewshot): if tokenizer.pad_token_id is None: pad_tok_id = tokenizer.eos_token_id else: pad_tok_id = tokenizer.pad_token_id label = f'{icl_cfg.label}/{num_fewshot}-shot' metric_names = list(icl_cfg.metric_names) destination_path = f'{destination_dir}/{icl_cfg.label}-{num_fewshot}.jsonl' if dist.get_local_rank() == 0 and os.path.exists(destination_path): os.remove(destination_path) dist.barrier() hf_parsing_map = icl_cfg.get('hf_parsing_map', {}) hf_loading_vars = icl_cfg.get('hf_loading_vars', {}) early_stopping_criteria = icl_cfg.get('early_stopping_criteria', None) if isinstance(early_stopping_criteria, ListConfig): early_stopping_criteria = om.to_container(early_stopping_criteria) assert early_stopping_criteria is None or isinstance(early_stopping_criteria, list) dataloaders = get_icl_task_dataloader(icl_cfg.icl_task_type, icl_cfg.dataset_uri, tokenizer, batch_size=icl_cfg.batch_size, max_seq_len=icl_cfg.max_seq_len, pad_tok_id=pad_tok_id, num_fewshot=num_fewshot, prompt_string=icl_cfg.prompt_string, example_delimiter=icl_cfg.example_delimiter, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, continuation_delimiter=icl_cfg.continuation_delimiter, question_prelimiter=icl_cfg.get('question_prelimiter', ''), destination_path=destination_path, fewshot_random_seed=icl_cfg.fewshot_random_seed, pass_at_k=icl_cfg.pass_at_k, generations_per_sample=icl_cfg.generations_per_sample, has_categories=icl_cfg.get('has_categories', False), cot_delimiter=icl_cfg.get('cot_delimiter', ''), generation_kwargs=icl_cfg.get('generation_kwargs', {}), early_stopping_criteria=early_stopping_criteria, do_normalization=icl_cfg.get('do_normalization', True)) if hasattr(icl_cfg, 'has_categories') and icl_cfg.has_categories and isinstance(dataloaders, dict): for category in dataloaders.keys(): logger_keys.extend([f'metrics/{label}/{category}/{m}' for m in metric_names]) evaluators.append(Evaluator(label=f'{label}/{category}', dataloader=dataloaders[category], metric_names=metric_names)) else: logger_keys.extend([f'metrics/{label}/{m}' for m in metric_names]) evaluators.append(Evaluator(label=label, dataloader=dataloaders, metric_names=metric_names, subset_num_batches=icl_subset_num_batches)) return (evaluators, logger_keys)