File size: 18,717 Bytes
fdb2891
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
253
254
255
256
257
258
259
260
261
262
263
264
265
266
267
268
269
270
271
272
273
274
275
276
277
278
279
280
281
282
283
284
285
286
287
288
289
290
291
292
293
294
295
296
297
298
299
300
301
302
303
304
305
306
307
308
309
310
311
312
313
314
315
316
317
318
319
320
321
322
323
import contextlib
import functools
import logging
import os
import re
from collections import OrderedDict
from typing import Any, ContextManager, Dict, Iterable, List, Optional, Tuple, Union
import torch
from torch.optim.optimizer import Optimizer
from torchmetrics import Metric
from transformers import AutoTokenizer, PreTrainedTokenizerBase
from .llmfoundry import registry
from .callbacks import EvalGauntlet
from .dataloader import build_dataloader
from .tiktoken import TiktokenTokenizerWrapper
from .registry_utils import construct_from_registry
log = logging.getLogger(__name__)

def build_evaluators(eval_loader_config: Optional[Union[DictConfig, ListConfig]], icl_tasks_config: Optional[Union[str, ListConfig]], eval_gauntlet_config: Optional[Union[str, DictConfig]], *, tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
    evaluators = []
    if eval_loader_config is not None:
        evaluators = build_eval_loaders(eval_loader_config, tokenizer, device_eval_batch_size)
    logger_keys = []
    eval_gauntlet_callback = None
    if icl_tasks_config is not None:
        icl_evaluators, logger_keys, eval_gauntlet_callback = build_icl_data_and_gauntlet(icl_tasks_config, eval_gauntlet_config, tokenizer, device_eval_batch_size, icl_seq_len, icl_subset_num_batches)
        evaluators.extend(icl_evaluators)
    return (evaluators, logger_keys, eval_gauntlet_callback)

def build_eval_loaders(eval_loader_config: Union[DictConfig, ListConfig], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int) -> List[Evaluator]:
    evaluators: List[Evaluator] = []
    if isinstance(eval_loader_config, ListConfig):
        eval_configs: ListConfig = eval_loader_config
        is_multi_eval = True
    else:
        eval_configs = ListConfig([eval_loader_config])
        is_multi_eval = False
    for eval_config in eval_configs:
        eval_dataloader = build_dataloader(eval_config, tokenizer, device_eval_batch_size)
        eval_loader: Evaluator = Evaluator(label=f'eval/{eval_config.label}' if is_multi_eval else 'eval', dataloader=eval_dataloader, metric_names=[])
        evaluators.append(eval_loader)
    return evaluators

def add_metrics_to_eval_loaders(evaluators: List[Evaluator], metric_names: List[str]) -> List[Evaluator]:
    eval_loaders, other_evaluators = ([], [])
    for evaluator in evaluators:
        if evaluator.metric_names == []:
            evaluator.metric_names = metric_names
            eval_loaders.append(evaluator)
        else:
            other_evaluators.append(evaluator)
    return eval_loaders + other_evaluators

def build_icl_data_and_gauntlet(icl_tasks_config: Union[str, ListConfig], eval_gauntlet_config: Optional[Union[str, DictConfig]], tokenizer: PreTrainedTokenizerBase, device_eval_batch_size: int, icl_seq_len: int, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str], Optional[EvalGauntlet]]:
    icl_evaluators, logger_keys = build_icl_evaluators(icl_tasks_config, tokenizer, icl_seq_len, device_eval_batch_size, icl_subset_num_batches=icl_subset_num_batches)
    eval_gauntlet_cb = None
    if eval_gauntlet_config is not None:
        if isinstance(eval_gauntlet_config, str):
            with open(eval_gauntlet_config, 'r') as icl_f:
                eval_gauntlet_cfg = om.load(icl_f)
            eval_gauntlet = eval_gauntlet_cfg.eval_gauntlet
        elif isinstance(eval_gauntlet_config, DictConfig):
            eval_gauntlet = eval_gauntlet_config
        else:
            raise ValueError(f'Got invalid type for eval_gauntlet_config: {type(eval_gauntlet_config)}')
        eval_gauntlet.logger_keys = logger_keys
        eval_gauntlet.benchmark_sizes = {e.label: e.dataloader.num_samples for e in icl_evaluators}
        eval_gauntlet_cb = EvalGauntlet(**eval_gauntlet)
    return (icl_evaluators, logger_keys, eval_gauntlet_cb)

