|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import re |
|
from typing import TYPE_CHECKING |
|
|
|
import torch |
|
from peft import LoraConfig, LoraModel, PeftModel, TaskType, get_peft_model |
|
from transformers.integrations import is_deepspeed_zero3_enabled |
|
from transformers.modeling_utils import is_fsdp_enabled |
|
|
|
from ..extras.logging import get_logger |
|
from .model_utils.misc import find_all_linear_modules, find_expanded_modules |
|
from .model_utils.quantization import QuantizationMethod |
|
from .model_utils.unsloth import get_unsloth_peft_model, load_unsloth_peft_model |
|
from .model_utils.visual import get_forbidden_modules, patch_target_modules |
|
|
|
|
|
if TYPE_CHECKING: |
|
from transformers import PretrainedConfig, PreTrainedModel |
|
|
|
from ..hparams import FinetuningArguments, ModelArguments |
|
|
|
|
|
logger = get_logger(__name__) |
|
|
|
|
|
def _setup_full_tuning( |
|
model: "PreTrainedModel", |
|
finetuning_args: "FinetuningArguments", |
|
is_trainable: bool, |
|
cast_trainable_params_to_fp32: bool, |
|
) -> None: |
|
if not is_trainable: |
|
return |
|
|
|
logger.info("Fine-tuning method: Full") |
|
forbidden_modules = get_forbidden_modules(model.config, finetuning_args) |
|
for name, param in model.named_parameters(): |
|
if not any(forbidden_module in name for forbidden_module in forbidden_modules): |
|
if cast_trainable_params_to_fp32: |
|
param.data = param.data.to(torch.float32) |
|
else: |
|
param.requires_grad_(False) |
|
|
|
|
|
def _setup_freeze_tuning( |
|
model: "PreTrainedModel", |
|
finetuning_args: "FinetuningArguments", |
|
is_trainable: bool, |
|
cast_trainable_params_to_fp32: bool, |
|
) -> None: |
|
if not is_trainable: |
|
return |
|
|
|
logger.info("Fine-tuning method: Freeze") |
|
if hasattr(model.config, "text_config"): |
|
config = getattr(model.config, "text_config") |
|
else: |
|
config = model.config |
|
|
|
num_layers = ( |
|
getattr(config, "num_hidden_layers", None) |
|
or getattr(config, "num_layers", None) |
|
or getattr(config, "n_layer", None) |
|
) |
|
if not num_layers: |
|
raise ValueError("Current model does not support freeze tuning.") |
|
|
|
if finetuning_args.use_llama_pro: |
|
if num_layers % finetuning_args.freeze_trainable_layers != 0: |
|
raise ValueError( |
|
"`num_layers` {} should be divisible by `num_layer_trainable` {}.".format( |
|
num_layers, finetuning_args.freeze_trainable_layers |
|
) |
|
) |
|
|
|
stride = num_layers // finetuning_args.freeze_trainable_layers |
|
trainable_layer_ids = range(stride - 1, num_layers + stride - 1, stride) |
|
elif finetuning_args.freeze_trainable_layers > 0: |
|
trainable_layer_ids = range(max(0, num_layers - finetuning_args.freeze_trainable_layers), num_layers) |
|
else: |
|
trainable_layer_ids = range(min(-finetuning_args.freeze_trainable_layers, num_layers)) |
|
|
|
hidden_modules = set() |
|
non_hidden_modules = set() |
|
for name, _ in model.named_parameters(): |
|
if ".0." in name: |
|
hidden_modules.add(name.split(".0.")[-1].split(".")[0]) |
|
elif ".1." in name: |
|
hidden_modules.add(name.split(".1.")[-1].split(".")[0]) |
|
|
|
if re.search(r"\.\d+\.", name) is None: |
|
non_hidden_modules.add(name.split(".")[-2]) |
|
|
|
trainable_layers = [] |
|
for module_name in finetuning_args.freeze_trainable_modules: |
|
if module_name != "all" and module_name not in hidden_modules: |
|
raise ValueError( |
|
"Module {} is not found, please choose from {}".format(module_name, ", ".join(hidden_modules)) |
|
) |
|
|
|
for idx in trainable_layer_ids: |
|
trainable_layers.append(".{:d}.{}".format(idx, module_name if module_name != "all" else "")) |
|
|
|
if finetuning_args.freeze_extra_modules: |
|
for module_name in finetuning_args.freeze_extra_modules: |
|
if module_name not in non_hidden_modules: |
|
raise ValueError( |
|
"Module {} is not found, please choose from {}".format(module_name, ", ".join(non_hidden_modules)) |
|
) |
|
|
|
trainable_layers.append(module_name) |
|
|
|
forbidden_modules = get_forbidden_modules(model.config, finetuning_args) |
|
for name, param in model.named_parameters(): |
|
if any(trainable_layer in name for trainable_layer in trainable_layers) and not any( |
|
forbidden_module in name for forbidden_module in forbidden_modules |
|
): |
|
if cast_trainable_params_to_fp32: |
|
param.data = param.data.to(torch.float32) |
|
else: |
|
param.requires_grad_(False) |
|
|
|
logger.info("Set trainable layers: {}".format(",".join(trainable_layers))) |
|
|
|
|
|
def _setup_lora_tuning( |
|
config: "PretrainedConfig", |
|
model: "PreTrainedModel", |
|
model_args: "ModelArguments", |
|
finetuning_args: "FinetuningArguments", |
|
is_trainable: bool, |
|
cast_trainable_params_to_fp32: bool, |
|
) -> "PeftModel": |
|
if is_trainable: |
|
logger.info("Fine-tuning method: {}".format("DoRA" if finetuning_args.use_dora else "LoRA")) |
|
|
|
adapter_to_resume = None |
|
|
|
if model_args.adapter_name_or_path is not None: |
|
is_mergeable = True |
|
if getattr(model, "quantization_method", None): |
|
assert len(model_args.adapter_name_or_path) == 1, "Quantized model only accepts a single adapter." |
|
is_mergeable = False |
|
|
|
if is_deepspeed_zero3_enabled(): |
|
assert len(model_args.adapter_name_or_path) == 1, "Cannot use multiple adapters in DeepSpeed ZeRO-3." |
|
is_mergeable = False |
|
|
|
if model_args.use_unsloth: |
|
assert len(model_args.adapter_name_or_path) == 1, "Unsloth model only accepts a single adapter." |
|
is_mergeable = False |
|
|
|
if (is_trainable and not finetuning_args.create_new_adapter) or (not is_mergeable): |
|
adapter_to_merge = model_args.adapter_name_or_path[:-1] |
|
adapter_to_resume = model_args.adapter_name_or_path[-1] |
|
else: |
|
adapter_to_merge = model_args.adapter_name_or_path |
|
|
|
init_kwargs = { |
|
"subfolder": model_args.adapter_folder, |
|
"offload_folder": model_args.offload_folder, |
|
"cache_dir": model_args.cache_dir, |
|
"revision": model_args.model_revision, |
|
"token": model_args.hf_hub_token, |
|
} |
|
|
|
for adapter in adapter_to_merge: |
|
model: "LoraModel" = PeftModel.from_pretrained(model, adapter, **init_kwargs) |
|
model = model.merge_and_unload() |
|
|
|
if len(adapter_to_merge) > 0: |
|
logger.info("Merged {} adapter(s).".format(len(adapter_to_merge))) |
|
|
|
if adapter_to_resume is not None: |
|
if model_args.use_unsloth: |
|
model = load_unsloth_peft_model(config, model_args, is_trainable=is_trainable) |
|
else: |
|
model = PeftModel.from_pretrained(model, adapter_to_resume, is_trainable=is_trainable, **init_kwargs) |
|
|
|
logger.info("Loaded adapter(s): {}".format(",".join(model_args.adapter_name_or_path))) |
|
|
|
if is_trainable and adapter_to_resume is None: |
|
if len(finetuning_args.lora_target) == 1 and finetuning_args.lora_target[0] == "all": |
|
target_modules = find_all_linear_modules(model, finetuning_args.freeze_vision_tower) |
|
else: |
|
target_modules = finetuning_args.lora_target |
|
|
|
if finetuning_args.use_llama_pro: |
|
target_modules = find_expanded_modules(model, target_modules, finetuning_args.freeze_trainable_layers) |
|
|
|
target_modules = patch_target_modules(model.config, finetuning_args, target_modules) |
|
|
|
if ( |
|
finetuning_args.use_dora |
|
and getattr(model, "quantization_method", None) is not None |
|
and getattr(model, "quantization_method", None) != QuantizationMethod.BITS_AND_BYTES |
|
): |
|
raise ValueError("DoRA is not compatible with PTQ-quantized models.") |
|
|
|
if model_args.resize_vocab and finetuning_args.additional_target is None: |
|
input_embeddings = model.get_input_embeddings() |
|
output_embeddings = model.get_output_embeddings() |
|
module_names = set() |
|
for name, module in model.named_modules(): |
|
if module in [input_embeddings, output_embeddings]: |
|
module_names.add(name.split(".")[-1]) |
|
|
|
finetuning_args.additional_target = module_names |
|
logger.warning("Vocab has been resized, add {} to trainable params.".format(",".join(module_names))) |
|
|
|
peft_kwargs = { |
|
"r": finetuning_args.lora_rank, |
|
"target_modules": target_modules, |
|
"lora_alpha": finetuning_args.lora_alpha, |
|
"lora_dropout": finetuning_args.lora_dropout, |
|
"use_rslora": finetuning_args.use_rslora, |
|
"use_dora": finetuning_args.use_dora, |
|
"modules_to_save": finetuning_args.additional_target, |
|
} |
|
|
|
if model_args.use_unsloth: |
|
model = get_unsloth_peft_model(model, model_args, peft_kwargs) |
|
else: |
|
if finetuning_args.pissa_init: |
|
if finetuning_args.pissa_iter == -1: |
|
logger.info("Using PiSSA initialization.") |
|
peft_kwargs["init_lora_weights"] = "pissa" |
|
else: |
|
logger.info("Using PiSSA initialization with FSVD steps {}.".format(finetuning_args.pissa_iter)) |
|
peft_kwargs["init_lora_weights"] = "pissa_niter_{}".format(finetuning_args.pissa_iter) |
|
|
|
lora_config = LoraConfig( |
|
task_type=TaskType.CAUSAL_LM, |
|
inference_mode=False, |
|
**peft_kwargs, |
|
) |
|
model = get_peft_model(model, lora_config) |
|
|
|
if is_trainable and cast_trainable_params_to_fp32: |
|
for param in filter(lambda p: p.requires_grad, model.parameters()): |
|
param.data = param.data.to(torch.float32) |
|
|
|
return model |
|
|
|
|
|
def init_adapter( |
|
config: "PretrainedConfig", |
|
model: "PreTrainedModel", |
|
model_args: "ModelArguments", |
|
finetuning_args: "FinetuningArguments", |
|
is_trainable: bool, |
|
) -> "PreTrainedModel": |
|
r""" |
|
Initializes the adapters. |
|
|
|
Support full-parameter, freeze and LoRA training. |
|
|
|
Note that the trainable parameters must be cast to float32. |
|
""" |
|
if is_trainable and getattr(model, "quantization_method", None) is not None: |
|
if finetuning_args.finetuning_type != "lora": |
|
raise ValueError("Quantized models can only be used for the LoRA tuning.") |
|
|
|
if finetuning_args.pissa_init: |
|
raise ValueError("Cannot initialize PiSSA adapter on quantized models.") |
|
|
|
|
|
|
|
|
|
cast_trainable_params_to_fp32 = False |
|
if not is_trainable: |
|
pass |
|
elif finetuning_args.pure_bf16 or finetuning_args.use_badam: |
|
logger.info("Pure bf16 / BAdam detected, remaining trainable params in half precision.") |
|
elif model_args.quantization_bit is None and (is_deepspeed_zero3_enabled() or is_fsdp_enabled()): |
|
logger.info("ZeRO3 / FSDP detected, remaining trainable params in float32.") |
|
else: |
|
logger.info("Upcasting trainable params to float32.") |
|
cast_trainable_params_to_fp32 = True |
|
|
|
if finetuning_args.finetuning_type == "full": |
|
_setup_full_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32) |
|
elif finetuning_args.finetuning_type == "freeze": |
|
_setup_freeze_tuning(model, finetuning_args, is_trainable, cast_trainable_params_to_fp32) |
|
elif finetuning_args.finetuning_type == "lora": |
|
model = _setup_lora_tuning( |
|
config, model, model_args, finetuning_args, is_trainable, cast_trainable_params_to_fp32 |
|
) |
|
else: |
|
raise NotImplementedError("Unknown finetuning type: {}.".format(finetuning_args.finetuning_type)) |
|
|
|
return model |
|
|