|
import os |
|
import re |
|
import time |
|
import torch |
|
import numpy as np |
|
from safetensors.torch import load_file |
|
from diffusers.loaders import LoraLoaderMixin |
|
from diffusers.loaders.lora_conversion_utils import _maybe_map_sgm_blocks_to_diffusers, _convert_kohya_lora_to_diffusers |
|
from types import SimpleNamespace |
|
import logging.handlers |
|
LORA_PREFIX_UNET = "lora_unet" |
|
LORA_PREFIX_TEXT_ENCODER = "lora_te" |
|
LORA_UNET_LAYERS = ['lora_unet_down_blocks_0_attentions_0', 'lora_unet_down_blocks_0_attentions_1', 'lora_unet_down_blocks_1_attentions_0', 'lora_unet_down_blocks_1_attentions_1', 'lora_unet_down_blocks_2_attentions_0', 'lora_unet_down_blocks_2_attentions_1', 'lora_unet_mid_block_attentions_0', 'lora_unet_up_blocks_1_attentions_0', |
|
'lora_unet_up_blocks_1_attentions_1', 'lora_unet_up_blocks_1_attentions_2', 'lora_unet_up_blocks_2_attentions_0', 'lora_unet_up_blocks_2_attentions_1', 'lora_unet_up_blocks_2_attentions_2', 'lora_unet_up_blocks_3_attentions_0', 'lora_unet_up_blocks_3_attentions_1', 'lora_unet_up_blocks_3_attentions_2'] |
|
|
|
|
|
def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=1.0, lora_file_format="fp32", device="cuda:0"): |
|
if lora_file_format == "fp32": |
|
model_dtype = np.float32 |
|
elif lora_file_format == "fp16": |
|
model_dtype = np.float16 |
|
else: |
|
raise Exception(f"unsupported model dtype: {lora_file_format}") |
|
all_files = os.scandir(lora_model_path) |
|
unload_dict = [] |
|
|
|
for file in all_files: |
|
|
|
if 'text' in file.name: |
|
layer_infos = file.name.split('.')[0].split( |
|
'text_model_')[-1].split('_') |
|
curr_layer = clip_model.text_model |
|
else: |
|
continue |
|
|
|
|
|
temp_name = layer_infos.pop(0) |
|
while len(layer_infos) > -1: |
|
try: |
|
curr_layer = curr_layer.__getattr__(temp_name) |
|
if len(layer_infos) > 0: |
|
temp_name = layer_infos.pop(0) |
|
|
|
|
|
|
|
|
|
elif len(layer_infos) == 0: |
|
break |
|
except Exception: |
|
if len(temp_name) > 0: |
|
temp_name += '_'+layer_infos.pop(0) |
|
else: |
|
temp_name = layer_infos.pop(0) |
|
data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to( |
|
clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape) |
|
if len(curr_layer.weight.data) == 4: |
|
adding_weight = alpha * data.permute(0, 3, 1, 2) |
|
else: |
|
adding_weight = alpha * data |
|
curr_layer.weight.data += adding_weight |
|
|
|
curr_layer_unload_data = { |
|
"layer": curr_layer, |
|
"added_weight": adding_weight |
|
} |
|
unload_dict.append(curr_layer_unload_data) |
|
return unload_dict |
|
|
|
|
|
def add_xltext_lora_layer(clip_model, clip_model_2, lora_model_path, alpha=1.0, lora_file_format="fp32", device="cuda:0"): |
|
if lora_file_format == "fp32": |
|
model_dtype = np.float32 |
|
elif lora_file_format == "fp16": |
|
model_dtype = np.float16 |
|
else: |
|
raise Exception(f"unsupported model dtype: {lora_file_format}") |
|
all_files = os.scandir(lora_model_path) |
|
unload_dict = [] |
|
|
|
for file in all_files: |
|
|
|
if 'text' in file.name: |
|
layer_infos = file.name.split('.')[0].split( |
|
'text_model_')[-1].split('_') |
|
if "text_encoder_2" in file.name: |
|
curr_layer = clip_model_2.text_model |
|
elif "text_encoder" in file.name: |
|
curr_layer = clip_model.text_model |
|
else: |
|
raise ValueError( |
|
"Cannot identify clip model, need text_encoder or text_encoder_2 in filename, found: ", file.name) |
|
else: |
|
continue |
|
|
|
|
|
|
|
temp_name = layer_infos.pop(0) |
|
while len(layer_infos) > -1: |
|
try: |
|
curr_layer = curr_layer.__getattr__(temp_name) |
|
if len(layer_infos) > 0: |
|
temp_name = layer_infos.pop(0) |
|
|
|
|
|
|
|
|
|
elif len(layer_infos) == 0: |
|
break |
|
except Exception: |
|
if len(temp_name) > 0: |
|
temp_name += '_'+layer_infos.pop(0) |
|
else: |
|
temp_name = layer_infos.pop(0) |
|
|
|
data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to( |
|
clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape) |
|
if len(curr_layer.weight.data) == 4: |
|
adding_weight = alpha * data.permute(0, 3, 1, 2) |
|
else: |
|
adding_weight = alpha * data |
|
curr_layer.weight.data += adding_weight |
|
|
|
curr_layer_unload_data = { |
|
"layer": curr_layer, |
|
"added_weight": adding_weight |
|
} |
|
unload_dict.append(curr_layer_unload_data) |
|
return unload_dict |
|
|
|
def lora_trans(state_dict): |
|
loraload = LoraLoaderMixin() |
|
unet_config = SimpleNamespace(**{'layers_per_block': 2}) |
|
state_dicts = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config) |
|
state_dicts_trans, state_dicts_alpha = _convert_kohya_lora_to_diffusers( |
|
state_dicts) |
|
keys = list(state_dicts_trans.keys()) |
|
for k in keys: |
|
key = k.replace('processor.', '') |
|
for x in ['.lora_linear_layer.', '_lora.', '.lora.']: |
|
key = key.replace(x, '.lora_') |
|
if key.find('text_encoder') >= 0: |
|
for x in ['q', 'k', 'v', 'out']: |
|
key = key.replace(f'.to_{x}.', f'.{x}_proj.') |
|
key = key.replace('to_out.', 'to_out.0.') |
|
if key != k: |
|
state_dicts_trans[key] = state_dicts_trans.pop(k) |
|
alpha = torch.Tensor(list(set(list(state_dicts_alpha.values())))) |
|
state_dicts_trans.update({'lora.alpha': alpha}) |
|
|
|
return state_dicts_trans |
|
|
|
|
|
def load_state_dict(filename, need_trans=True): |
|
state_dict = load_file(os.path.abspath(filename), device="cpu") |
|
if need_trans: |
|
state_dict = lora_trans(state_dict) |
|
return state_dict |
|
|
|
|
|
def move_state_dict_to_cuda(state_dict): |
|
ret_state_dict = {} |
|
for item in state_dict: |
|
ret_state_dict[item] = state_dict[item].cuda() |
|
return ret_state_dict |
|
|
|
|
|
def add_lora_to_opt_model(state_dict, unet, clip_model, clip_model_2, alpha=1.0, need_trans=False): |
|
|
|
state_dict = move_state_dict_to_cuda(state_dict) |
|
|
|
alpha_ks = list(filter(lambda x: x.find('.alpha') >= 0, state_dict)) |
|
lora_alpha = state_dict[alpha_ks[0]].item() if len(alpha_ks) > 0 else -1 |
|
|
|
visited = set() |
|
for key in state_dict: |
|
|
|
|
|
|
|
|
|
|
|
if '.alpha' in key or key in visited: |
|
continue |
|
|
|
if "text" in key: |
|
curr_layer = clip_model_2 if key.find( |
|
'text_encoder_2') >= 0 else clip_model |
|
|
|
|
|
layer_infos = key.split('.')[1:] |
|
|
|
for x in layer_infos: |
|
try: |
|
curr_layer = curr_layer.__getattr__(x) |
|
except Exception: |
|
break |
|
|
|
|
|
pair_keys = [key.replace("lora_down", "lora_up"), |
|
key.replace("lora_up", "lora_down")] |
|
weight_up, weight_down = state_dict[pair_keys[0] |
|
], state_dict[pair_keys[1]] |
|
|
|
weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0 |
|
|
|
if len(weight_up.shape) == 4: |
|
weight_up = weight_up.squeeze([2, 3]) |
|
weight_down = weight_down.squeeze([2, 3]) |
|
if len(weight_down.shape) == 4: |
|
adding_weight = torch.einsum( |
|
'a b, b c h w -> a c h w', weight_up, weight_down) |
|
else: |
|
adding_weight = torch.mm( |
|
weight_up, weight_down).unsqueeze(2).unsqueeze(3) |
|
else: |
|
adding_weight = torch.mm(weight_up, weight_down) |
|
adding_weight = alpha * weight_scale * adding_weight |
|
|
|
curr_layer.weight.data += adding_weight.to(torch.float16) |
|
|
|
for item in pair_keys: |
|
visited.add(item) |
|
|
|
elif "unet" in key: |
|
layer_infos = key |
|
layer_infos = layer_infos.replace(".lora_up.weight", "") |
|
layer_infos = layer_infos.replace(".lora_down.weight", "") |
|
|
|
layer_infos = layer_infos[5:] |
|
layer_names = layer_infos.split(".") |
|
|
|
layers = [] |
|
i = 0 |
|
while i < len(layer_names): |
|
|
|
if len(layers) >= 4: |
|
layers[-1] += "_" + layer_names[i] |
|
elif i + 1 < len(layer_names) and layer_names[i+1].isdigit(): |
|
layers.append(layer_names[i] + "_" + layer_names[i+1]) |
|
i += 1 |
|
elif len(layers) > 0 and "samplers" in layers[-1]: |
|
layers[-1] += "_" + layer_names[i] |
|
else: |
|
layers.append(layer_names[i]) |
|
i += 1 |
|
layer_infos = ".".join(layers) |
|
|
|
pair_keys = [key.replace("lora_down", "lora_up"), |
|
key.replace("lora_up", "lora_down")] |
|
|
|
|
|
if len(state_dict[pair_keys[0]].shape) == 4: |
|
weight_up = state_dict[pair_keys[0]].squeeze( |
|
3).squeeze(2).to(torch.float32) |
|
weight_down = state_dict[pair_keys[1]].to(torch.float32) |
|
weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0 |
|
|
|
weight_up, weight_down = state_dict[pair_keys[0] |
|
], state_dict[pair_keys[1]] |
|
weight_up = weight_up.squeeze([2, 3]).to(torch.float32) |
|
weight_down = weight_down.squeeze([2, 3]).to(torch.float32) |
|
if len(weight_down.shape) == 4: |
|
curr_layer_weight = weight_scale * \ |
|
torch.einsum('a b, b c h w -> a c h w', |
|
weight_up, weight_down) |
|
else: |
|
curr_layer_weight = weight_scale * \ |
|
torch.mm(weight_up, weight_down).unsqueeze( |
|
2).unsqueeze(3) |
|
|
|
curr_layer_weight = curr_layer_weight.permute(0, 2, 3, 1) |
|
|
|
else: |
|
weight_up = state_dict[pair_keys[0]].to(torch.float32) |
|
weight_down = state_dict[pair_keys[1]].to(torch.float32) |
|
weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0 |
|
|
|
curr_layer_weight = weight_scale * \ |
|
torch.mm(weight_up, weight_down) |
|
|
|
|
|
curr_layer_weight = curr_layer_weight.to(torch.float16) |
|
|
|
unet.load_lora_by_name(layers, curr_layer_weight, alpha) |
|
|
|
for item in pair_keys: |
|
visited.add(item) |
|
|