def build_composer_model(name: str, cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, init_context: Optional[ContextManager]=None, master_weights_dtype: Optional[str]=None) -> ComposerModel:
    """Builds a ComposerModel from the registry.

    Args:
        name (str): Name of the model to build.
        cfg (DictConfig): Configuration for the model.
        tokenizer (PreTrainedTokenizerBase): Tokenizer to use.
        init_context (Optional[ContextManager], optional): Context manager to use for initialization. Defaults to None.
        master_weights_dtype (Optional[str], optional): Master weights dtype. Defaults to None.

    Returns:
        ComposerModel: _description_
    """
    if init_context is None:
        init_context = contextlib.nullcontext()
    with init_context:
        model = construct_from_registry(name=name, registry=registry.models, pre_validation_function=ComposerModel, post_validation_function=None, kwargs={'om_model_config': cfg, 'tokenizer': tokenizer})
    str_dtype_to_torch_dtype = {'f16': torch.float16, 'float16': torch.float16, 'bf16': torch.bfloat16, 'bfloat16': torch.bfloat16}
    if master_weights_dtype is not None:
        if master_weights_dtype not in str_dtype_to_torch_dtype:
            raise ValueError(f'Invalid master_weights_dtype: {master_weights_dtype}. ' + f'Valid options are: {list(str_dtype_to_torch_dtype.keys())}.')
        dtype = str_dtype_to_torch_dtype[master_weights_dtype]
        model = model.to(dtype=dtype)
    return model

def build_callback(name: str, kwargs: Optional[Dict[str, Any]]=None, config: Any=None) -> Callback:
    """Builds a callback from the registry."""
    registry_to_use = registry.callbacks
    if name in registry.callbacks_with_config:
        if kwargs is None:
            kwargs = {}
        if 'config' in kwargs:
            raise ValueError(f'`config` is a reserved keyword for callbacks with config. Please remove it from the kwargs.')
        kwargs['config'] = config
        registry_to_use = registry.callbacks_with_config
    return construct_from_registry(name=name, registry=registry_to_use, partial_function=True, pre_validation_function=Callback, post_validation_function=None, kwargs=kwargs)

def build_logger(name: str, kwargs: Optional[Dict[str, Any]]=None) -> LoggerDestination:
    """Builds a logger from the registry."""
    return construct_from_registry(name=name, registry=registry.loggers, partial_function=True, pre_validation_function=LoggerDestination, post_validation_function=None, kwargs=kwargs)

def build_algorithm(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Algorithm:
    """Builds an algorithm from the registry."""
    return construct_from_registry(name=name, registry=registry.algorithms, partial_function=True, pre_validation_function=Algorithm, post_validation_function=None, kwargs=kwargs)

def build_metric(name: str, kwargs: Optional[Dict[str, Any]]=None) -> Metric:
    """Builds a metric from the registry."""
    return construct_from_registry(name=name, registry=registry.metrics, partial_function=True, pre_validation_function=Metric, post_validation_function=None, kwargs=kwargs)

def _extract_param_groups(model: torch.nn.Module, optimizer_config: Optional[Dict[str, Any]]=None) -> Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]:
    """Extracts parameter groups defined in the optimizer config.

    The optimizer_config defines the optimizer args. It can additionally have key
    `disable_grad` which is a string or list of strings. If a string matches a
    parameter name, then that parameter will have `requires_grad=False`. This is
    useful for freezing parameters. It can additionally have a key
    `param_groups` which is a list of dicts. In this dict, key `param_str_match`
    defines a string; if a parameter name contains this string, then it will be
    in this parameter group. This is useful for grouping parameters together.
    The dict can also contain any other key that is a valid optimizer arg.
    Note: to handle name overlap conflicts, params are assigned to parameter
    groups and added to `param_groups` in the order that `param_str_match` appear
    in `param_groups`.

    Usage
    To disable gradient for all parameters that contain the string "norm" or "bias":
    ```
    optimizer_config: {
        "name": "decoupled_lionw",
        "lr": 1e-3,
        "weight_decay": 1e-2,
        "betas": [0.9, 0.999],
        "eps": 1e-8,
        "disable_grad": ["norm", "bias"]
    }
    ```

    To create and modify the optimizer parameters for all parameters that contain
    the string "norm" and "bias" separately:
    ```
    optimizer_config: {
        "name": "decoupled_lionw",
        "lr": 1e-3,
        "weight_decay": 1e-2,
        "betas": [0.9, 0.999],
        "eps": 1e-8,
        "param_groups": [
            {
                "param_str_match": "norm",
                "lr": 1e-4,
                "weight_decay": 0.0,
            },
            {
                "param_str_match": "bias",
                "lr": 5e-4,
                "weight_decay": 0.0,
            },
        ],
    }
    ```

    Args:
        model (torch.nn.Module): model to extract parameters from
        optimizer_config (Dict[str, Any]): optimizer config

    Returns:
        Union[Iterable[torch.Tensor], Iterable[Dict[str, Any]]]: an iterable of
            torch.Tensor's or dict's. Specifies what Tensors should be optimized
            and their param groupings.
    """
    if optimizer_config is None:
        return model.parameters()
    if 'disable_grad' in optimizer_config.keys():
        str_matches = optimizer_config.pop('disable_grad')
        if isinstance(str_matches, str):
            str_matches = [str_matches]
        for str_match in str_matches:
            for n, p in model.named_parameters():
                if re.search(str_match, n):
                    p.requires_grad = False
                    log.debug(f'Setting `{n}.requires_grad = False`.')
    param_groups_config = optimizer_config.pop('param_groups', None)
    if param_groups_config is not None:
        params = []
        param_dict = OrderedDict(((n, p) for n, p in model.named_parameters()))
        log.debug(f'Default optimizer settings: {optimizer_config}.')
        for param_group_config in param_groups_config:
            str_match = param_group_config.pop('param_str_match')
            filter_fn = functools.partial(re.search, str_match)
            param_names = [n for n in param_dict.keys() if filter_fn(n)]
            group_params = {'params': [param_dict.pop(n) for n in param_names]}
            group_params.update(param_group_config)
            log.debug(f'Creating optimizer param_group with parameters: {param_names} ' + f'(extracted using str_match={str_match!r}). The param_group optimizer ' + f'setting overrides are: {param_group_config}.')
            params.append(group_params)
        params.insert(0, {'params': param_dict.values()})
        return params
    return model.parameters()

