Iterative Trainer
Iterative fine-tuning is a training method that enables to perform custom actions (generation and filtering for example) between optimization steps. In TRL we provide an easy-to-use API to fine-tune your models in an iterative way in just a few lines of code.
Usage
To get started quickly, instantiate an instance a model, and a tokenizer.
model = AutoModelForCausalLM.from_pretrained(model_name)
tokenizer = AutoTokenizer.from_pretrained(model_name)
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token
trainer = IterativeSFTTrainer(
model,
tokenizer
)
You have the choice to either provide a list of strings or a list of tensors to the step function.
Using a list of tensors as input:
inputs = {
"input_ids": input_ids,
"attention_mask": attention_mask
}
trainer.step(**inputs)
Using a list of strings as input:
inputs = {
"texts": texts
}
trainer.step(**inputs)
For causal language models, labels will automatically be created from input_ids or from texts. When using sequence to sequence models you will have to provide your own labels or text_labels.
IterativeTrainer
class trl.IterativeSFTTrainer
< source >( model: Optional = None args: Optional = None tokenizer: Optional = None optimizers: Tuple = (None, None) data_collator: Optional = None eval_dataset: Union = None max_length: Optional = None truncation_mode: Optional = 'keep_end' preprocess_logits_for_metrics: Optional = None compute_metrics: Optional = None optimize_device_cache: Optional = False )
Parameters
- model (
PreTrainedModel
) — Model to be optimized, either an ‘AutoModelForCausalLM’ or an ‘AutoModelForSeq2SeqLM’. Check the documentation ofPreTrainedModel
for more details. - args (
transformers.TrainingArguments
) — The arguments to use for training. - tokenizer (
PreTrainedTokenizerBase
) — Tokenizer to be used for encoding the data. Check the documentation oftransformers.PreTrainedTokenizer
andtransformers.PreTrainedTokenizerFast
for more details. - optimizers (
Tuple[torch.optim.Optimizer, torch.optim.lr_scheduler.LambdaLR]
) — The optimizer and scheduler to use for training. - data_collator (Union[DataCollatorForLanguageModeling, DataCollatorForSeq2Seq], optional) — Data collator to be used for training and passed along the dataloader.
- eval_dataset (
datasets.Dataset
) — The dataset to use for evaluation. - max_length (
int
, defaults toNone
) — The maximum length of the input. - truncation_mode (
str
, defaults tokeep_end
) — The truncation mode to use, eitherkeep_end
orkeep_start
. - preprocess_logits_for_metrics (
Callable[[torch.Tensor, torch.Tensor], torch.Tensor]
) — The function to use to preprocess the logits before computing the metrics. - compute_metrics (
Callable[[EvalPrediction], Dict]
, optional) — The function to use to compute the metrics. Must take aEvalPrediction
and return a dictionary string to metric values. - optimize_device_cache (
bool
, optional, defaults toFalse
) — Optimize CUDA cache for slightly more memory-efficient training.
The IterativeSFTTrainer can be used to finetune models with methods that requires some steps between optimization.
step
< source >( input_ids: Optional = None attention_mask: Optional = None labels: Optional = None texts: Optional = None texts_labels: Optional = None ) → dict[str, Any]
Parameters
- input_ids (List
torch.LongTensor
) — List of tensors containing the input_ids (if not provided, text will be used) - attention_mask (List
torch.LongTensor
, , optional) — List of tensors containing the attention_mask - labels (List
torch.FloatTensor
, optional) — List of tensors containing the labels (if set to None, will default to input_ids) - texts (List
str
, optional) — List of strings containing the text input (if not provided, input_ids will directly be used) - texts_labels (List
str
, optional) — List of strings containing the text labels (if set to None, will default to text)
Returns
dict[str, Any]
A summary of the training statistics
Run an optimisation step given a list of input_ids, attention_mask, and labels or a list of text and text_labels.