AWS Trainium & Inferentia documentation

Supervised Fine-Tuning of Llama 3 8B on one AWS Trainium instance

Hugging Face's logo
Join the Hugging Face community

and get access to the augmented documentation experience

to get started

Supervised Fine-Tuning of Llama 3 8B on one AWS Trainium instance

Note: The complete script for this tutorial can be downloaded here.

This tutorial will teach you how to fine-tune open source LLMs like Llama 3 on AWS Trainium. In our example, we are going to leverage the Optimum Neuron, Transformers and Datasets libraries.

You will learn how to:

  1. Setup AWS Environment
  2. Load and process the dataset
  3. Supervised Fine-Tuning of Llama on AWS Trainium with the NeuronSFTTrainer
  4. Launch Training
  5. Evaluate and test fine-tuned Llama model

While we will use Llama-3 8B in this tutorial, it is completely possible to use other models, simply by swtiching the model_id.

1. Setup AWS Environment

Before starting this tutorial, you will need to setup your environment:

  1. Create an AWS Trainium instance. You will need a trn1.32xlarge, which contains 16 Neuron Devices. You can follow this guide to create one.
  2. Make sure you are logged in on the Hugging Face Hub:
huggingface-cli login --token YOUR_TOKEN
  1. Check that you have access to the model. Some open source models are gated, meaning that users need to apply to the model owner to be able to use the model weights. Here we will be training Llama-3 8B, for which there are two possibilities:
  1. Clone the Optimum Neuron repository, which contains the complete script described in this tutorial:
git clone https://github.com/huggingface/optimum-neuron.git

2. Load and prepare the dataset

For this tutorial, we will use Dolly, an open source dataset of instruction-following records on categories outlined in the InstructGPT paper, including brainstorming, classification, closed QA, generation, information extraction, open QA, and summarization.

Example:

{
  "instruction": "What is world of warcraft",
  "context": "",
  "response": (
        "World of warcraft is a massive online multi player role playing game. "
        "It was released in 2004 by blizarre entertainment"
    )
}

We can use the load_dataset() method from the 🤗 Datasets library to load the dolly dataset very easily.

from datasets import load_dataset
from random import randrange

# Load dataset from the hub
dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

print(f"dataset size: {len(dataset)}")
print(dataset[randrange(len(dataset))])
# dataset size: 15011

To instruct fine-tune our model we need to:

  1. Convert our structured examples into collection of tasks described via instructions

  2. (Optional) Pack multiple examples to one sequence for more efficient training. In other words, we are stacking multiple examples into one example, and split them with the EOS token.

We could do this manually, but we will use the NeuronSFTTrainer instead.

3. Supervised Fine-Tuning of Llama on AWS Trainium with the NeuronSFTTrainer

Normally you would use the SFTConfig and SFTTrainer classes to perform supervised fine-tuning of PyTorch-based transformer models.

Instead, here we will be using the NeuronSFTConfig and NeuronSFTTrainer. These classes replicate the ones from the trl library while making sure they work properly on Neuron cores.

Formatting our dataset

There are multiple ways to give a dataset to the NeuronSFTTrainer, and one of them consists in providing a formatting function. For dolly without packing the examples it looks as follows:

def format_dolly(examples):
    output_text = []
    for i in range(len(examples["instruction"])):
        instruction = f"### Instruction\n{examples['instruction'][i]}"
        context = f"### Context\n{examples['context'][i]}" if len(examples["context"][i]) > 0 else None
        response = f"### Answer\n{examples['response'][i]}"
        prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
        output_text.append(prompt)
    return output_text

Preparing the model

Since Llama-3 8B is a big model it will not fit on a single trn1.32xlarge instance, even with distributed training. To actually fine-tune a 8B model using only one Trainium instance we need to use both LoRA and distributed training.

If you want to know more about distributed training you can take a look at the documentation.

Here, we will use tensor parallelism in conjuction with LoRA. Our training code will look as follows:

