|
|
|
""" |
|
pip install -U transformers accelerate trl wandb wheel packaging peft bitsandbytes liger-kernel flash_attn |
|
|
|
python sft.py \ |
|
--run_name="llama3.1-8b-continued2" \ |
|
--model_name_or_path="meta-llama/Meta-Llama-3.1-8B" \ |
|
--dataset_name="mlfoundations/dclm-baseline-1.0-parquet,mlabonne/FineTome-100k" \ |
|
--report_to="wandb" \ |
|
--optim="adamw_torch_fused" \ |
|
--lr_scheduler_type="cosine" \ |
|
--max_steps=10000000 \ |
|
--max_seq_length=64000 \ |
|
--learning_rate=0.0001 \ |
|
--attn_implementation="flash_attention_2" \ |
|
--save_strategy="steps" \ |
|
--save_steps 50 \ |
|
--save_total_limit=10 \ |
|
--per_device_train_batch_size=1 \ |
|
--gradient_accumulation_steps=8 \ |
|
--logging_steps=1 \ |
|
--num_train_epochs=1 \ |
|
--load_in_4bit \ |
|
--push_to_hub \ |
|
--hub_model_id="ericflo/Llama-3.1-8B-ContinuedTraining2-LoRA" \ |
|
--hub_strategy="all_checkpoints" \ |
|
--gradient_checkpointing \ |
|
--use_peft \ |
|
--lora_r=128 \ |
|
--lora_alpha=256 \ |
|
--lora_dropout=0.05 \ |
|
--use_liger=true \ |
|
--packing=true \ |
|
--torch_dtype="bfloat16" \ |
|
--output_dir="continuedtraining2_output" |
|
""" |
|
|
|
import logging |
|
import os |
|
import random |
|
from contextlib import nullcontext |
|
|
|
from trl.commands.cli_utils import init_zero_verbose, SFTScriptArguments, TrlParser |
|
from trl.env_utils import strtobool |
|
|
|
TRL_USE_RICH = strtobool(os.getenv("TRL_USE_RICH", "0")) |
|
|
|
if TRL_USE_RICH: |
|
init_zero_verbose() |
|
FORMAT = "%(message)s" |
|
|
|
from rich.console import Console |
|
from rich.logging import RichHandler |
|
|
|
import torch |
|
from datasets import load_dataset, interleave_datasets |
|
|
|
from tqdm.rich import tqdm |
|
from transformers import AutoTokenizer |
|
|
|
from trl import ( |
|
ModelConfig, |
|
RichProgressCallback, |
|
SFTConfig, |
|
SFTTrainer, |
|
get_peft_config, |
|
get_quantization_config, |
|
get_kbit_device_map, |
|
) |
|
|
|
tqdm.pandas() |
|
|
|
if TRL_USE_RICH: |
|
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()], level=logging.INFO) |
|
|
|
print("Loading tokenizers...") |
|
METAML_TOK = AutoTokenizer.from_pretrained("meta-llama/Meta-Llama-3.1-8B-Instruct") |
|
CHATML_TOK = AutoTokenizer.from_pretrained("NousResearch/Hermes-3-Llama-3.1-8B") |
|
print("Tokenizers loaded.") |
|
|
|
def formatting_prompts_func(example): |
|
try: |
|
language = example.get('language') |
|
url = example.get('url') |
|
text = example.get('text') |
|
title = example.get('title') |
|
conversations = example.get('conversations') |
|
source = example.get('source') |
|
repo_name = example.get('max_stars_repo_name') |
|
repo_path = example.get('max_stars_repo_path') |
|
star_count = example.get('max_stars_count') |
|
content = example.get('content') |
|
|
|
if language and url and text: |
|
return f'{language} {url} {text}' |
|
elif title and url and text: |
|
return f'{title} {url} {text}' |
|
elif conversations: |
|
rows = [{ |
|
"role": {"system": "system", "gpt": "assistant", "human": "user"}[row["from"]], |
|
"content": row["value"], |
|
} for row in conversations] |
|
tok = random.choice([METAML_TOK, CHATML_TOK]) |
|
return f'{source} {tok.apply_chat_template(rows, tokenize=False)}' |
|
elif "max_stars_repo_name" in example: |
|
return f'{example["max_stars_repo_name"]} {example["max_stars_repo_path"]} {example["max_stars_count"]} {example["content"]}' |
|
print(f"Unknown example: {example}") |
|
raise ValueError(f"Unknown example: {example}") |
|
except Exception as e: |
|
print(e) |
|
raise e |
|
|
|
if __name__ == "__main__": |
|
parser = TrlParser((SFTScriptArguments, SFTConfig, ModelConfig)) |
|
args, training_args, model_config = parser.parse_args_and_config() |
|
|
|
|
|
if TRL_USE_RICH: |
|
training_args.disable_tqdm = True |
|
console = Console() |
|
|
|
|
|
|
|
|
|
model_config.lora_target_modules = ["q_proj", "k_proj", "v_proj", "o_proj", "gate_proj", "up_proj", "down_proj"] |
|
quantization_config = get_quantization_config(model_config) |
|
model_kwargs = dict( |
|
revision=model_config.model_revision, |
|
trust_remote_code=model_config.trust_remote_code, |
|
attn_implementation=model_config.attn_implementation, |
|
torch_dtype=model_config.torch_dtype, |
|
use_cache=False if training_args.gradient_checkpointing else True, |
|
device_map=get_kbit_device_map() if quantization_config is not None else None, |
|
quantization_config=quantization_config, |
|
) |
|
training_args.model_init_kwargs = model_kwargs |
|
tokenizer = AutoTokenizer.from_pretrained( |
|
model_config.model_name_or_path, trust_remote_code=model_config.trust_remote_code, use_fast=True |
|
) |
|
tokenizer.pad_token = tokenizer.eos_token |
|
|
|
|
|
|
|
|
|
dataset_names = args.dataset_name.split(',') |
|
train_datasets = [load_dataset(name, split="train", streaming=True) for name in dataset_names] |
|
train_datasets.append(load_dataset("bigcode/starcoderdata", data_dir="python", split="train", streaming=True)) |
|
train_datasets.append(load_dataset("wikimedia/wikipedia", "20231101.en", split="train", streaming=True)) |
|
train_datasets.append(load_dataset("wikimedia/wikipedia", "20231101.es", split="train", streaming=True)) |
|
train_datasets.append(load_dataset("wikimedia/wikipedia", "20231101.fr", split="train", streaming=True)) |
|
interleaved_dataset = interleave_datasets(train_datasets) |
|
eval_dataset = interleaved_dataset.take(100) |
|
train_dataset = interleaved_dataset.skip(100) |
|
|
|
print(train_dataset) |
|
print(eval_dataset) |
|
|
|
|
|
|
|
|
|
init_context = nullcontext() if not TRL_USE_RICH else console.status("[bold green]Initializing the SFTTrainer...") |
|
save_context = ( |
|
nullcontext() |
|
if not TRL_USE_RICH |
|
else console.status(f"[bold green]Training completed! Saving the model to {training_args.output_dir}") |
|
) |
|
|
|
|
|
|
|
|
|
with init_context: |
|
trainer = SFTTrainer( |
|
model=model_config.model_name_or_path, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
tokenizer=tokenizer, |
|
peft_config=get_peft_config(model_config), |
|
callbacks=[RichProgressCallback] if TRL_USE_RICH else None, |
|
formatting_func=formatting_prompts_func, |
|
) |
|
|
|
trainer.train() |
|
|
|
with save_context: |
|
trainer.save_model(training_args.output_dir) |