|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import json |
|
import logging |
|
import os |
|
import signal |
|
import sys |
|
import time |
|
from concurrent.futures import ThreadPoolExecutor |
|
from datetime import timedelta |
|
from typing import TYPE_CHECKING, Any, Dict, Optional |
|
|
|
import torch |
|
import transformers |
|
from peft import PeftModel |
|
from transformers import PreTrainedModel, ProcessorMixin, TrainerCallback |
|
from transformers.trainer_utils import PREFIX_CHECKPOINT_DIR, has_length |
|
from transformers.utils import ( |
|
SAFE_WEIGHTS_NAME, |
|
WEIGHTS_NAME, |
|
is_safetensors_available, |
|
) |
|
from typing_extensions import override |
|
|
|
from ..extras.constants import TRAINER_LOG, V_HEAD_SAFE_WEIGHTS_NAME, V_HEAD_WEIGHTS_NAME |
|
from ..extras.logging import LoggerHandler, get_logger |
|
from ..extras.misc import get_peak_memory |
|
|
|
|
|
if is_safetensors_available(): |
|
from safetensors import safe_open |
|
from safetensors.torch import save_file |
|
|
|
if TYPE_CHECKING: |
|
from transformers import TrainerControl, TrainerState, TrainingArguments |
|
from trl import AutoModelForCausalLMWithValueHead |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def fix_valuehead_checkpoint( |
|
model: "AutoModelForCausalLMWithValueHead", output_dir: str, safe_serialization: bool |
|
) -> None: |
|
r""" |
|
The model is already unwrapped. |
|
|
|
There are three cases: |
|
1. full tuning without ds_zero3: state_dict = {"model.layers.*": ..., "v_head.summary.*": ...} |
|
2. lora tuning without ds_zero3: state_dict = {"v_head.summary.*": ...} |
|
3. under deepspeed zero3: state_dict = {"pretrained_model.model.layers.*": ..., "v_head.summary.*": ...} |
|
|
|
We assume `stage3_gather_16bit_weights_on_model_save=true`. |
|
""" |
|
if not isinstance(model.pretrained_model, (PreTrainedModel, PeftModel)): |
|
return |
|
|
|
if safe_serialization: |
|
path_to_checkpoint = os.path.join(output_dir, SAFE_WEIGHTS_NAME) |
|
with safe_open(path_to_checkpoint, framework="pt", device="cpu") as f: |
|
state_dict: Dict[str, torch.Tensor] = {key: f.get_tensor(key) for key in f.keys()} |
|
else: |
|
path_to_checkpoint = os.path.join(output_dir, WEIGHTS_NAME) |
|
state_dict: Dict[str, torch.Tensor] = torch.load(path_to_checkpoint, map_location="cpu") |
|
|
|
os.remove(path_to_checkpoint) |
|
decoder_state_dict, v_head_state_dict = {}, {} |
|
for name, param in state_dict.items(): |
|
if name.startswith("v_head."): |
|
v_head_state_dict[name] = param |
|
else: |
|
decoder_state_dict[name.replace("pretrained_model.", "", 1)] = param |
|
|
|
model.pretrained_model.save_pretrained( |
|
output_dir, state_dict=decoder_state_dict or None, safe_serialization=safe_serialization |
|
) |
|
|
|
if safe_serialization: |
|
save_file(v_head_state_dict, os.path.join(output_dir, V_HEAD_SAFE_WEIGHTS_NAME), metadata={"format": "pt"}) |
|
else: |
|
torch.save(v_head_state_dict, os.path.join(output_dir, V_HEAD_WEIGHTS_NAME)) |
|
|
|
logger.info("Value head model saved at: {}".format(output_dir)) |
|
|
|
|
|
class FixValueHeadModelCallback(TrainerCallback): |
|
r""" |
|
A callback for fixing the checkpoint for valuehead models. |
|
""" |
|
|
|
@override |
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
r""" |
|
Event called after a checkpoint save. |
|
""" |
|
if args.should_save: |
|
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) |
|
fix_valuehead_checkpoint( |
|
model=kwargs.pop("model"), output_dir=output_dir, safe_serialization=args.save_safetensors |
|
) |
|
|
|
|
|
class SaveProcessorCallback(TrainerCallback): |
|
r""" |
|
A callback for saving the processor. |
|
""" |
|
|
|
def __init__(self, processor: "ProcessorMixin") -> None: |
|
self.processor = processor |
|
|
|
@override |
|
def on_save(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
if args.should_save: |
|
output_dir = os.path.join(args.output_dir, "{}-{}".format(PREFIX_CHECKPOINT_DIR, state.global_step)) |
|
getattr(self.processor, "image_processor").save_pretrained(output_dir) |
|
|
|
@override |
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
if args.should_save: |
|
getattr(self.processor, "image_processor").save_pretrained(args.output_dir) |
|
|
|
|
|
class PissaConvertCallback(TrainerCallback): |
|
r""" |
|
A callback for converting the PiSSA adapter to a normal one. |
|
""" |
|
|
|
@override |
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
r""" |
|
Event called at the beginning of training. |
|
""" |
|
if args.should_save: |
|
model = kwargs.pop("model") |
|
pissa_init_dir = os.path.join(args.output_dir, "pissa_init") |
|
logger.info("Initial PiSSA adapter will be saved at: {}.".format(pissa_init_dir)) |
|
if isinstance(model, PeftModel): |
|
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") |
|
setattr(model.peft_config["default"], "init_lora_weights", True) |
|
model.save_pretrained(pissa_init_dir, safe_serialization=args.save_safetensors) |
|
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) |
|
|
|
@override |
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
if args.should_save: |
|
model = kwargs.pop("model") |
|
pissa_init_dir = os.path.join(args.output_dir, "pissa_init") |
|
pissa_backup_dir = os.path.join(args.output_dir, "pissa_backup") |
|
pissa_convert_dir = os.path.join(args.output_dir, "pissa_converted") |
|
logger.info("Converted PiSSA adapter will be saved at: {}.".format(pissa_convert_dir)) |
|
|
|
|
|
|
|
|
|
if isinstance(model, PeftModel): |
|
init_lora_weights = getattr(model.peft_config["default"], "init_lora_weights") |
|
setattr(model.peft_config["default"], "init_lora_weights", True) |
|
model.save_pretrained(pissa_backup_dir, safe_serialization=args.save_safetensors) |
|
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) |
|
model.save_pretrained( |
|
pissa_convert_dir, safe_serialization=args.save_safetensors, convert_pissa_to_lora=pissa_init_dir |
|
) |
|
model.load_adapter(pissa_backup_dir, "default", is_trainable=True) |
|
model.set_adapter("default") |
|
if "pissa_init" in model.peft_config.keys(): |
|
model.delete_adapter("pissa_init") |
|
|
|
setattr(model.peft_config["default"], "init_lora_weights", init_lora_weights) |
|
|
|
|
|
class LogCallback(TrainerCallback): |
|
r""" |
|
A callback for logging training and evaluation status. |
|
""" |
|
|
|
def __init__(self) -> None: |
|
|
|
self.start_time = 0 |
|
self.cur_steps = 0 |
|
self.max_steps = 0 |
|
self.elapsed_time = "" |
|
self.remaining_time = "" |
|
self.thread_pool: Optional["ThreadPoolExecutor"] = None |
|
|
|
self.aborted = False |
|
self.do_train = False |
|
|
|
self.webui_mode = os.environ.get("LLAMABOARD_ENABLED", "0").lower() in ["true", "1"] |
|
if self.webui_mode: |
|
signal.signal(signal.SIGABRT, self._set_abort) |
|
self.logger_handler = LoggerHandler(os.environ.get("LLAMABOARD_WORKDIR")) |
|
logging.root.addHandler(self.logger_handler) |
|
transformers.logging.add_handler(self.logger_handler) |
|
|
|
def _set_abort(self, signum, frame) -> None: |
|
self.aborted = True |
|
|
|
def _reset(self, max_steps: int = 0) -> None: |
|
self.start_time = time.time() |
|
self.cur_steps = 0 |
|
self.max_steps = max_steps |
|
self.elapsed_time = "" |
|
self.remaining_time = "" |
|
|
|
def _timing(self, cur_steps: int) -> None: |
|
cur_time = time.time() |
|
elapsed_time = cur_time - self.start_time |
|
avg_time_per_step = elapsed_time / cur_steps if cur_steps != 0 else 0 |
|
remaining_time = (self.max_steps - cur_steps) * avg_time_per_step |
|
self.cur_steps = cur_steps |
|
self.elapsed_time = str(timedelta(seconds=int(elapsed_time))) |
|
self.remaining_time = str(timedelta(seconds=int(remaining_time))) |
|
|
|
def _write_log(self, output_dir: str, logs: Dict[str, Any]) -> None: |
|
with open(os.path.join(output_dir, TRAINER_LOG), "a", encoding="utf-8") as f: |
|
f.write(json.dumps(logs) + "\n") |
|
|
|
def _create_thread_pool(self, output_dir: str) -> None: |
|
os.makedirs(output_dir, exist_ok=True) |
|
self.thread_pool = ThreadPoolExecutor(max_workers=1) |
|
|
|
def _close_thread_pool(self) -> None: |
|
if self.thread_pool is not None: |
|
self.thread_pool.shutdown(wait=True) |
|
self.thread_pool = None |
|
|
|
@override |
|
def on_init_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
if ( |
|
args.should_save |
|
and os.path.exists(os.path.join(args.output_dir, TRAINER_LOG)) |
|
and args.overwrite_output_dir |
|
): |
|
logger.warning("Previous trainer log in this folder will be deleted.") |
|
os.remove(os.path.join(args.output_dir, TRAINER_LOG)) |
|
|
|
@override |
|
def on_train_begin(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
if args.should_save: |
|
self.do_train = True |
|
self._reset(max_steps=state.max_steps) |
|
self._create_thread_pool(output_dir=args.output_dir) |
|
|
|
@override |
|
def on_train_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
self._close_thread_pool() |
|
|
|
@override |
|
def on_substep_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
if self.aborted: |
|
control.should_epoch_stop = True |
|
control.should_training_stop = True |
|
|
|
@override |
|
def on_step_end(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
if self.aborted: |
|
control.should_epoch_stop = True |
|
control.should_training_stop = True |
|
|
|
@override |
|
def on_evaluate(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
if not self.do_train: |
|
self._close_thread_pool() |
|
|
|
@override |
|
def on_predict(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
if not self.do_train: |
|
self._close_thread_pool() |
|
|
|
@override |
|
def on_log(self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs): |
|
if not args.should_save: |
|
return |
|
|
|
self._timing(cur_steps=state.global_step) |
|
logs = dict( |
|
current_steps=self.cur_steps, |
|
total_steps=self.max_steps, |
|
loss=state.log_history[-1].get("loss", None), |
|
eval_loss=state.log_history[-1].get("eval_loss", None), |
|
predict_loss=state.log_history[-1].get("predict_loss", None), |
|
reward=state.log_history[-1].get("reward", None), |
|
accuracy=state.log_history[-1].get("rewards/accuracies", None), |
|
learning_rate=state.log_history[-1].get("learning_rate", None), |
|
epoch=state.log_history[-1].get("epoch", None), |
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, |
|
elapsed_time=self.elapsed_time, |
|
remaining_time=self.remaining_time, |
|
) |
|
if state.num_input_tokens_seen: |
|
logs["throughput"] = round(state.num_input_tokens_seen / (time.time() - self.start_time), 2) |
|
logs["total_tokens"] = state.num_input_tokens_seen |
|
|
|
if os.environ.get("RECORD_VRAM", "0").lower() in ["true", "1"]: |
|
vram_allocated, vram_reserved = get_peak_memory() |
|
logs["vram_allocated"] = round(vram_allocated / 1024 / 1024 / 1024, 2) |
|
logs["vram_reserved"] = round(vram_reserved / 1024 / 1024 / 1024, 2) |
|
|
|
logs = {k: v for k, v in logs.items() if v is not None} |
|
if self.webui_mode and all(key in logs for key in ["loss", "learning_rate", "epoch"]): |
|
logger.info( |
|
"{{'loss': {:.4f}, 'learning_rate': {:2.4e}, 'epoch': {:.2f}, 'throughput': {}}}".format( |
|
logs["loss"], logs["learning_rate"], logs["epoch"], logs.get("throughput", "N/A") |
|
) |
|
) |
|
|
|
if self.thread_pool is not None: |
|
self.thread_pool.submit(self._write_log, args.output_dir, logs) |
|
|
|
@override |
|
def on_prediction_step( |
|
self, args: "TrainingArguments", state: "TrainerState", control: "TrainerControl", **kwargs |
|
): |
|
if self.do_train: |
|
return |
|
|
|
if self.aborted: |
|
sys.exit(0) |
|
|
|
if not args.should_save: |
|
return |
|
|
|
eval_dataloader = kwargs.pop("eval_dataloader", None) |
|
if has_length(eval_dataloader): |
|
if self.max_steps == 0: |
|
self._reset(max_steps=len(eval_dataloader)) |
|
self._create_thread_pool(output_dir=args.output_dir) |
|
|
|
self._timing(cur_steps=self.cur_steps + 1) |
|
if self.cur_steps % 5 == 0 and self.thread_pool is not None: |
|
logs = dict( |
|
current_steps=self.cur_steps, |
|
total_steps=self.max_steps, |
|
percentage=round(self.cur_steps / self.max_steps * 100, 2) if self.max_steps != 0 else 100, |
|
elapsed_time=self.elapsed_time, |
|
remaining_time=self.remaining_time, |
|
) |
|
self.thread_pool.submit(self._write_log, args.output_dir, logs) |
|
|