|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from itertools import chain |
|
from typing import TYPE_CHECKING, Any, Dict, List |
|
|
|
|
|
if TYPE_CHECKING: |
|
from transformers import PreTrainedTokenizer |
|
|
|
from ...hparams import DataArguments |
|
|
|
|
|
def preprocess_pretrain_dataset( |
|
examples: Dict[str, List[Any]], tokenizer: "PreTrainedTokenizer", data_args: "DataArguments" |
|
) -> Dict[str, List[Any]]: |
|
|
|
eos_token = "<|end_of_text|>" if data_args.template == "llama3" else tokenizer.eos_token |
|
text_examples = [messages[0]["content"] + eos_token for messages in examples["_prompt"]] |
|
|
|
if not data_args.packing: |
|
if data_args.template == "gemma": |
|
text_examples = [tokenizer.bos_token + example for example in text_examples] |
|
|
|
result = tokenizer(text_examples, add_special_tokens=False, truncation=True, max_length=data_args.cutoff_len) |
|
else: |
|
tokenized_examples = tokenizer(text_examples, add_special_tokens=False) |
|
concatenated_examples = {k: list(chain(*tokenized_examples[k])) for k in tokenized_examples.keys()} |
|
total_length = len(concatenated_examples[list(concatenated_examples.keys())[0]]) |
|
block_size = data_args.cutoff_len |
|
total_length = (total_length // block_size) * block_size |
|
result = { |
|
k: [t[i : i + block_size] for i in range(0, total_length, block_size)] |
|
for k, t in concatenated_examples.items() |
|
} |
|
if data_args.template == "gemma": |
|
for i in range(len(result["input_ids"])): |
|
result["input_ids"][i][0] = tokenizer.bos_token_id |
|
|
|
return result |
|
|