from peft import LoraConfig
from optimum.neuron import NeuronSFTConfig, NeuronSFTTrainer
from optimum.neuron.distributed import lazy_load_for_parallelism

# Define the tensor_parallel_size
tensor_parallel_size = 2

dataset = load_dataset("databricks/databricks-dolly-15k", split="train")

model_id = "meta-llama/Meta-Llama-3-8B"

tokenizer = AutoTokenizer.from_pretrained(model_id)
tokenizer.pad_token = tokenizer.eos_token

with lazy_load_for_parallelism(tensor_parallel_size=tensor_parallel_size):
    model = AutoModelForCausalLM.from_pretrained(model_id)

config = LoraConfig(
    r=16,
    lora_alpha=16,
    lora_dropout=0.05,
    target_modules=[
        "q_proj",
        "gate_proj",
        "v_proj",
        "o_proj",
        "k_proj",
        "up_proj",
        "down_proj"
    ],
    bias="none",
    task_type="CAUSAL_LM",
)

# training_args is an instance of NeuronTrainingArguments
args = training_args.to_dict()
sft_config = NeuronSFTConfig(
    max_seq_length=1024,
    packing=False,
    **args,
)

trainer = NeuronSFTTrainer(
    args=sft_config,
    model=model,
    peft_config=config,
    tokenizer=tokenizer,
    train_dataset=dataset,
    formatting_func=format_dolly,
)

# Start training
trainer.train()

trainer.save_model()  # Saves the tokenizer too for easy upload

The key points here are:

  • We use the lazy_load_for_parallelism context manager to lazily load the model. This will not load the full model weights on each worker, but instead only load the required weights (sharded or full). This is much more memory efficient, and often mandatory to use.
  • We define a LoraConfig that specifies which layers should have adapters, and the hyperparameters for theses adapters.
  • We create a NeuronSFTConfig from regular NeuronTrainingArguments. Here we specify that we do not want to pack our examples, and that the max sequence length should be 1024, meaning that every example will be either padded or truncated to a length of 1024.
  • We use the NeuronSFTTrainer to perform training. It will take the lazily loaded model, along with lora_config, sft_config and format_dolly and prepare the dataset and model for supervised fine-tuning.

4. Launch Training

We prepared a script called sft_lora_finetune_llm.py summing up everything mentioned in this tutorial.

PyTorch Neuron uses torch_xla. It evaluates operations lazily during the execution of the training loops, which means it builds a symbolic graph in the background, and the graph is executed on the hardware only when the tensor is printed, transferred to CPU, or when xm.mark_step() is called. During execution, multiple graphs can be built depending on control-flow, and it can take time to compile each graph sequentially. To alleviate that, the Neuron SDK provides neuron_parallel_compile, a tool which performs a fast trial run that builds all the graphs and compile them in parallel. This step is usually called precompilation.

Precompilation

When training models on AWS Trainium we first need to compile our model with our training arguments.

To ease this step, we added a model cache repository, which allows us to use precompiled models from the Hugging Face Hub to skip the compilation step. But be careful: every change in the model configuration might lead to a new compilation, which could result in some cache misses.

To learn more about the caching system, and how you can create your own private cache repository, check this guide.

The compilation command simply consists in calling your script as an input to the neuron_parallel_compile utility:

#!/bin/bash
set -ex

export NEURON_FUSE_SOFTMAX=1
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3
export MALLOC_ARENA_MAX=64
export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/"

PROCESSES_PER_NODE=8

NUM_EPOCHS=1
TP_DEGREE=2
PP_DEGREE=1
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=1
MODEL_NAME="meta-llama/Meta-Llama-3-8B"
OUTPUT_DIR=output-$SLURM_JOB_ID

if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then
    MAX_STEPS=$((LOGGING_STEPS + 5))
else
    MAX_STEPS=-1
fi


