import copy import logging from dataclasses import dataclass, field from typing import Optional import numpy as np import torch from transformers import AutoModelForCausalLM, AutoTokenizer from alignment import H4ArgumentParser, ModelArguments, get_kbit_device_map, get_quantization_config from huggingface_hub import upload_folder logger = logging.getLogger(__name__) @dataclass class InitializationArguments(ModelArguments): output_dir: str = field( default="./checkpoint", metadata={"help": "The output directory where the model predictions and checkpoints will be written."}, ) num_hidden_layers: int = field( default=6, metadata={"help": "The number of hidden layers in the Transformer decoder."}, ) push_to_hub: Optional[bool] = field( default=False, metadata={"help": "Whether or not to upload the trained model to the model hub after training."} ) hub_model_id: Optional[str] = field( default=None, metadata={"help": "The name of the repository to keep in sync with the local `output_dir`."} ) low_cpu_mem_usage: Optional[bool] = field( default=True, metadata={ "help": "Create the teacher model as an empty shell, and only materialize its parameters when the pretrained weights are loaded. " "Significantly benefits loading time and RAM consumption." }, ) def main(): parser = H4ArgumentParser([InitializationArguments]) model_args = parser.parse() logger.info(f"Model parameters {model_args}") logger.info("*** Load pretrained teacher model ***") torch_dtype = ( model_args.torch_dtype if model_args.torch_dtype in ["auto", None] else getattr(torch, model_args.torch_dtype) ) quantization_config = get_quantization_config(model_args) model_kwargs = dict( revision=model_args.model_revision, trust_remote_code=model_args.trust_remote_code, torch_dtype=torch_dtype, device_map=get_kbit_device_map() if quantization_config is not None else None, quantization_config=quantization_config, low_cpu_mem_usage=model_args.low_cpu_mem_usage, ) teacher_model = AutoModelForCausalLM.from_pretrained(model_args.model_name_or_path, **model_kwargs) tokenizer = AutoTokenizer.from_pretrained(model_args.model_name_or_path) generation_config = teacher_model.generation_config teacher_config = teacher_model.config logger.info("*** Teacher model loaded! ***") student_config = copy.deepcopy(teacher_config) student_config.num_hidden_layers = model_args.num_hidden_layers teacher_hidden_layers = teacher_config.num_hidden_layers decoder_mapping = np.linspace(0, teacher_hidden_layers - 1, student_config.num_hidden_layers, dtype=int) decoder_mapping[-1] = teacher_hidden_layers - 1 decoder_map = {} for student_layer, teacher_layer in enumerate(decoder_mapping): decoder_map[teacher_layer] = student_layer # init the student params from the teacher model logger.info("*** Load and initialise student model ***") student_model = AutoModelForCausalLM.from_config(student_config) missing_keys, unexpected_keys = student_model.load_state_dict(teacher_model.state_dict(), strict=False) if len(missing_keys) > 0: raise RuntimeError( f"Error(s) in loading state_dict for {student_model.__class__.__name__}. \n" f"Missing key(s) in state_dict: {missing_keys}" ) if student_config.num_hidden_layers == teacher_hidden_layers: decoder_keys = [key for key in unexpected_keys if "model.layers" in key] if len(decoder_keys) > 0: raise RuntimeError( f"Error(s) in loading state_dict for {student_model.__class__.__name__}. \n" f"Unexpected key(s) in state_dict: {decoder_keys}" ) for layer in range(teacher_hidden_layers): if layer in decoder_map: # re-introduce pre-defined layers from the teacher student_model.model.layers[decoder_map[layer]].load_state_dict( teacher_model.model.layers[layer].state_dict() ) logger.info("*** Student model loaded! ***") # remove the teacher params and model del teacher_model # save the converted weights and model if model_args.output_dir is not None: student_model.save_pretrained(model_args.output_dir) # we also need to correctly save the processor and generation config tokenizer.save_pretrained(model_args.output_dir) generation_config.save_pretrained(model_args.output_dir) if model_args.push_to_hub: repo_id = model_args.hub_model_id or model_args.output_dir upload_folder( repo_id=repo_id, folder_path=model_args.output_dir, commit_description="Uploading initialised weights and configs", ) if __name__ == "__main__": main()