TRL documentation

Command Line Interfaces (CLIs)

You are viewing v0.12.0 version. A newer version v0.12.1 is available.
Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Command Line Interfaces (CLIs)

You can use TRL to fine-tune your Language Model with Supervised Fine-Tuning (SFT) or Direct Policy Optimization (DPO) or even chat with your model using the TRL CLIs.

Currently supported CLIs are:

  • trl sft: fine-tune a LLM on a text/instruction dataset
  • trl dpo: fine-tune a LLM with DPO on a preference dataset
  • trl chat: quickly spin up a LLM fine-tuned for chatting
  • trl env: get the system information

Fine-tuning with the CLI

Before getting started, pick up a Language Model from Hugging Face Hub. Supported models can be found with the filter “text-generation” within models. Also make sure to pick up a relevant dataset for your task.

Before using the sft or dpo commands make sure to run:

accelerate config

and pick up the right configuration for your training setup (single / multi-GPU, DeepSpeed, etc.). Make sure to complete all steps of accelerate config before running any CLI command.

We also recommend you passing a YAML config file to configure your training protocol. Below is a simple example of a YAML file that you can use for training your models with trl sft command.

model_name_or_path:
  trl-internal-testing/tiny-random-LlamaForCausalLM
dataset_name:
  stanfordnlp/imdb
report_to:
  none
learning_rate:
  0.0001
lr_scheduler_type:
  cosine

Save that config in a .yaml and get started immediately! An example CLI config is available as examples/cli_configs/example_config.yaml. Note you can overwrite the arguments from the config file by explicitly passing them to the CLI, e.g. from the root folder:

trl sft --config examples/cli_configs/example_config.yaml --output_dir test-trl-cli --lr_scheduler_type cosine_with_restarts

Will force-use cosine_with_restarts for lr_scheduler_type.

Supported Arguments

We do support all arguments from transformers.TrainingArguments, for loading your model, we support all arguments from ~trl.ModelConfig:

class trl.ModelConfig

< >

( model_name_or_path: Optional = None model_revision: str = 'main' torch_dtype: Optional = None trust_remote_code: bool = False attn_implementation: Optional = None use_peft: bool = False lora_r: int = 16 lora_alpha: int = 32 lora_dropout: float = 0.05 lora_target_modules: Optional = None lora_modules_to_save: Optional = None lora_task_type: str = 'CAUSAL_LM' use_rslora: bool = False load_in_8bit: bool = False load_in_4bit: bool = False bnb_4bit_quant_type: Literal = 'nf4' use_bnb_nested_quant: bool = False )

Parameters

  • model_name_or_path (Optional[str], optional, defaults to None) — Model checkpoint for weights initialization.
  • model_revision (str, optional, defaults to "main") — Specific model version to use. It can be a branch name, a tag name, or a commit id.
  • torch_dtype (Optional[Literal["auto", "bfloat16", "float16", "float32"]], optional, defaults to None) — Override the default torch.dtype and load the model under this dtype. Possible values are

    • "bfloat16": torch.bfloat16
    • "float16": torch.float16
    • "float32": torch.float32
    • "auto": Automatically derive the dtype from the model’s weights.
  • trust_remote_code (bool, optional, defaults to False) — Whether to allow for custom models defined on the Hub in their own modeling files. This option should only be set to True for repositories you trust and in which you have read the code, as it will execute code present on the Hub on your local machine.
  • attn_implementation (Optional[str], optional, defaults to None) — Which attention implementation to use. You can run --attn_implementation=flash_attention_2, in which case you must install this manually by running pip install flash-attn --no-build-isolation.
  • use_peft (bool, optional, defaults to False) — Whether to use PEFT for training.
  • lora_r (int, optional, defaults to 16) — LoRA R value.
  • lora_alpha (int, optional, defaults to 32) — LoRA alpha.
  • lora_dropout (float, optional, defaults to 0.05) — LoRA dropout.
  • lora_target_modules (Optional[Union[str, List[str]]], optional, defaults to None) — LoRA target modules.
  • lora_modules_to_save (Optional[List[str]], optional, defaults to None) — Model layers to unfreeze & train.
  • lora_task_type (str, optional, defaults to "CAUSAL_LM") — Task type to pass for LoRA (use "SEQ_CLS" for reward modeling).
  • use_rslora (bool, optional, defaults to False) — Whether to use Rank-Stabilized LoRA, which sets the adapter scaling factor to lora_alpha/√r, instead of the original default value of lora_alpha/r.
  • load_in_8bit (bool, optional, defaults to False) — Whether to use 8 bit precision for the base model. Works only with LoRA.
  • load_in_4bit (bool, optional, defaults to False) — Whether to use 4 bit precision for the base model. Works only with LoRA.
  • bnb_4bit_quant_type (str, optional, defaults to "nf4") — Quantization type ("fp4" or "nf4").
  • use_bnb_nested_quant (bool, optional, defaults to False) — Whether to use nested quantization.