XLA_USE_BF16=1 neuron_parallel_compile torchrun --nproc_per_node $PROCESSES_PER_NODE docs/source/training_tutorials/sft_lora_finetune_llm.py \
  --model_id $MODEL_NAME \
  --num_train_epochs $NUM_EPOCHS \
  --do_train \
  --learning_rate 5e-5 \
  --warmup_ratio 0.03 \
  --max_steps $MAX_STEPS \
  --per_device_train_batch_size $BS \
  --per_device_eval_batch_size $BS \
  --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
  --gradient_checkpointing true \
  --bf16 \
  --zero_1 false \
  --tensor_parallel_size $TP_DEGREE \
  --pipeline_parallel_size $PP_DEGREE \
  --logging_steps $LOGGING_STEPS \
  --save_total_limit 1 \
  --output_dir $OUTPUT_DIR \
  --lr_scheduler_type "constant" \
  --overwrite_output_dir

Make sure to run this precompilation phase for around 10 training steps. It is usually enough to accumulate and compile all the graphs that will be needed during the actual training.

Note: Compiling without a cache can take a while. It will also create dummy files in the dolly_llama_sharded during compilation you will have to remove them afterwards. We also need to add MALLOC_ARENA_MAX=64 to limit the CPU allocation to avoid potential crashes, don’t remove it for now.

# remove dummy artifacts which are created by the precompilation command
rm -rf dolly_llama

Actual Training

After compilation is done we can start our actual training with a similar command, we just need to remove the use of neuron_parallel_compile.

We will use torchrun to launch our training script. torchrun is a tool that automatically distributes a PyTorch model across multiple accelerators. We can pass the number of accelerators as nproc_per_node arguments alongside our hyperparameters.

The difference to the compilation command is that we changed from max_steps=10 to num_train_epochs=3.

Launch the training, with the following command.

#!/bin/bash
set -ex

export NEURON_FUSE_SOFTMAX=1
export NEURON_RT_ASYNC_EXEC_MAX_INFLIGHT_REQUESTS=3
export MALLOC_ARENA_MAX=64
export NEURON_CC_FLAGS="--model-type=transformer --distribution-strategy=llm-training --enable-saturate-infinity --cache_dir=/home/ubuntu/cache_dir_neuron/"

PROCESSES_PER_NODE=8

NUM_EPOCHS=1
TP_DEGREE=2
PP_DEGREE=1
BS=1
GRADIENT_ACCUMULATION_STEPS=8
LOGGING_STEPS=1
MODEL_NAME="meta-llama/Meta-Llama-3-8B"
OUTPUT_DIR=output-$SLURM_JOB_ID

if [ "$NEURON_EXTRACT_GRAPHS_ONLY" = "1" ]; then
    MAX_STEPS=$((LOGGING_STEPS + 5))
else
    MAX_STEPS=-1
fi


XLA_USE_BF16=1 torchrun --nproc_per_node $PROCESSES_PER_NODE docs/source/training_tutorials/sft_lora_finetune_llm.py \
  --model_id $MODEL_NAME \
  --num_train_epochs $NUM_EPOCHS \
  --do_train \
  --learning_rate 5e-5 \
  --warmup_ratio 0.03 \
  --max_steps $MAX_STEPS \
  --per_device_train_batch_size $BS \
  --per_device_eval_batch_size $BS \
  --gradient_accumulation_steps $GRADIENT_ACCUMULATION_STEPS \
  --gradient_checkpointing true \
  --bf16 \
  --zero_1 false \
  --tensor_parallel_size $TP_DEGREE \
  --pipeline_parallel_size $PP_DEGREE \
  --logging_steps $LOGGING_STEPS \
  --save_total_limit 1 \
  --output_dir $OUTPUT_DIR \
  --lr_scheduler_type "constant" \
  --overwrite_output_dir

That’s it, we successfully trained Llama-3 8B on AWS Trainium!

But before we can share and test our model we need to consolidate our model. Since we used tensor parallelism during training, we saved sharded versions of the checkpoints. We need to consolidate them now.

Consolidate the Checkpoint

