import os import torch from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer import numpy as np def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=1.0, lora_file_format="fp32", device="cuda:0"): if lora_file_format == "fp32": model_dtype = np.float32 elif lora_file_format == "fp16": model_dtype = np.float16 else: raise Exception(f"unsupported model dtype: {lora_file_format}") all_files = os.scandir(lora_model_path) unload_dict = [] # directly update weight in diffusers model for file in all_files: if 'text' in file.name: layer_infos = file.name.split('.')[0].split('text_model_')[-1].split('_') curr_layer = clip_model.text_model else: continue # find the target layer temp_name = layer_infos.pop(0) while len(layer_infos) > -1: try: curr_layer = curr_layer.__getattr__(temp_name) if len(layer_infos) > 0: temp_name = layer_infos.pop(0) # if temp_name == "self": # temp_name += "_" + layer_infos.pop(0) # elif temp_name != "mlp" and len(layer_infos) == 1: # temp_name += "_" + layer_infos.pop(0) elif len(layer_infos) == 0: break except Exception: if len(temp_name) > 0: temp_name += '_'+layer_infos.pop(0) else: temp_name = layer_infos.pop(0) data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to(clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape) if len(curr_layer.weight.data) == 4: adding_weight = alpha * data.permute(0,3,1,2) else: adding_weight = alpha * data curr_layer.weight.data += adding_weight curr_layer_unload_data = { "layer": curr_layer, "added_weight": adding_weight } unload_dict.append(curr_layer_unload_data) return unload_dict