Configuration class for the models.

Using HfArgumentParser we can turn this class into argparse arguments that can be specified on the command line.

You can pass any of these arguments either to the CLI or the YAML file.

Supervised Fine-tuning (SFT)

Follow the basic instructions above and run trl sft --output_dir <output_dir> <*args>:

trl sft --model_name_or_path facebook/opt-125m --dataset_name stanfordnlp/imdb --output_dir opt-sft-imdb

The SFT CLI is based on the examples/scripts/sft.py script.

Direct Policy Optimization (DPO)

To use the DPO CLI, you need to have a dataset in the TRL format such as

These datasets always have at least three columns prompt, chosen, rejected:

  • prompt is a list of strings.
  • chosen is the chosen response in chat format
  • rejected is the rejected response chat format

To do a quick start, you can run the following command:

trl dpo --model_name_or_path facebook/opt-125m --output_dir trl-hh-rlhf --dataset_name trl-internal-testing/hh-rlhf-helpful-base-trl-style

The DPO CLI is based on the examples/scripts/dpo.py script.

Custom preference dataset

Format the dataset into TRL format (you can adapt the examples/datasets/anthropic_hh.py):

python examples/datasets/anthropic_hh.py --push_to_hub --hf_entity your-hf-org

Chat interface

The chat CLI lets you quickly load the model and talk to it. Simply run the following:

$ trl chat --model_name_or_path Qwen/Qwen1.5-0.5B-Chat 
<quentin_gallouedec>:
What is the best programming language?

<Qwen/Qwen1.5-0.5B-Chat>:
There isn't a "best" programming language, as everyone has different style preferences, needs, and preferences. However, some people commonly use   
languages like Python, Java, C++, and JavaScript, which are popular among developers for a variety of reasons, including readability, flexibility,  
and scalability. Ultimately, it depends on personal preference, needs, and goals.

Note that the chat interface relies on the tokenizer’s chat template to format the inputs for the model. Make sure your tokenizer has a chat template defined.

Besides talking to the model there are a few commands you can use:

  • clear: clears the current conversation and start a new one
  • example {NAME}: load example named {NAME} from the config and use it as the user input
  • set {SETTING_NAME}={SETTING_VALUE};: change the system prompt or generation settings (multiple settings are separated by a ;).
  • reset: same as clear but also resets the generation configs to defaults if they have been changed by set
  • save or save {SAVE_NAME}: save the current chat and settings to file by default to ./chat_history/{MODEL_NAME}/chat_{DATETIME}.yaml or {SAVE_NAME} if provided
  • exit: closes the interface

The default examples are defined in examples/scripts/config/default_chat_config.yaml but you can pass your own with --config CONFIG_FILE where you can also specify the default generation parameters.

Getting the system information

You can get the system information by running the following command:

trl env

This will print out the system information including the GPU information, the CUDA version, the PyTorch version, the transformers version, and the TRL version, and any optional dependencies that are installed.

Copy-paste the following information when reporting an issue:

- Platform: Linux-5.15.0-1048-aws-x86_64-with-glibc2.31
- Python version: 3.11.9
- PyTorch version: 2.4.1
- CUDA device: NVIDIA H100 80GB HBM3
- Transformers version: 4.45.0.dev0
- Accelerate version: 0.34.2
- Accelerate config: 
  - compute_environment: LOCAL_MACHINE
  - distributed_type: DEEPSPEED
  - mixed_precision: no
  - use_cpu: False
  - debug: False
  - num_processes: 4
  - machine_rank: 0
  - num_machines: 1
  - rdzv_backend: static
  - same_network: True
  - main_training_function: main
  - enable_cpu_affinity: False
  - deepspeed_config: {'gradient_accumulation_steps': 4, 'offload_optimizer_device': 'none', 'offload_param_device': 'none', 'zero3_init_flag': False, 'zero_stage': 2}
  - downcast_bf16: no
  - tpu_use_cluster: False
  - tpu_use_sudo: False
  - tpu_env: []
- Datasets version: 3.0.0
- HF Hub version: 0.24.7
- TRL version: 0.12.0.dev0+acb4d70
- bitsandbytes version: 0.41.1
- DeepSpeed version: 0.15.1
- Diffusers version: 0.30.3
- Liger-Kernel version: 0.3.0
- LLM-Blender version: 0.0.2
- OpenAI version: 1.46.0
- PEFT version: 0.12.0

This information are required when reporting an issue.

< > Update on GitHub