mpt-7b-8k-chat / packing.py
irenedea's picture
LLM-foundry update March 26, 2024 23:50:31
fdb2891 verified
raw
history blame
14.4 kB
import logging
import tempfile
from typing import Callable, Dict, Iterable, List, Literal, Optional, Tuple
import numpy as np
import torch
from transformers import PreTrainedTokenizerBase
log = logging.getLogger(__name__)
class BinPackCollator:
"""Utility collator for packing to reduce padding."""
def __init__(self, collator: Callable, target_batch_size: int, max_seq_len: int, pad_token_id: int, padding_side: Literal['left', 'right'], max_leftover_bins_to_keep: Optional[int]=None):
self.base_collator = collator
self.out_size = int(target_batch_size)
self.max_seq_len = int(max_seq_len)
self.pad_token_id = int(pad_token_id)
self.padding_side = padding_side
if self.out_size <= 0:
raise ValueError(f'target_batch_size={target_batch_size!r} must be >0.')
if self.max_seq_len <= 0:
raise ValueError(f'max_seq_len={max_seq_len!r} must be >0.')
if self.pad_token_id < 0:
raise ValueError(f'pad_token_id={pad_token_id!r} must be >=0.')
if max_leftover_bins_to_keep is not None and max_leftover_bins_to_keep < 0:
raise ValueError(f'max_leftover_bins_to_keep={max_leftover_bins_to_keep!r} must be >=0 or None.')
self.max_leftover_bins_to_keep = max_leftover_bins_to_keep
self.n_packed_tokens = 0
self.n_total_tokens = 0
self.n_packed_examples = 0
self._leftover_bins: List[Tuple[int, Dict[str, torch.Tensor]]] = []
@property
def waste(self) -> float:
return 1 - self.n_packed_tokens / self.n_total_tokens
@property
def efficiency(self) -> float:
return self.n_packed_tokens / (self.max_seq_len * self.n_packed_examples)
def __call__(self, examples: List[Dict[str, torch.Tensor]]) -> Dict[str, torch.Tensor]:
batch = self.base_collator(examples)
return self.pack(batch)
def pack(self, batch: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
assert 'attention_mask' in batch
assert 'input_ids' in batch
for key in batch.keys():
assert key in ['input_ids', 'labels', 'attention_mask', 'sequence_id']
sizes, trimmed_examples = _trim_batch(batch)
return self._pack_trimmed_examples(trimmed_examples, sizes)
def _pack_trimmed_examples(self, trimmed_examples: List[Dict[str, torch.Tensor]], sizes: List[int]) -> Dict[str, torch.Tensor]:
"""Packs trimmed examples into fixed-size bins and repads them.
Args:
trimmed_examples (List[Dict[str, torch.Tensor]]): A list of trimmed examples.
sizes (List[int]): The sizes of the trimmed examples.
Returns:
Dict[str, torch.Tensor]: A batch of repadded examples ready for processing
"""
packed_examples, n_packed_tokens, n_total_tokens, leftover_bins = _first_fit_bin_packing(sizes=sizes, examples=trimmed_examples, num_bins=self.out_size, max_bin_size=self.max_seq_len, existing_bins=self._leftover_bins)
self.n_packed_tokens += n_packed_tokens
self.n_total_tokens += n_total_tokens
self.n_packed_examples += self.out_size
self._leftover_bins = leftover_bins[:self.max_leftover_bins_to_keep]
batch = _repad(packed_examples, max_seq_len=self.max_seq_len, pad_token_id=self.pad_token_id, padding_side=self.padding_side)
return batch
def _trim_batch(batch: Dict[str, torch.Tensor]) -> Tuple[List[int], List[Dict[str, torch.Tensor]]]:
"""Trims padding off all examples in batch.
Args:
batch (Dict[str, torch.Tensor]): Batch of padded data with tensors as values.
Returns:
A tuple with unpadded lengths of examples and a list of each trimmed example from the batch.
"""
sizes, trimmed_examples = ([], [])
for idx in range(batch['attention_mask'].shape[0]):
size, trimmed_example = _extract_trim_batch_idx(batch, idx)
sizes.append(size)
trimmed_examples.append(trimmed_example)
return (sizes, trimmed_examples)
def _extract_trim_batch_idx(batch: Dict[str, torch.Tensor], idx: int) -> Tuple[int, Dict[str, torch.Tensor]]:
example = {k: v[idx] for k, v in batch.items()}
keep = example['attention_mask'] == 1
size = int(keep.sum())
trim_example = {k: v[keep] for k, v in example.items()}
trim_example['sequence_id'] = torch.zeros_like(trim_example['input_ids'])
return (size, trim_example)
def _combine_in_place(example: Dict[str, torch.Tensor], add_on: Dict[str, torch.Tensor]) -> Dict[str, torch.Tensor]:
if 'labels' in add_on:
add_on['labels'][0] = -100
for k in example.keys():
if k == 'sequence_id':
example[k] = torch.cat([example[k], add_on[k] + 1 + torch.max(example[k])])
else:
example[k] = torch.cat([example[k], add_on[k]])
return example
def _first_fit_bin_packing(sizes: List[int], examples: List[Dict[str, torch.Tensor]], num_bins: int, max_bin_size: int, existing_bins: List[Tuple[int, Dict[str, torch.Tensor]]]) -> Tuple[List[Dict[str, torch.Tensor]], int, int, List[Tuple[int, Dict[str, torch.Tensor]]]]:
bins: List[Tuple[int, Dict[str, torch.Tensor]]] = existing_bins
starting_total_bin_sizes = sum([bin_size for bin_size, _ in bins])
sizes_and_examples = [(size, example) for size, example in zip(sizes, examples)]
sorted_sizes_and_examples = sorted(sizes_and_examples, key=lambda x: x[0], reverse=True)
required_num_examples = max(0, num_bins - len(bins))
num_examples = len(sizes)
if num_examples < required_num_examples:
for size, example in sorted_sizes_and_examples:
bins.append((size, example))
total_bin_sizes = sum([bin_size for bin_size, _ in bins])
total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
total_example_sizes = sum(sizes)
if total_new_bin_sizes != total_example_sizes:
raise AssertionError(f'Error in packing. total_example_sizes={total_example_sizes!r} does not equal total_new_bin_sizes={total_new_bin_sizes!r}.')
sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
bin_sizes, packed_examples = ([], [])
for bin_size, packed_example in sorted_bins:
bin_sizes.append(bin_size)
packed_examples.append(packed_example)
return (packed_examples[:num_bins], sum(bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:])
for i, (size, example) in enumerate(sorted_sizes_and_examples):
required_num_examples = max(0, num_bins - len(bins))
n_remaining = num_examples - i
assert n_remaining >= required_num_examples
if n_remaining == required_num_examples:
bins.append((size, example))
continue
added = False
for bidx in range(len(bins)):
if bins[bidx][0] + size <= max_bin_size:
bin_size, packed_example = bins.pop(bidx)
bin_size = bin_size + size
packed_example = _combine_in_place(packed_example, example)
bins.append((bin_size, packed_example))
added = True
break
if not added:
bins.append((size, example))
total_bin_sizes = sum([bin_size for bin_size, _ in bins])
total_new_bin_sizes = total_bin_sizes - starting_total_bin_sizes
total_example_sizes = sum(sizes)
if total_new_bin_sizes != total_example_sizes:
raise AssertionError(f'Error in packing. total_example_sizes={total_example_sizes!r} does not equal total_new_bin_sizes={total_new_bin_sizes!r}.')
sorted_bins = sorted(bins, key=lambda x: x[0], reverse=True)
bin_sizes, packed_examples = ([], [])
for bin_size, packed_example in sorted_bins:
bin_sizes.append(bin_size)
packed_examples.append(packed_example)
return (packed_examples[:num_bins], sum(bin_sizes[:num_bins]), sum(sizes), sorted_bins[num_bins:])
def _repad(packed_examples: List[Dict[str, torch.Tensor]], max_seq_len: int, pad_token_id: int, padding_side: str) -> Dict[str, torch.Tensor]:
def pad_tensor(tensor: torch.Tensor, pad_value: int):
if len(tensor) == max_seq_len:
return tensor
t = torch.full((max_seq_len,), pad_value, dtype=tensor.dtype, device=tensor.device)
if padding_side == 'left':
t[-len(tensor):] = tensor
elif padding_side == 'right':
t[:len(tensor)] = tensor
else:
raise ValueError(f'Unknown padding_side={padding_side!r}')
return t
pad_vals = {'input_ids': pad_token_id, 'labels': -100, 'attention_mask': 0, 'sequence_id': -1}
keys = packed_examples[0].keys()
batch = {}
for key in keys:
batch[key] = torch.stack([pad_tensor(example[key], pad_vals[key]) for example in packed_examples])
return batch
def auto_packing_ratio(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, device_batch_size: int, num_packing_ratios: int=20) -> float:
"""Find a packing ratio that minimizes padding with zero waste.
By packing examples, we can increase training efficiency, training on more data with less batches.
However, in practice, the selected packing_ratio may produce some waste because profiling is done on only
a subset of the dataset.
We select a min_ratio of 1 and a max_ratio that is the max_seq_len / 100, and profile up to
num_packing_ratios packing ratios between min_ratio and max_ratio, inclusive.
When a packing_ratio with non-zero waste is found, we stop and select the previous ratio,
which has zero waste.
Args:
dataloader_cfg (DictConfig): The dataloader configuration for profiling.
tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
device_batch_size (int): The size of the batches (number of examples) per device.
num_packing_ratio (int): The number of packing ratios to try.
Returns:
A packing ratio that minimizes padding while maintaining zero waste.
"""
rng_state = reproducibility.get_rng_state()
reproducibility.seed_all(0)
max_seq_len = dataloader_cfg.dataset.max_seq_len
if max_seq_len <= 100:
return 1
min_ratio = 1
max_ratio = max_seq_len / 100
profiling_results = profile_packing(dataloader_cfg, tokenizer, min_ratio, max_ratio, num_packing_ratios, device_batch_size)
packing_ratio = 1
for packing_ratio_candidate, _, waste in profiling_results:
if waste is None or waste > 0:
break
packing_ratio = packing_ratio_candidate
if dist.is_available() and dist.is_initialized():
device = get_device(None)
packing_ratio_tensor = device.tensor_to_device(torch.tensor(packing_ratio))
dist.all_reduce(packing_ratio_tensor, reduce_operation='MIN')
packing_ratio = packing_ratio_tensor.item()
reproducibility.load_rng_state(rng_state)
return packing_ratio
def profile_packing(dataloader_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, min_ratio: float, max_ratio: float, num_packing_ratios: int, device_batch_size: int) -> Iterable[Tuple[float, Optional[float], Optional[float]]]:
"""Generator function that profiles example packing across packing ratios.
Args:
dataloader_cfg (DictConfig): The dataloader configuration for profiling.
tokenizer (PreTrainedTokenizerBase): The tokenizer for profiling.
min_ratio (float): Smallest packing_ratio to test. Must be >=1.
max_ratio (float): Largest packing_ratio to test. Must be larger than `min_ratio`.
num_packing_ratios (int): Number of packing_ratio values (spaced between `min_ratio` and `max_ratio`) to try.
device_batch_size (int): The size of the batches (number of examples) per device.
Returns:
An iterable of tuples of packing ratio, padding, and waste, sorted by smallest to largest packing ratio.
"""
import copy
from .dataloader import build_dataloader
max_seq_len = dataloader_cfg.dataset.get('max_seq_len')
max_leftovers_to_keep = dataloader_cfg.dataset.get('max_leftovers_to_keep', None)
dataloader_cfg = copy.deepcopy(dataloader_cfg)
dataloader_cfg.dataset.packing_ratio = 1.0
dataloader_cfg.drop_last = False
dataloader_cfg.num_workers = 0
dataloader_cfg.prefetch_factor = None
dataloader_cfg.persistent_workers = False
if dataloader_cfg.dataset.get('remote') is not None:
dataloader_cfg.dataset.local = tempfile.TemporaryDirectory().name
packing_ratios, raw_batch_sizes = ([], [])
for packing_ratio in np.linspace(min_ratio, max_ratio, num_packing_ratios, endpoint=True):
packing_ratio = np.round(10 * packing_ratio) / 10
raw_batch_size = int(packing_ratio * device_batch_size)
if raw_batch_size not in raw_batch_sizes:
packing_ratios.append(packing_ratio)
raw_batch_sizes.append(raw_batch_size)
n_profile_examples = max(raw_batch_sizes) * 100
train_dataspec = build_dataloader(dataloader_cfg, tokenizer, n_profile_examples)
train_dataloader = train_dataspec.dataloader
big_batch = next(iter(train_dataloader))
sizes, trimmed_examples = _trim_batch(big_batch)
def profile(raw_batch_size: int) -> Tuple[Optional[float], Optional[float]]:
trimmed_examples_copy = [te.copy() for te in trimmed_examples]
packer = BinPackCollator(collator=lambda x: x, target_batch_size=device_batch_size, max_seq_len=max_seq_len, pad_token_id=0, padding_side='left', max_leftover_bins_to_keep=max_leftovers_to_keep)
for idx in range(0, len(trimmed_examples_copy), raw_batch_size):
batch = trimmed_examples_copy[idx:idx + raw_batch_size]
if len(batch) < device_batch_size:
continue
packer._pack_trimmed_examples(batch, sizes[idx:idx + raw_batch_size])
if packer.n_packed_examples == 0:
log.debug('No examples packed during profiling. Dataset is smaller than device batch size.')
return (None, None)
padding_percent = 100 * (1 - packer.efficiency)
waste_percent = 100 * packer.waste
return (padding_percent, waste_percent)
for packing_ratio, raw_batch_size in zip(packing_ratios, raw_batch_sizes):
padding, waste = profile(raw_batch_size)
yield (packing_ratio, padding, waste)