|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from functools import partial |
|
from typing import TYPE_CHECKING, Callable, Literal, Optional, Tuple |
|
|
|
from .processors.feedback import preprocess_feedback_dataset |
|
from .processors.pairwise import preprocess_pairwise_dataset, print_pairwise_dataset_example |
|
from .processors.pretrain import preprocess_pretrain_dataset |
|
from .processors.supervised import ( |
|
preprocess_packed_supervised_dataset, |
|
preprocess_supervised_dataset, |
|
print_supervised_dataset_example, |
|
) |
|
from .processors.unsupervised import preprocess_unsupervised_dataset, print_unsupervised_dataset_example |
|
|
|
|
|
if TYPE_CHECKING: |
|
from transformers import PreTrainedTokenizer, ProcessorMixin |
|
|
|
from ..hparams import DataArguments |
|
from .template import Template |
|
|
|
|
|
def get_preprocess_and_print_func( |
|
data_args: "DataArguments", |
|
stage: Literal["pt", "sft", "rm", "ppo", "kto"], |
|
template: "Template", |
|
tokenizer: "PreTrainedTokenizer", |
|
processor: Optional["ProcessorMixin"], |
|
do_generate: bool = False, |
|
) -> Tuple[Callable, Callable]: |
|
if stage == "pt": |
|
preprocess_func = partial( |
|
preprocess_pretrain_dataset, |
|
tokenizer=tokenizer, |
|
data_args=data_args, |
|
) |
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) |
|
elif stage == "sft" and not do_generate: |
|
if data_args.packing: |
|
if data_args.neat_packing: |
|
from datasets.arrow_writer import OptimizedTypedSequence, TypedSequence |
|
|
|
def __init__(self, data, **kwargs): |
|
return TypedSequence.__init__( |
|
self, |
|
data, |
|
type=kwargs.pop("type", None), |
|
try_type=kwargs.pop("try_type", None), |
|
optimized_int_type=kwargs.pop("optimized_int_type", None), |
|
) |
|
|
|
OptimizedTypedSequence.__init__ = __init__ |
|
preprocess_func = partial( |
|
preprocess_packed_supervised_dataset, |
|
template=template, |
|
tokenizer=tokenizer, |
|
processor=processor, |
|
data_args=data_args, |
|
) |
|
else: |
|
preprocess_func = partial( |
|
preprocess_supervised_dataset, |
|
template=template, |
|
tokenizer=tokenizer, |
|
processor=processor, |
|
data_args=data_args, |
|
) |
|
|
|
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) |
|
elif stage == "rm": |
|
preprocess_func = partial( |
|
preprocess_pairwise_dataset, |
|
template=template, |
|
tokenizer=tokenizer, |
|
processor=processor, |
|
data_args=data_args, |
|
) |
|
print_function = partial(print_pairwise_dataset_example, tokenizer=tokenizer) |
|
elif stage == "kto": |
|
preprocess_func = partial( |
|
preprocess_feedback_dataset, |
|
template=template, |
|
tokenizer=tokenizer, |
|
processor=processor, |
|
data_args=data_args, |
|
) |
|
print_function = partial(print_supervised_dataset_example, tokenizer=tokenizer) |
|
else: |
|
preprocess_func = partial( |
|
preprocess_unsupervised_dataset, |
|
template=template, |
|
tokenizer=tokenizer, |
|
processor=processor, |
|
data_args=data_args, |
|
) |
|
print_function = partial(print_unsupervised_dataset_example, tokenizer=tokenizer) |
|
|
|
return preprocess_func, print_function |
|
|