File size: 18,717 Bytes
fdb2891 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 315 316 317 318 319 320 321 322 323 |
import contextlib
import functools
import logging
import os
import re
from collections import OrderedDict
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from .llmfoundry import registry
from .callbacks import EvalGauntlet
from .dataloader import build_dataloader
from .tiktoken import TiktokenTokenizerWrapper
from .registry_utils import construct_from_registry
log = logging.getLogger(__name__)
def build_evaluators(eval_loader_config: Optional[Union[DictConfig, ListConfig]], icl_tasks_config: Optional[Union[str, ListConfig]], eval_gauntlet_config: Optional[Union[str, DictConfig]], *, tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
evaluators = []
if eval_loader_config is not None:
evaluators = build_eval_loaders(eval_loader_config, tokenizer, device_eval_batch_size)
logger_keys = []
eval_gauntlet_callback = None
if icl_tasks_config is not None:
icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(icl_tasks_config, eval_gauntlet_config, tokenizer, device_eval_batch_size, icl_seq_len, icl_subset_num_batches)
evaluators.extend(icl_evaluators)
return (evaluators, logger_keys, eval_gauntlet_callback)
def build_eval_loaders(eval_loader_config: Union[DictConfig, ListConfig], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int) -> List[Evaluator]:
evaluators: List[Evaluator] = []
if isinstance(eval_loader_config, ListConfig):
eval_configs: ListConfig = eval_loader_config
is_multi_eval = True
else:
eval_configs = ListConfig([eval_loader_config])
is_multi_eval = False
for eval_config in eval_configs:
eval_dataloader = build_dataloader(eval_config, tokenizer, device_eval_batch_size)
eval_loader: Evaluator = Evaluator(label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', dataloader=eval_dataloader, metric_names=[])
evaluators.append(eval_loader)
return evaluators
def add_metrics_to_eval_loaders(evaluators: List[Evaluator], metric_names: List[str]) -> List[Evaluator]:
eval_loaders, other_evaluators = ([], [])
for evaluator in evaluators:
if evaluator.metric_names == []:
evaluator.metric_names = metric_names
eval_loaders.append(evaluator)
else:
other_evaluators.append(evaluator)
return eval_loaders + other_evaluators
def build_icl_data_and_gauntlet(icl_tasks_config: Union[str, ListConfig], eval_gauntlet_config: Optional[Union[str, DictConfig]], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
icl_evaluators, logger_keys = build_icl_evaluators(icl_tasks_config, tokenizer, icl_seq_len, device_eval_batch_size, icl_subset_num_batches=icl_subset_num_batches)
eval_gauntlet_cb = None
if eval_gauntlet_config is not None:
if isinstance(eval_gauntlet_config, str):
with open(eval_gauntlet_config, 'r') as icl_f:
eval_gauntlet_cfg = om.load(icl_f)
eval_gauntlet = eval_gauntlet_cfg.eval_gauntlet
elif isinstance(eval_gauntlet_config, DictConfig):
eval_gauntlet = eval_gauntlet_config
else:
raise ValueError(f'Got invalid type for eval_gauntlet_config: {type(eval_gauntlet_config)}')
eval_gauntlet.logger_keys = logger_keys
eval_gauntlet.benchmark_sizes = {e.label: e.dataloader.num_samples for e in icl_evaluators}
eval_gauntlet_cb = EvalGauntlet(**eval_gauntlet)
return (icl_evaluators, logger_keys, eval_gauntlet_cb)
def build_composer_model(name: str, cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, init_context: Optional[ContextManager]=None, master_weights_dtype: Optional[str]=None) -> ComposerModel:
"""Builds a ComposerModel from the registry.
Args:
name (str): Name of the model to build.
cfg (DictConfig): Configuration for the model.
tokenizer (PreTrainedTokenizerBase): Tokenizer to use.
init_context (Optional[ContextManager], optional): Context manager to use for initialization. Defaults to None.
master_weights_dtype (Optional[str], optional): Master weights dtype. Defaults to None.
Returns:
ComposerModel: _description_
"""
if init_context is None:
init_context = contextlib.nullcontext()
with init_context:
model = construct_from_registry(name=name, registry=registry.models, pre_validation_function=ComposerModel, post_validation_function=None, kwargs={'om_model_config': cfg, 'tokenizer': tokenizer})
str_dtype_to_torch_dtype = {'f16': torch.float16, 'float16': torch.float16, 'bf16': torch.bfloat16, 'bfloat16': torch.bfloat16}
if master_weights_dtype is not None:
if master_weights_dtype not in str_dtype_to_torch_dtype:
raise ValueError(f'Invalid master_weights_dtype: {master_weights_dtype}. ' + f'Valid options are: {list(str_dtype_to_torch_dtype.keys())}.')
dtype = str_dtype_to_torch_dtype[master_weights_dtype]
model = model.to(dtype=dtype)
return model
def build_callback(name: str, kwargs: Optional[Dict[str, Any]]=None, config: Any=None) -> Callback:
"""Builds a callback from the registry."""
registry_to_use = registry.callbacks
if name in registry.callbacks_with_config:
if kwargs is None:
kwargs = {}
if 'config' in kwargs:
raise ValueError(f'`config` is a reserved keyword for callbacks with config. Please remove it from the kwargs.')
kwargs['config'] = config
registry_to_use = registry.callbacks_with_config
return construct_from_registry(name=name, registry=registry_to_use, partial_function=True, pre_validation_function=Callback, post_validation_function=None, kwargs=kwargs)
def build_logger(name: str, kwargs: Optional[Dict[str, Any]]=None) -> LoggerDestination:
"""Builds a logger from the registry."""
return construct_from_registry(name=name, registry=registry.loggers, partial_function=True, pre_validation_function=LoggerDestination, post_validation_function=None, kwargs=kwargs)
def build_algorithm(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Algorithm:
"""Builds an algorithm from the registry."""
return construct_from_registry(name=name, registry=registry.algorithms, partial_function=True, pre_validation_function=Algorithm, post_validation_function=None, kwargs=kwargs)
def build_metric(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Metric:
"""Builds a metric from the registry."""
return construct_from_registry(name=name, registry=registry.metrics, partial_function=True, pre_validation_function=Metric, post_validation_function=None, kwargs=kwargs)
def _extract_param_groups(model: torch.nn.Module, optimizer_config: Optional[Dict[str, Any]]=None) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]:
"""Extracts parameter groups defined in the optimizer config.
The optimizer_config defines the optimizer args. It can additionally have key
`disable_grad` which is a string or list of strings. If a string matches a
parameter name, then that parameter will have `requires_grad=False`. This is
useful for freezing parameters. It can additionally have a key
`param_groups` which is a list of dicts. In this dict, key `param_str_match`
defines a string; if a parameter name contains this string, then it will be
in this parameter group. This is useful for grouping parameters together.
The dict can also contain any other key that is a valid optimizer arg.
Note: to handle name overlap conflicts, params are assigned to parameter
groups and added to `param_groups` in the order that `param_str_match` appear
in `param_groups`.
Usage
To disable gradient for all parameters that contain the string "norm" or "bias":
```
optimizer_config: {
"name": "decoupled_lionw",
"lr": 1e-3,
"weight_decay": 1e-2,
"betas": [0.9, 0.999],
"eps": 1e-8,
"disable_grad": ["norm", "bias"]
}
```
To create and modify the optimizer parameters for all parameters that contain
the string "norm" and "bias" separately:
```
optimizer_config: {
"name": "decoupled_lionw",
"lr": 1e-3,
"weight_decay": 1e-2,
"betas": [0.9, 0.999],
"eps": 1e-8,
"param_groups": [
{
"param_str_match": "norm",
"lr": 1e-4,
"weight_decay": 0.0,
},
{
"param_str_match": "bias",
"lr": 5e-4,
"weight_decay": 0.0,
},
],
}
```
Args:
model (torch.nn.Module): model to extract parameters from
optimizer_config (Dict[str, Any]): optimizer config
Returns:
Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of
torch.Tensor's or dict's. Specifies what Tensors should be optimized
and their param groupings.
"""
if optimizer_config is None:
return model.parameters()
if 'disable_grad' in optimizer_config.keys():
str_matches = optimizer_config.pop('disable_grad')
if isinstance(str_matches, str):
str_matches = [str_matches]
for str_match in str_matches:
for n, p in model.named_parameters():
if re.search(str_match, n):
p.requires_grad = False
log.debug(f'Setting `{n}.requires_grad = False`.')
param_groups_config = optimizer_config.pop('param_groups', None)
if param_groups_config is not None:
params = []
param_dict = OrderedDict(((n, p) for n, p in model.named_parameters()))
log.debug(f'Default optimizer settings: {optimizer_config}.')
for param_group_config in param_groups_config:
str_match = param_group_config.pop('param_str_match')
filter_fn = functools.partial(re.search, str_match)
param_names = [n for n in param_dict.keys() if filter_fn(n)]
group_params = {'params': [param_dict.pop(n) for n in param_names]}
group_params.update(param_group_config)
log.debug(f'Creating optimizer param_group with parameters: {param_names} ' + f'(extracted using str_match={str_match!r}). The param_group optimizer ' + f'setting overrides are: {param_group_config}.')
params.append(group_params)
params.insert(0, {'params': param_dict.values()})
return params
return model.parameters()
def build_optimizer(model: torch.nn.Module, name: str, optimizer_config: Optional[Dict[str, Any]]=None) -> Optimizer:
params = _extract_param_groups(model, optimizer_config)
kwargs = optimizer_config
if kwargs is None:
kwargs = {}
if 'params' in kwargs:
raise ValueError('The `params` will be automatically extracted from the model and ' + 'optimizer config. Please remove it from the optimizer config kwargs.')
kwargs['params'] = params
return construct_from_registry(name=name, registry=registry.optimizers, partial_function=True, pre_validation_function=Optimizer, post_validation_function=None, kwargs=kwargs)
def build_scheduler(name: str, scheduler_config: Optional[Dict[str, Any]]=None) -> ComposerScheduler:
return construct_from_registry(name=name, registry=registry.schedulers, partial_function=True, pre_validation_function=ComposerScheduler, post_validation_function=None, kwargs=scheduler_config)
def build_tokenizer(tokenizer_name: str, tokenizer_kwargs: Dict[str, Any]) -> PreTrainedTokenizerBase:
os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
os.environ['TOKENIZERS_PARALLELISM'] = 'false'
signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'
if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
with dist.local_rank_zero_download_and_wait(signal_file_path):
pass
if tokenizer_name.startswith('tiktoken'):
tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs)
else:
tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs)
tokenizer.model_max_length = tokenizer_kwargs.get('model_max_length', int(1e+30))
if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None:
raise ValueError(f'The tokenizer {tokenizer_name} must have an eos_token.')
if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
if dist.get_local_rank() == 0:
with open(signal_file_path, 'wb') as f:
f.write(b'local_rank0_completed_tokenizer_setup')
dist.barrier()
if dist.get_local_rank() == 0:
os.remove(signal_file_path)
return tokenizer
def build_icl_evaluators(icl_tasks: Union[str, ListConfig], tokenizer: PreTrainedTokenizerBase, default_max_seq_len: int, default_batch_size: int, destination_dir: Optional[str]=None, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str]]:
if destination_dir is None:
destination_dir = os.getcwd()
evaluators = []
logger_keys = []
icl_tasks_list = None
if isinstance(icl_tasks, str):
log.info(f'Extracting ICL task config from path: {icl_tasks}')
with open(icl_tasks, 'r') as icl_f:
icl_task_cfg = om.load(icl_f)
icl_tasks_list = icl_task_cfg.icl_tasks
else:
icl_tasks_list = icl_tasks
def _validate_cfg(icl_cfg: DictConfig):
assert 'label' in icl_cfg
assert 'dataset_uri' in icl_cfg and icl_cfg.dataset_uri is not None
assert 'icl_task_type' in icl_cfg
assert 'num_fewshot' in icl_cfg
if 'metric_names' not in icl_cfg:
if icl_cfg.icl_task_type == 'language_modeling':
icl_cfg.metric_names = ['InContextLearningLMAccuracy']
elif icl_cfg.icl_task_type == 'multiple_choice':
icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy']
elif icl_cfg.icl_task_type == 'schema':
icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy']
elif icl_cfg.icl_task_type == 'question_answering':
icl_cfg.metric_names = ['InContextLearningQAAccuracy']
elif icl_cfg.icl_task_type == 'code_evaluation':
icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy']
else:
raise ValueError(f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.')
if 'prompt_string' not in icl_cfg:
icl_cfg.prompt_string = ''
if 'example_delimiter' not in icl_cfg:
icl_cfg.example_delimiter = '\n'
if 'continuation_delimiter' not in icl_cfg:
icl_cfg.continuation_delimiter = ' '
if 'max_seq_len' not in icl_cfg:
icl_cfg.max_seq_len = default_max_seq_len
if 'batch_size' not in icl_cfg:
icl_cfg.batch_size = default_batch_size
if 'pass_at_k' not in icl_cfg:
icl_cfg.pass_at_k = 1
if 'fewshot_random_seed' not in icl_cfg:
icl_cfg.fewshot_random_seed = 1234
if 'generations_per_sample' not in icl_cfg:
icl_cfg.generations_per_sample = 1
if 'num_beams' in icl_cfg:
raise ValueError('num_beams is no longer supported as a top level icl_task parameter.' + 'Please use generation_kwargs.num_beams instead.')
for icl_cfg in icl_tasks_list:
assert isinstance(icl_cfg, DictConfig)
_validate_cfg(icl_cfg)
for num_fewshot in list(icl_cfg.num_fewshot):
if tokenizer.pad_token_id is None:
pad_tok_id = tokenizer.eos_token_id
else:
pad_tok_id = tokenizer.pad_token_id
label = f'{icl_cfg.label}/{num_fewshot}-shot'
metric_names = list(icl_cfg.metric_names)
destination_path = f'{destination_dir}/{icl_cfg.label}-{num_fewshot}.jsonl'
if dist.get_local_rank() == 0 and os.path.exists(destination_path):
os.remove(destination_path)
dist.barrier()
hf_parsing_map = icl_cfg.get('hf_parsing_map', {})
hf_loading_vars = icl_cfg.get('hf_loading_vars', {})
early_stopping_criteria = icl_cfg.get('early_stopping_criteria', None)
if isinstance(early_stopping_criteria, ListConfig):
early_stopping_criteria = om.to_container(early_stopping_criteria)
assert early_stopping_criteria is None or isinstance(early_stopping_criteria, list)
dataloaders = get_icl_task_dataloader(icl_cfg.icl_task_type, icl_cfg.dataset_uri, tokenizer, batch_size=icl_cfg.batch_size, max_seq_len=icl_cfg.max_seq_len, pad_tok_id=pad_tok_id, num_fewshot=num_fewshot, prompt_string=icl_cfg.prompt_string, example_delimiter=icl_cfg.example_delimiter, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, continuation_delimiter=icl_cfg.continuation_delimiter, question_prelimiter=icl_cfg.get('question_prelimiter', ''), destination_path=destination_path, fewshot_random_seed=icl_cfg.fewshot_random_seed, pass_at_k=icl_cfg.pass_at_k, generations_per_sample=icl_cfg.generations_per_sample, has_categories=icl_cfg.get('has_categories', False), cot_delimiter=icl_cfg.get('cot_delimiter', ''), generation_kwargs=icl_cfg.get('generation_kwargs', {}), early_stopping_criteria=early_stopping_criteria, do_normalization=icl_cfg.get('do_normalization', True))
if hasattr(icl_cfg, 'has_categories') and icl_cfg.has_categories and isinstance(dataloaders, dict):
for category in dataloaders.keys():
logger_keys.extend([f'metrics/{label}/{category}/{m}' for m in metric_names])
evaluators.append(Evaluator(label=f'{label}/{category}', dataloader=dataloaders[category], metric_names=metric_names))
else:
logger_keys.extend([f'metrics/{label}/{m}' for m in metric_names])
evaluators.append(Evaluator(label=label, dataloader=dataloaders, metric_names=metric_names, subset_num_batches=icl_subset_num_batches))
return (evaluators, logger_keys) |