Spaces:
Paused
Paused
import logging | |
import os | |
import re | |
import lora_patches | |
import network | |
import network_lora | |
import network_hada | |
import network_ia3 | |
import network_lokr | |
import network_full | |
import network_norm | |
import torch | |
from typing import Union | |
from modules import shared, devices, sd_models, errors, scripts, sd_hijack | |
module_types = [ | |
network_lora.ModuleTypeLora(), | |
network_hada.ModuleTypeHada(), | |
network_ia3.ModuleTypeIa3(), | |
network_lokr.ModuleTypeLokr(), | |
network_full.ModuleTypeFull(), | |
network_norm.ModuleTypeNorm(), | |
] | |
re_digits = re.compile(r"\d+") | |
re_x_proj = re.compile(r"(.*)_([qkv]_proj)$") | |
re_compiled = {} | |
suffix_conversion = { | |
"attentions": {}, | |
"resnets": { | |
"conv1": "in_layers_2", | |
"conv2": "out_layers_3", | |
"norm1": "in_layers_0", | |
"norm2": "out_layers_0", | |
"time_emb_proj": "emb_layers_1", | |
"conv_shortcut": "skip_connection", | |
} | |
} | |
def convert_diffusers_name_to_compvis(key, is_sd2): | |
def match(match_list, regex_text): | |
regex = re_compiled.get(regex_text) | |
if regex is None: | |
regex = re.compile(regex_text) | |
re_compiled[regex_text] = regex | |
r = re.match(regex, key) | |
if not r: | |
return False | |
match_list.clear() | |
match_list.extend([int(x) if re.match(re_digits, x) else x for x in r.groups()]) | |
return True | |
m = [] | |
if match(m, r"lora_unet_conv_in(.*)"): | |
return f'diffusion_model_input_blocks_0_0{m[0]}' | |
if match(m, r"lora_unet_conv_out(.*)"): | |
return f'diffusion_model_out_2{m[0]}' | |
if match(m, r"lora_unet_time_embedding_linear_(\d+)(.*)"): | |
return f"diffusion_model_time_embed_{m[0] * 2 - 2}{m[1]}" | |
if match(m, r"lora_unet_down_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): | |
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) | |
return f"diffusion_model_input_blocks_{1 + m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" | |
if match(m, r"lora_unet_mid_block_(attentions|resnets)_(\d+)_(.+)"): | |
suffix = suffix_conversion.get(m[0], {}).get(m[2], m[2]) | |
return f"diffusion_model_middle_block_{1 if m[0] == 'attentions' else m[1] * 2}_{suffix}" | |
if match(m, r"lora_unet_up_blocks_(\d+)_(attentions|resnets)_(\d+)_(.+)"): | |
suffix = suffix_conversion.get(m[1], {}).get(m[3], m[3]) | |
return f"diffusion_model_output_blocks_{m[0] * 3 + m[2]}_{1 if m[1] == 'attentions' else 0}_{suffix}" | |
if match(m, r"lora_unet_down_blocks_(\d+)_downsamplers_0_conv"): | |
return f"diffusion_model_input_blocks_{3 + m[0] * 3}_0_op" | |
if match(m, r"lora_unet_up_blocks_(\d+)_upsamplers_0_conv"): | |
return f"diffusion_model_output_blocks_{2 + m[0] * 3}_{2 if m[0]>0 else 1}_conv" | |
if match(m, r"lora_te_text_model_encoder_layers_(\d+)_(.+)"): | |
if is_sd2: | |
if 'mlp_fc1' in m[1]: | |
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" | |
elif 'mlp_fc2' in m[1]: | |
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" | |
else: | |
return f"model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" | |
return f"transformer_text_model_encoder_layers_{m[0]}_{m[1]}" | |
if match(m, r"lora_te2_text_model_encoder_layers_(\d+)_(.+)"): | |
if 'mlp_fc1' in m[1]: | |
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc1', 'mlp_c_fc')}" | |
elif 'mlp_fc2' in m[1]: | |
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('mlp_fc2', 'mlp_c_proj')}" | |
else: | |
return f"1_model_transformer_resblocks_{m[0]}_{m[1].replace('self_attn', 'attn')}" | |
return key | |
def assign_network_names_to_compvis_modules(sd_model): | |
network_layer_mapping = {} | |
if shared.sd_model.is_sdxl: | |
for i, embedder in enumerate(shared.sd_model.conditioner.embedders): | |
if not hasattr(embedder, 'wrapped'): | |
continue | |
for name, module in embedder.wrapped.named_modules(): | |
network_name = f'{i}_{name.replace(".", "_")}' | |
network_layer_mapping[network_name] = module | |
module.network_layer_name = network_name | |
else: | |
for name, module in shared.sd_model.cond_stage_model.wrapped.named_modules(): | |
network_name = name.replace(".", "_") | |
network_layer_mapping[network_name] = module | |
module.network_layer_name = network_name | |
for name, module in shared.sd_model.model.named_modules(): | |
network_name = name.replace(".", "_") | |
network_layer_mapping[network_name] = module | |
module.network_layer_name = network_name | |
sd_model.network_layer_mapping = network_layer_mapping | |
def load_network(name, network_on_disk): | |
net = network.Network(name, network_on_disk) | |
net.mtime = os.path.getmtime(network_on_disk.filename) | |
sd = sd_models.read_state_dict(network_on_disk.filename) | |
# this should not be needed but is here as an emergency fix for an unknown error people are experiencing in 1.2.0 | |
if not hasattr(shared.sd_model, 'network_layer_mapping'): | |
assign_network_names_to_compvis_modules(shared.sd_model) | |
keys_failed_to_match = {} | |
is_sd2 = 'model_transformer_resblocks' in shared.sd_model.network_layer_mapping | |
matched_networks = {} | |
for key_network, weight in sd.items(): | |
key_network_without_network_parts, network_part = key_network.split(".", 1) | |
key = convert_diffusers_name_to_compvis(key_network_without_network_parts, is_sd2) | |
sd_module = shared.sd_model.network_layer_mapping.get(key, None) | |
if sd_module is None: | |
m = re_x_proj.match(key) | |
if m: | |
sd_module = shared.sd_model.network_layer_mapping.get(m.group(1), None) | |
# SDXL loras seem to already have correct compvis keys, so only need to replace "lora_unet" with "diffusion_model" | |
if sd_module is None and "lora_unet" in key_network_without_network_parts: | |
key = key_network_without_network_parts.replace("lora_unet", "diffusion_model") | |
sd_module = shared.sd_model.network_layer_mapping.get(key, None) | |
elif sd_module is None and "lora_te1_text_model" in key_network_without_network_parts: | |
key = key_network_without_network_parts.replace("lora_te1_text_model", "0_transformer_text_model") | |
sd_module = shared.sd_model.network_layer_mapping.get(key, None) | |
# some SD1 Loras also have correct compvis keys | |
if sd_module is None: | |
key = key_network_without_network_parts.replace("lora_te1_text_model", "transformer_text_model") | |
sd_module = shared.sd_model.network_layer_mapping.get(key, None) | |
if sd_module is None: | |
keys_failed_to_match[key_network] = key | |
continue | |
if key not in matched_networks: | |
matched_networks[key] = network.NetworkWeights(network_key=key_network, sd_key=key, w={}, sd_module=sd_module) | |
matched_networks[key].w[network_part] = weight | |
for key, weights in matched_networks.items(): | |
net_module = None | |
for nettype in module_types: | |
net_module = nettype.create_module(net, weights) | |
if net_module is not None: | |
break | |
if net_module is None: | |
raise AssertionError(f"Could not find a module type (out of {', '.join([x.__class__.__name__ for x in module_types])}) that would accept those keys: {', '.join(weights.w)}") | |
net.modules[key] = net_module | |
if keys_failed_to_match: | |
logging.debug(f"Network {network_on_disk.filename} didn't match keys: {keys_failed_to_match}") | |
return net | |
def purge_networks_from_memory(): | |
while len(networks_in_memory) > shared.opts.lora_in_memory_limit and len(networks_in_memory) > 0: | |
name = next(iter(networks_in_memory)) | |
networks_in_memory.pop(name, None) | |
devices.torch_gc() | |
def load_networks(names, te_multipliers=None, unet_multipliers=None, dyn_dims=None): | |
already_loaded = {} | |
for net in loaded_networks: | |
if net.name in names: | |
already_loaded[net.name] = net | |
loaded_networks.clear() | |
networks_on_disk = [available_network_aliases.get(name, None) for name in names] | |
if any(x is None for x in networks_on_disk): | |
list_available_networks() | |
networks_on_disk = [available_network_aliases.get(name, None) for name in names] | |
failed_to_load_networks = [] | |
for i, (network_on_disk, name) in enumerate(zip(networks_on_disk, names)): | |
net = already_loaded.get(name, None) | |
if network_on_disk is not None: | |
if net is None: | |
net = networks_in_memory.get(name) | |
if net is None or os.path.getmtime(network_on_disk.filename) > net.mtime: | |
try: | |
net = load_network(name, network_on_disk) | |
networks_in_memory.pop(name, None) | |
networks_in_memory[name] = net | |
except Exception as e: | |
errors.display(e, f"loading network {network_on_disk.filename}") | |
continue | |
net.mentioned_name = name | |
network_on_disk.read_hash() | |
if net is None: | |
failed_to_load_networks.append(name) | |
logging.info(f"Couldn't find network with name {name}") | |
continue | |
net.te_multiplier = te_multipliers[i] if te_multipliers else 1.0 | |
net.unet_multiplier = unet_multipliers[i] if unet_multipliers else 1.0 | |
net.dyn_dim = dyn_dims[i] if dyn_dims else 1.0 | |
loaded_networks.append(net) | |
if failed_to_load_networks: | |
sd_hijack.model_hijack.comments.append("Networks not found: " + ", ".join(failed_to_load_networks)) | |
purge_networks_from_memory() | |
def network_restore_weights_from_backup(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): | |
weights_backup = getattr(self, "network_weights_backup", None) | |
bias_backup = getattr(self, "network_bias_backup", None) | |
if weights_backup is None and bias_backup is None: | |
return | |
if weights_backup is not None: | |
if isinstance(self, torch.nn.MultiheadAttention): | |
self.in_proj_weight.copy_(weights_backup[0]) | |
self.out_proj.weight.copy_(weights_backup[1]) | |
else: | |
self.weight.copy_(weights_backup) | |
if bias_backup is not None: | |
if isinstance(self, torch.nn.MultiheadAttention): | |
self.out_proj.bias.copy_(bias_backup) | |
else: | |
self.bias.copy_(bias_backup) | |
else: | |
if isinstance(self, torch.nn.MultiheadAttention): | |
self.out_proj.bias = None | |
else: | |
self.bias = None | |
def network_apply_weights(self: Union[torch.nn.Conv2d, torch.nn.Linear, torch.nn.GroupNorm, torch.nn.LayerNorm, torch.nn.MultiheadAttention]): | |
""" | |
Applies the currently selected set of networks to the weights of torch layer self. | |
If weights already have this particular set of networks applied, does nothing. | |
If not, restores orginal weights from backup and alters weights according to networks. | |
""" | |
network_layer_name = getattr(self, 'network_layer_name', None) | |
if network_layer_name is None: | |
return | |
current_names = getattr(self, "network_current_names", ()) | |
wanted_names = tuple((x.name, x.te_multiplier, x.unet_multiplier, x.dyn_dim) for x in loaded_networks) | |
weights_backup = getattr(self, "network_weights_backup", None) | |
if weights_backup is None and wanted_names != (): | |
if current_names != (): | |
raise RuntimeError("no backup weights found and current weights are not unchanged") | |
if isinstance(self, torch.nn.MultiheadAttention): | |
weights_backup = (self.in_proj_weight.to(devices.cpu, copy=True), self.out_proj.weight.to(devices.cpu, copy=True)) | |
else: | |
weights_backup = self.weight.to(devices.cpu, copy=True) | |
self.network_weights_backup = weights_backup | |
bias_backup = getattr(self, "network_bias_backup", None) | |
if bias_backup is None: | |
if isinstance(self, torch.nn.MultiheadAttention) and self.out_proj.bias is not None: | |
bias_backup = self.out_proj.bias.to(devices.cpu, copy=True) | |
elif getattr(self, 'bias', None) is not None: | |
bias_backup = self.bias.to(devices.cpu, copy=True) | |
else: | |
bias_backup = None | |
self.network_bias_backup = bias_backup | |
if current_names != wanted_names: | |
network_restore_weights_from_backup(self) | |
for net in loaded_networks: | |
module = net.modules.get(network_layer_name, None) | |
if module is not None and hasattr(self, 'weight'): | |
try: | |
with torch.no_grad(): | |
updown, ex_bias = module.calc_updown(self.weight) | |
if len(self.weight.shape) == 4 and self.weight.shape[1] == 9: | |
# inpainting model. zero pad updown to make channel[1] 4 to 9 | |
updown = torch.nn.functional.pad(updown, (0, 0, 0, 0, 0, 5)) | |
self.weight += updown | |
if ex_bias is not None and hasattr(self, 'bias'): | |
if self.bias is None: | |
self.bias = torch.nn.Parameter(ex_bias) | |
else: | |
self.bias += ex_bias | |
except RuntimeError as e: | |
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") | |
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 | |
continue | |
module_q = net.modules.get(network_layer_name + "_q_proj", None) | |
module_k = net.modules.get(network_layer_name + "_k_proj", None) | |
module_v = net.modules.get(network_layer_name + "_v_proj", None) | |
module_out = net.modules.get(network_layer_name + "_out_proj", None) | |
if isinstance(self, torch.nn.MultiheadAttention) and module_q and module_k and module_v and module_out: | |
try: | |
with torch.no_grad(): | |
updown_q, _ = module_q.calc_updown(self.in_proj_weight) | |
updown_k, _ = module_k.calc_updown(self.in_proj_weight) | |
updown_v, _ = module_v.calc_updown(self.in_proj_weight) | |
updown_qkv = torch.vstack([updown_q, updown_k, updown_v]) | |
updown_out, ex_bias = module_out.calc_updown(self.out_proj.weight) | |
self.in_proj_weight += updown_qkv | |
self.out_proj.weight += updown_out | |
if ex_bias is not None: | |
if self.out_proj.bias is None: | |
self.out_proj.bias = torch.nn.Parameter(ex_bias) | |
else: | |
self.out_proj.bias += ex_bias | |
except RuntimeError as e: | |
logging.debug(f"Network {net.name} layer {network_layer_name}: {e}") | |
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 | |
continue | |
if module is None: | |
continue | |
logging.debug(f"Network {net.name} layer {network_layer_name}: couldn't find supported operation") | |
extra_network_lora.errors[net.name] = extra_network_lora.errors.get(net.name, 0) + 1 | |
self.network_current_names = wanted_names | |
def network_forward(module, input, original_forward): | |
""" | |
Old way of applying Lora by executing operations during layer's forward. | |
Stacking many loras this way results in big performance degradation. | |
""" | |
if len(loaded_networks) == 0: | |
return original_forward(module, input) | |
input = devices.cond_cast_unet(input) | |
network_restore_weights_from_backup(module) | |
network_reset_cached_weight(module) | |
y = original_forward(module, input) | |
network_layer_name = getattr(module, 'network_layer_name', None) | |
for lora in loaded_networks: | |
module = lora.modules.get(network_layer_name, None) | |
if module is None: | |
continue | |
y = module.forward(input, y) | |
return y | |
def network_reset_cached_weight(self: Union[torch.nn.Conv2d, torch.nn.Linear]): | |
self.network_current_names = () | |
self.network_weights_backup = None | |
def network_Linear_forward(self, input): | |
if shared.opts.lora_functional: | |
return network_forward(self, input, originals.Linear_forward) | |
network_apply_weights(self) | |
return originals.Linear_forward(self, input) | |
def network_Linear_load_state_dict(self, *args, **kwargs): | |
network_reset_cached_weight(self) | |
return originals.Linear_load_state_dict(self, *args, **kwargs) | |
def network_Conv2d_forward(self, input): | |
if shared.opts.lora_functional: | |
return network_forward(self, input, originals.Conv2d_forward) | |
network_apply_weights(self) | |
return originals.Conv2d_forward(self, input) | |
def network_Conv2d_load_state_dict(self, *args, **kwargs): | |
network_reset_cached_weight(self) | |
return originals.Conv2d_load_state_dict(self, *args, **kwargs) | |
def network_GroupNorm_forward(self, input): | |
if shared.opts.lora_functional: | |
return network_forward(self, input, originals.GroupNorm_forward) | |
network_apply_weights(self) | |
return originals.GroupNorm_forward(self, input) | |
def network_GroupNorm_load_state_dict(self, *args, **kwargs): | |
network_reset_cached_weight(self) | |
return originals.GroupNorm_load_state_dict(self, *args, **kwargs) | |
def network_LayerNorm_forward(self, input): | |
if shared.opts.lora_functional: | |
return network_forward(self, input, originals.LayerNorm_forward) | |
network_apply_weights(self) | |
return originals.LayerNorm_forward(self, input) | |
def network_LayerNorm_load_state_dict(self, *args, **kwargs): | |
network_reset_cached_weight(self) | |
return originals.LayerNorm_load_state_dict(self, *args, **kwargs) | |
def network_MultiheadAttention_forward(self, *args, **kwargs): | |
network_apply_weights(self) | |
return originals.MultiheadAttention_forward(self, *args, **kwargs) | |
def network_MultiheadAttention_load_state_dict(self, *args, **kwargs): | |
network_reset_cached_weight(self) | |
return originals.MultiheadAttention_load_state_dict(self, *args, **kwargs) | |
def list_available_networks(): | |
available_networks.clear() | |
available_network_aliases.clear() | |
forbidden_network_aliases.clear() | |
available_network_hash_lookup.clear() | |
forbidden_network_aliases.update({"none": 1, "Addams": 1}) | |
os.makedirs(shared.cmd_opts.lora_dir, exist_ok=True) | |
candidates = list(shared.walk_files(shared.cmd_opts.lora_dir, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) | |
candidates += list(shared.walk_files(shared.cmd_opts.lyco_dir_backcompat, allowed_extensions=[".pt", ".ckpt", ".safetensors"])) | |
for filename in candidates: | |
if os.path.isdir(filename): | |
continue | |
name = os.path.splitext(os.path.basename(filename))[0] | |
try: | |
entry = network.NetworkOnDisk(name, filename) | |
except OSError: # should catch FileNotFoundError and PermissionError etc. | |
errors.report(f"Failed to load network {name} from {filename}", exc_info=True) | |
continue | |
available_networks[name] = entry | |
if entry.alias in available_network_aliases: | |
forbidden_network_aliases[entry.alias.lower()] = 1 | |
available_network_aliases[name] = entry | |
available_network_aliases[entry.alias] = entry | |
re_network_name = re.compile(r"(.*)\s*\([0-9a-fA-F]+\)") | |
def infotext_pasted(infotext, params): | |
if "AddNet Module 1" in [x[1] for x in scripts.scripts_txt2img.infotext_fields]: | |
return # if the other extension is active, it will handle those fields, no need to do anything | |
added = [] | |
for k in params: | |
if not k.startswith("AddNet Model "): | |
continue | |
num = k[13:] | |
if params.get("AddNet Module " + num) != "LoRA": | |
continue | |
name = params.get("AddNet Model " + num) | |
if name is None: | |
continue | |
m = re_network_name.match(name) | |
if m: | |
name = m.group(1) | |
multiplier = params.get("AddNet Weight A " + num, "1.0") | |
added.append(f"<lora:{name}:{multiplier}>") | |
if added: | |
params["Prompt"] += "\n" + "".join(added) | |
originals: lora_patches.LoraPatches = None | |
extra_network_lora = None | |
available_networks = {} | |
available_network_aliases = {} | |
loaded_networks = [] | |
networks_in_memory = {} | |
available_network_hash_lookup = {} | |
forbidden_network_aliases = {} | |
list_available_networks() | |