def build_optimizer(model: torch.nn.Module, name: str, optimizer_config: Optional[Dict[str, Any]]=None) -> Optimizer:
    params = _extract_param_groups(model, optimizer_config)
    kwargs = optimizer_config
    if kwargs is None:
        kwargs = {}
    if 'params' in kwargs:
        raise ValueError('The `params` will be automatically extracted from the model and ' + 'optimizer config. Please remove it from the optimizer config kwargs.')
    kwargs['params'] = params
    return construct_from_registry(name=name, registry=registry.optimizers, partial_function=True, pre_validation_function=Optimizer, post_validation_function=None, kwargs=kwargs)

def build_scheduler(name: str, scheduler_config: Optional[Dict[str, Any]]=None) -> ComposerScheduler:
    return construct_from_registry(name=name, registry=registry.schedulers, partial_function=True, pre_validation_function=ComposerScheduler, post_validation_function=None, kwargs=scheduler_config)

def build_tokenizer(tokenizer_name: str, tokenizer_kwargs: Dict[str, Any]) -> PreTrainedTokenizerBase:
    os.environ['TRANSFORMERS_NO_ADVISORY_WARNINGS'] = '1'
    os.environ['TOKENIZERS_PARALLELISM'] = 'false'
    signal_file_path = f'.node_{dist.get_node_rank()}_local_rank0_completed_tokenizer_setup'
    if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
        with dist.local_rank_zero_download_and_wait(signal_file_path):
            pass
    if tokenizer_name.startswith('tiktoken'):
        tokenizer = TiktokenTokenizerWrapper(**tokenizer_kwargs)
    else:
        tokenizer = AutoTokenizer.from_pretrained(tokenizer_name, **tokenizer_kwargs)
        tokenizer.model_max_length = tokenizer_kwargs.get('model_max_length', int(1e+30))
    if not hasattr(tokenizer, 'eos_token') or tokenizer.eos_token is None:
        raise ValueError(f'The tokenizer {tokenizer_name} must have an eos_token.')
    if dist.is_available() and dist.is_initialized() and (dist.get_world_size() > 1):
        if dist.get_local_rank() == 0:
            with open(signal_file_path, 'wb') as f:
                f.write(b'local_rank0_completed_tokenizer_setup')
        dist.barrier()
        if dist.get_local_rank() == 0:
            os.remove(signal_file_path)
    return tokenizer

