lyraSD / lyrasd_model /lora_util.py
yibolu
update pipeline and demos
6eca12e
raw
history blame
No virus
12 kB
import os
import re
import time
import torch
import numpy as np
from safetensors.torch import load_file
from diffusers.loaders import LoraLoaderMixin
from diffusers.loaders.lora_conversion_utils import _maybe_map_sgm_blocks_to_diffusers, _convert_kohya_lora_to_diffusers
from types import SimpleNamespace
import logging.handlers
LORA_PREFIX_UNET = "lora_unet"
LORA_PREFIX_TEXT_ENCODER = "lora_te"
LORA_UNET_LAYERS = ['lora_unet_down_blocks_0_attentions_0', 'lora_unet_down_blocks_0_attentions_1', 'lora_unet_down_blocks_1_attentions_0', 'lora_unet_down_blocks_1_attentions_1', 'lora_unet_down_blocks_2_attentions_0', 'lora_unet_down_blocks_2_attentions_1', 'lora_unet_mid_block_attentions_0', 'lora_unet_up_blocks_1_attentions_0',
'lora_unet_up_blocks_1_attentions_1', 'lora_unet_up_blocks_1_attentions_2', 'lora_unet_up_blocks_2_attentions_0', 'lora_unet_up_blocks_2_attentions_1', 'lora_unet_up_blocks_2_attentions_2', 'lora_unet_up_blocks_3_attentions_0', 'lora_unet_up_blocks_3_attentions_1', 'lora_unet_up_blocks_3_attentions_2']
def add_text_lora_layer(clip_model, lora_model_path="Misaka.safetensors", alpha=1.0, lora_file_format="fp32", device="cuda:0"):
if lora_file_format == "fp32":
model_dtype = np.float32
elif lora_file_format == "fp16":
model_dtype = np.float16
else:
raise Exception(f"unsupported model dtype: {lora_file_format}")
all_files = os.scandir(lora_model_path)
unload_dict = []
# directly update weight in diffusers model
for file in all_files:
if 'text' in file.name:
layer_infos = file.name.split('.')[0].split(
'text_model_')[-1].split('_')
curr_layer = clip_model.text_model
else:
continue
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
# if temp_name == "self":
# temp_name += "_" + layer_infos.pop(0)
# elif temp_name != "mlp" and len(layer_infos) == 1:
# temp_name += "_" + layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += '_'+layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to(
clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape)
if len(curr_layer.weight.data) == 4:
adding_weight = alpha * data.permute(0, 3, 1, 2)
else:
adding_weight = alpha * data
curr_layer.weight.data += adding_weight
curr_layer_unload_data = {
"layer": curr_layer,
"added_weight": adding_weight
}
unload_dict.append(curr_layer_unload_data)
return unload_dict
def add_xltext_lora_layer(clip_model, clip_model_2, lora_model_path, alpha=1.0, lora_file_format="fp32", device="cuda:0"):
if lora_file_format == "fp32":
model_dtype = np.float32
elif lora_file_format == "fp16":
model_dtype = np.float16
else:
raise Exception(f"unsupported model dtype: {lora_file_format}")
all_files = os.scandir(lora_model_path)
unload_dict = []
# directly update weight in diffusers model
for file in all_files:
if 'text' in file.name:
layer_infos = file.name.split('.')[0].split(
'text_model_')[-1].split('_')
if "text_encoder_2" in file.name:
curr_layer = clip_model_2.text_model
elif "text_encoder" in file.name:
curr_layer = clip_model.text_model
else:
raise ValueError(
"Cannot identify clip model, need text_encoder or text_encoder_2 in filename, found: ", file.name)
else:
continue
# find the target layer
# find the target layer
temp_name = layer_infos.pop(0)
while len(layer_infos) > -1:
try:
curr_layer = curr_layer.__getattr__(temp_name)
if len(layer_infos) > 0:
temp_name = layer_infos.pop(0)
# if temp_name == "self":
# temp_name += "_" + layer_infos.pop(0)
# elif temp_name != "mlp" and len(layer_infos) == 1:
# temp_name += "_" + layer_infos.pop(0)
elif len(layer_infos) == 0:
break
except Exception:
if len(temp_name) > 0:
temp_name += '_'+layer_infos.pop(0)
else:
temp_name = layer_infos.pop(0)
data = torch.from_numpy(np.fromfile(file.path, dtype=model_dtype)).to(
clip_model.dtype).to(clip_model.device).reshape(curr_layer.weight.data.shape)
if len(curr_layer.weight.data) == 4:
adding_weight = alpha * data.permute(0, 3, 1, 2)
else:
adding_weight = alpha * data
curr_layer.weight.data += adding_weight
curr_layer_unload_data = {
"layer": curr_layer,
"added_weight": adding_weight
}
unload_dict.append(curr_layer_unload_data)
return unload_dict
def lora_trans(state_dict):
loraload = LoraLoaderMixin()
unet_config = SimpleNamespace(**{'layers_per_block': 2})
state_dicts = _maybe_map_sgm_blocks_to_diffusers(state_dict, unet_config)
state_dicts_trans, state_dicts_alpha = _convert_kohya_lora_to_diffusers(
state_dicts)
keys = list(state_dicts_trans.keys())
for k in keys:
key = k.replace('processor.', '')
for x in ['.lora_linear_layer.', '_lora.', '.lora.']:
key = key.replace(x, '.lora_')
if key.find('text_encoder') >= 0:
for x in ['q', 'k', 'v', 'out']:
key = key.replace(f'.to_{x}.', f'.{x}_proj.')
key = key.replace('to_out.', 'to_out.0.')
if key != k:
state_dicts_trans[key] = state_dicts_trans.pop(k)
alpha = torch.Tensor(list(set(list(state_dicts_alpha.values()))))
state_dicts_trans.update({'lora.alpha': alpha})
return state_dicts_trans
def load_state_dict(filename, need_trans=True):
state_dict = load_file(os.path.abspath(filename), device="cpu")
if need_trans:
state_dict = lora_trans(state_dict)
return state_dict
def move_state_dict_to_cuda(state_dict):
ret_state_dict = {}
for item in state_dict:
ret_state_dict[item] = state_dict[item].cuda()
return ret_state_dict
def add_lora_to_opt_model(state_dict, unet, clip_model, clip_model_2, alpha=1.0, need_trans=False):
# directly update weight in diffusers model
state_dict = move_state_dict_to_cuda(state_dict)
alpha_ks = list(filter(lambda x: x.find('.alpha') >= 0, state_dict))
lora_alpha = state_dict[alpha_ks[0]].item() if len(alpha_ks) > 0 else -1
visited = set()
for key in state_dict:
# print(key)
# it is suggested to print out the key, it usually will be something like below
# "lora_te_text_model_encoder_layers_0_self_attn_k_proj.lora_down.weight"
# as we have set the alpha beforehand, so just skip
if '.alpha' in key or key in visited:
continue
if "text" in key:
curr_layer = clip_model_2 if key.find(
'text_encoder_2') >= 0 else clip_model
# if is_sdxl:
layer_infos = key.split('.')[1:]
for x in layer_infos:
try:
curr_layer = curr_layer.__getattr__(x)
except Exception:
break
# update weight
pair_keys = [key.replace("lora_down", "lora_up"),
key.replace("lora_up", "lora_down")]
weight_up, weight_down = state_dict[pair_keys[0]
], state_dict[pair_keys[1]]
weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0
if len(weight_up.shape) == 4:
weight_up = weight_up.squeeze([2, 3])
weight_down = weight_down.squeeze([2, 3])
if len(weight_down.shape) == 4:
adding_weight = torch.einsum(
'a b, b c h w -> a c h w', weight_up, weight_down)
else:
adding_weight = torch.mm(
weight_up, weight_down).unsqueeze(2).unsqueeze(3)
else:
adding_weight = torch.mm(weight_up, weight_down)
adding_weight = alpha * weight_scale * adding_weight
curr_layer.weight.data += adding_weight.to(torch.float16)
# update visited list
for item in pair_keys:
visited.add(item)
elif "unet" in key:
layer_infos = key
layer_infos = layer_infos.replace(".lora_up.weight", "")
layer_infos = layer_infos.replace(".lora_down.weight", "")
layer_infos = layer_infos[5:]
layer_names = layer_infos.split(".")
layers = []
i = 0
while i < len(layer_names):
if len(layers) >= 4:
layers[-1] += "_" + layer_names[i]
elif i + 1 < len(layer_names) and layer_names[i+1].isdigit():
layers.append(layer_names[i] + "_" + layer_names[i+1])
i += 1
elif len(layers) > 0 and "samplers" in layers[-1]:
layers[-1] += "_" + layer_names[i]
else:
layers.append(layer_names[i])
i += 1
layer_infos = ".".join(layers)
pair_keys = [key.replace("lora_down", "lora_up"),
key.replace("lora_up", "lora_down")]
# update weight
if len(state_dict[pair_keys[0]].shape) == 4:
weight_up = state_dict[pair_keys[0]].squeeze(
3).squeeze(2).to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0
weight_up, weight_down = state_dict[pair_keys[0]
], state_dict[pair_keys[1]]
weight_up = weight_up.squeeze([2, 3]).to(torch.float32)
weight_down = weight_down.squeeze([2, 3]).to(torch.float32)
if len(weight_down.shape) == 4:
curr_layer_weight = weight_scale * \
torch.einsum('a b, b c h w -> a c h w',
weight_up, weight_down)
else:
curr_layer_weight = weight_scale * \
torch.mm(weight_up, weight_down).unsqueeze(
2).unsqueeze(3)
curr_layer_weight = curr_layer_weight.permute(0, 2, 3, 1)
else:
weight_up = state_dict[pair_keys[0]].to(torch.float32)
weight_down = state_dict[pair_keys[1]].to(torch.float32)
weight_scale = lora_alpha/weight_up.shape[1] if lora_alpha != -1 else 1.0
curr_layer_weight = weight_scale * \
torch.mm(weight_up, weight_down)
#
curr_layer_weight = curr_layer_weight.to(torch.float16)
unet.load_lora_by_name(layers, curr_layer_weight, alpha)
for item in pair_keys:
visited.add(item)