The Optimum CLI provides a way of doing that very easily via the optimum neuron consolidate [sharded_checkpoint] [output_dir] command:

optimum-cli neuron consolidate dolly_llama dolly_llama

5. Evaluate and test fine-tuned Llama model

As for training, to be able to run inference on AWS Trainium or AWS Inferentia2 we need to compile our model. In this case, we will use our Trainium instance for the inference test, but we recommend customer to switch to Inferentia2 (inf2.24xlarge) for inference.

Optimum Neuron implements similar to Transformers AutoModel classes for easy inference use. We will use the NeuronModelForCausalLM class to load our vanilla transformers checkpoint and convert it to neuron.

from optimum.neuron import NeuronModelForCausalLM
from transformers import AutoTokenizer

compiler_args = {"num_cores": 2, "auto_cast_type": 'fp16'}
input_shapes = {"batch_size": 1, "sequence_length": 2048}

tokenizer = AutoTokenizer.from_pretrained("dolly_llama")
model = NeuronModelForCausalLM.from_pretrained(
        "dolly_llama",
        export=True,
        **compiler_args,
        **input_shapes)

Note: Inference compilation can take ~25minutes. Luckily, you need to only run this onces. Since you can save the model afterwards. If you are going to run on Inferentia2 you need to recompile again. The compilation is parameter and hardware specific.

# COMMENT IN if you want to save the compiled model
# model.save_pretrained("compiled_dolly_llama")

We can now test inference, but have to make sure we format our input to our prompt format we used for fine-tuning. Therefore we created a helper method, which accepts a dict with our instruction and optionally a context.

def format_dolly_inference(sample):
    instruction = f"### Instruction\n{sample['instruction']}"
    context = f"### Context\n{sample['context']}" if "context" in sample else None
    response = f"### Answer\n"
    prompt = "\n\n".join([i for i in [instruction, context, response] if i is not None])
    return prompt


def generate(sample):
    prompt = format_dolly_inference(sample)
    inputs = tokenizer(prompt, return_tensors="pt")
    outputs = model.generate(
        **inputs,
        max_new_tokens=512,
        do_sample=True,
        temperature=0.9,
        top_k=50,
        top_p=0.9
    )
    return tokenizer.decode(outputs[0], skip_special_tokens=False)[len(prompt):]

Let’s test inference. First we test without a context.

Note: Inference is not expected to be super fast on AWS Trainium using 2 cores. For Inference we recommend using Inferentia2.

prompt = {
  "instruction": "Can you tell me something about AWS?"
}
res = generate(prompt)

print(res)

AWS stands for Amazon Web Services. AWS is a suite of remote computing services offered by Amazon. The most widely used of these include Amazon Elastic Compute Cloud (Amazon EC2), which provides resizable compute capacity in the cloud; Amazon Simple Storage Service (Amazon S3), which is an object storage service; and Amazon Elastic Block Store (Amazon EBS), which is designed to provide high performance, durable block storage volumes for use with AWS instances. AWS also provides other services, such as AWS Identity and Access Management (IAM), a service that enables organizations to control access to their AWS resources, and AWS Key Management Service (AWS KMS), which helps customers create and control the use of encryption keys.

That looks correct. Now, lets add some context, e.g. as you would do for RAG applications:

prompt = {
  "instruction": "How can I train models on AWS Trainium?",
  "context": "🤗 Optimum Neuron is the interface between the 🤗 Transformers library and AWS Accelerators including [AWS Trainium](https://aws.amazon.com/machine-learning/trainium/?nc1=h_ls) and [AWS Inferentia](https://aws.amazon.com/machine-learning/inferentia/?nc1=h_ls). It provides a set of tools enabling easy model loading, training and inference on single- and multi-Accelerator settings for different downstream tasks."
}
res = generate(prompt)

print(res)

You can use the Optimum Neuron interface to train models on AWS Trainium.

Awesome, our model also correctly uses the provided context. We are done. Congrats on fine-tuning Llama on AWS Trainium.