def build_icl_evaluators(icl_tasks: Union[str, ListConfig], tokenizer: PreTrainedTokenizerBase, default_max_seq_len: int, default_batch_size: int, destination_dir: Optional[str]=None, icl_subset_num_batches: Optional[int]=None) -> Tuple[List[Evaluator], List[str]]:
    if destination_dir is None:
        destination_dir = os.getcwd()
    evaluators = []
    logger_keys = []
    icl_tasks_list = None
    if isinstance(icl_tasks, str):
        log.info(f'Extracting ICL task config from path: {icl_tasks}')
        with open(icl_tasks, 'r') as icl_f:
            icl_task_cfg = om.load(icl_f)
        icl_tasks_list = icl_task_cfg.icl_tasks
    else:
        icl_tasks_list = icl_tasks

    def _validate_cfg(icl_cfg: DictConfig):
        assert 'label' in icl_cfg
        assert 'dataset_uri' in icl_cfg and icl_cfg.dataset_uri is not None
        assert 'icl_task_type' in icl_cfg
        assert 'num_fewshot' in icl_cfg
        if 'metric_names' not in icl_cfg:
            if icl_cfg.icl_task_type == 'language_modeling':
                icl_cfg.metric_names = ['InContextLearningLMAccuracy']
            elif icl_cfg.icl_task_type == 'multiple_choice':
                icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy']
            elif icl_cfg.icl_task_type == 'schema':
                icl_cfg.metric_names = ['InContextLearningMultipleChoiceAccuracy']
            elif icl_cfg.icl_task_type == 'question_answering':
                icl_cfg.metric_names = ['InContextLearningQAAccuracy']
            elif icl_cfg.icl_task_type == 'code_evaluation':
                icl_cfg.metric_names = ['InContextLearningCodeEvalAccuracy']
            else:
                raise ValueError(f'No metric_names defined, unable to build default metrics for icl_task_type={icl_cfg.icl_task_type}.')
        if 'prompt_string' not in icl_cfg:
            icl_cfg.prompt_string = ''
        if 'example_delimiter' not in icl_cfg:
            icl_cfg.example_delimiter = '\n'
        if 'continuation_delimiter' not in icl_cfg:
            icl_cfg.continuation_delimiter = ' '
        if 'max_seq_len' not in icl_cfg:
            icl_cfg.max_seq_len = default_max_seq_len
        if 'batch_size' not in icl_cfg:
            icl_cfg.batch_size = default_batch_size
        if 'pass_at_k' not in icl_cfg:
            icl_cfg.pass_at_k = 1
        if 'fewshot_random_seed' not in icl_cfg:
            icl_cfg.fewshot_random_seed = 1234
        if 'generations_per_sample' not in icl_cfg:
            icl_cfg.generations_per_sample = 1
        if 'num_beams' in icl_cfg:
            raise ValueError('num_beams is no longer supported as a top level icl_task parameter.' + 'Please use generation_kwargs.num_beams instead.')
    for icl_cfg in icl_tasks_list:
        assert isinstance(icl_cfg, DictConfig)
        _validate_cfg(icl_cfg)
        for num_fewshot in list(icl_cfg.num_fewshot):
            if tokenizer.pad_token_id is None:
                pad_tok_id = tokenizer.eos_token_id
            else:
                pad_tok_id = tokenizer.pad_token_id
            label = f'{icl_cfg.label}/{num_fewshot}-shot'
            metric_names = list(icl_cfg.metric_names)
            destination_path = f'{destination_dir}/{icl_cfg.label}-{num_fewshot}.jsonl'
            if dist.get_local_rank() == 0 and os.path.exists(destination_path):
                os.remove(destination_path)
            dist.barrier()
            hf_parsing_map = icl_cfg.get('hf_parsing_map', {})
            hf_loading_vars = icl_cfg.get('hf_loading_vars', {})
            early_stopping_criteria = icl_cfg.get('early_stopping_criteria', None)
            if isinstance(early_stopping_criteria, ListConfig):
                early_stopping_criteria = om.to_container(early_stopping_criteria)
            assert early_stopping_criteria is None or isinstance(early_stopping_criteria, list)
            dataloaders = get_icl_task_dataloader(icl_cfg.icl_task_type, icl_cfg.dataset_uri, tokenizer, batch_size=icl_cfg.batch_size, max_seq_len=icl_cfg.max_seq_len, pad_tok_id=pad_tok_id, num_fewshot=num_fewshot, prompt_string=icl_cfg.prompt_string, example_delimiter=icl_cfg.example_delimiter, hf_loading_vars=hf_loading_vars, hf_parsing_map=hf_parsing_map, continuation_delimiter=icl_cfg.continuation_delimiter, question_prelimiter=icl_cfg.get('question_prelimiter', ''), destination_path=destination_path, fewshot_random_seed=icl_cfg.fewshot_random_seed, pass_at_k=icl_cfg.pass_at_k, generations_per_sample=icl_cfg.generations_per_sample, has_categories=icl_cfg.get('has_categories', False), cot_delimiter=icl_cfg.get('cot_delimiter', ''), generation_kwargs=icl_cfg.get('generation_kwargs', {}), early_stopping_criteria=early_stopping_criteria, do_normalization=icl_cfg.get('do_normalization', True))
            if hasattr(icl_cfg, 'has_categories') and icl_cfg.has_categories and isinstance(dataloaders, dict):
                for category in dataloaders.keys():
                    logger_keys.extend([f'metrics/{label}/{category}/{m}' for m in metric_names])
                    evaluators.append(Evaluator(label=f'{label}/{category}', dataloader=dataloaders[category], metric_names=metric_names))
            else:
                logger_keys.extend([f'metrics/{label}/{m}' for m in metric_names])
                evaluators.append(Evaluator(label=label, dataloader=dataloaders, metric_names=metric_names, subset_num_batches=icl_subset_num_batches))
    return (evaluators, logger_keys)