|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
""" |
|
Supervised fine-tuning script for decoder language models. |
|
""" |
|
|
|
import logging |
|
import random |
|
import sys |
|
|
|
import datasets |
|
import torch |
|
import transformers |
|
from transformers import set_seed |
|
|
|
from alignment import ( |
|
DataArguments, |
|
H4ArgumentParser, |
|
ModelArguments, |
|
SFTConfig, |
|
apply_chat_template, |
|
get_checkpoint, |
|
get_datasets, |
|
get_kbit_device_map, |
|
get_peft_config, |
|
get_quantization_config, |
|
get_tokenizer, |
|
) |
|
from trl import SFTTrainer |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
def main(): |
|
parser = H4ArgumentParser((ModelArguments, DataArguments, SFTConfig)) |
|
model_args, data_args, training_args = parser.parse() |
|
|
|
|
|
set_seed(training_args.seed) |
|
|
|
|
|
|
|
|
|
logging.basicConfig( |
|
format="%(asctime)s - %(levelname)s - %(name)s - %(message)s", |
|
datefmt="%Y-%m-%d %H:%M:%S", |
|
handlers=[logging.StreamHandler(sys.stdout)], |
|
) |
|
log_level = training_args.get_process_log_level() |
|
logger.setLevel(log_level) |
|
datasets.utils.logging.set_verbosity(log_level) |
|
transformers.utils.logging.set_verbosity(log_level) |
|
transformers.utils.logging.enable_default_handler() |
|
transformers.utils.logging.enable_explicit_format() |
|
|
|
|
|
logger.warning( |
|
f"Process rank: {training_args.local_rank}, device: {training_args.device}, n_gpu: {training_args.n_gpu}" |
|
+ f" distributed training: {bool(training_args.local_rank != -1)}, 16-bits training: {training_args.fp16}" |
|
) |
|
logger.info(f"Model parameters {model_args}") |
|
logger.info(f"Data parameters {data_args}") |
|
logger.info(f"Training/evaluation parameters {training_args}") |
|
|
|
|
|
last_checkpoint = get_checkpoint(training_args) |
|
if last_checkpoint is not None and training_args.resume_from_checkpoint is None: |
|
logger.info(f"Checkpoint detected, resuming training at {last_checkpoint=}.") |
|
|
|
|
|
|
|
|
|
raw_datasets = get_datasets(data_args, splits=data_args.dataset_splits) |
|
logger.info( |
|
f"Training on the following datasets and their proportions: {[split + ' : ' + str(dset.num_rows) for split, dset in raw_datasets.items()]}" |
|
) |
|
column_names = list(raw_datasets["train"].features) |
|
if "messages" not in column_names: |
|
with training_args.main_process_first(desc="Log a few random samples from the processed training set"): |
|
def format_messages(example): |
|
messages = [] |
|
for idx, message in enumerate(example["data"]): |
|
role = "user" if idx % 2 == 0 else "assistant" |
|
messages.append({"content": message, "role": role}) |
|
example["messages"] = messages |
|
return example |
|
|
|
raw_datasets = raw_datasets.map(format_messages, desc="Formatting messages", num_proc=data_args.preprocessing_num_workers) |
|
|
|
|
|
|
|
|
|
tokenizer = get_tokenizer(model_args, data_args) |
|
|
|
|
|
|
|
|
|
with training_args.main_process_first(): |
|
raw_datasets = raw_datasets.map( |
|
apply_chat_template, |
|
fn_kwargs={"tokenizer": tokenizer, "task": "sft"}, |
|
num_proc=data_args.preprocessing_num_workers, |
|
remove_columns=column_names, |
|
desc="Applying chat template", |
|
) |
|
train_dataset = raw_datasets["train"] |
|
eval_dataset = raw_datasets["test"] |
|
|
|
with training_args.main_process_first(desc="Log a few random samples from the processed training set"): |
|
for index in random.sample(range(len(raw_datasets["train"])), 3): |
|
logger.info(f"Sample {index} of the processed training set:\n\n{raw_datasets['train'][index]['text']}") |
|
|
|
|
|
|
|
|
|
logger.info("*** Load pretrained model ***") |
|
torch_dtype = ( |
|
model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) |
|
) |
|
quantization_config = get_quantization_config(model_args) |
|
|
|
model_kwargs = dict( |
|
revision=model_args.model_revision, |
|
trust_remote_code=model_args.trust_remote_code, |
|
use_flash_attention_2=model_args.use_flash_attention_2, |
|
torch_dtype=torch_dtype, |
|
use_cache=False if training_args.gradient_checkpointing else True, |
|
device_map=get_kbit_device_map() if quantization_config is not None else None, |
|
quantization_config=quantization_config, |
|
) |
|
logger.info("*** Model loaded! ***") |
|
|
|
|
|
|
|
|
|
trainer = SFTTrainer( |
|
model=model_args.model_name_or_path, |
|
model_init_kwargs=model_kwargs, |
|
args=training_args, |
|
train_dataset=train_dataset, |
|
eval_dataset=eval_dataset, |
|
dataset_text_field="text", |
|
max_seq_length=training_args.max_seq_length, |
|
tokenizer=tokenizer, |
|
packing=True, |
|
peft_config=get_peft_config(model_args), |
|
) |
|
|
|
|
|
|
|
|
|
logger.info("*** Train ***") |
|
checkpoint = None |
|
if training_args.resume_from_checkpoint is not None: |
|
checkpoint = training_args.resume_from_checkpoint |
|
elif last_checkpoint is not None: |
|
checkpoint = last_checkpoint |
|
train_result = trainer.train(resume_from_checkpoint=checkpoint) |
|
metrics = train_result.metrics |
|
metrics["train_samples"] = len(train_dataset) |
|
trainer.log_metrics("train", metrics) |
|
trainer.save_metrics("train", metrics) |
|
trainer.save_state() |
|
|
|
|
|
|
|
|
|
if training_args.do_eval: |
|
logger.info("*** Evaluate ***") |
|
metrics = trainer.evaluate() |
|
metrics["eval_samples"] = len(eval_dataset) |
|
trainer.log_metrics("eval", metrics) |
|
trainer.save_metrics("eval", metrics) |
|
|
|
|
|
|
|
|
|
logger.info("*** Save model ***") |
|
trainer.save_model(training_args.output_dir) |
|
logger.info(f"Model saved to {training_args.output_dir}") |
|
|
|
|
|
kwargs = { |
|
"finetuned_from": model_args.model_name_or_path, |
|
"dataset": list(data_args.dataset_mixer.keys()), |
|
"dataset_tags": list(data_args.dataset_mixer.keys()), |
|
"tags": ["alignment-handbook"], |
|
} |
|
if trainer.accelerator.is_main_process: |
|
trainer.create_model_card(**kwargs) |
|
|
|
trainer.model.config.use_cache = True |
|
trainer.model.config.save_pretrained(training_args.output_dir) |
|
|
|
if training_args.push_to_hub is True: |
|
logger.info("Pushing to hub...") |
|
trainer.push_to_hub(**kwargs) |
|
|
|
logger.info("*** Training complete ***") |
|
|
|
|
|
if __name__ == "__main__": |
|
main() |
|
|
|
|