File size: 4,407 Bytes
5bec700 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 |
import torch
from types import SimpleNamespace
from .lora import (
extract_lora_ups_down,
inject_trainable_lora_extended,
monkeypatch_or_replace_lora_extended,
)
CLONE_OF_SIMO_KEYS = ["model", "loras", "target_replace_module", "r"]
lora_versions = dict(stable_lora="stable_lora", cloneofsimo="cloneofsimo")
lora_func_types = dict(loader="loader", injector="injector")
lora_args = dict(
model=None,
loras=None,
target_replace_module=[],
target_module=[],
r=4,
search_class=[torch.nn.Linear],
dropout=0,
lora_bias="none",
)
LoraVersions = SimpleNamespace(**lora_versions)
LoraFuncTypes = SimpleNamespace(**lora_func_types)
LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo]
LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector]
def filter_dict(_dict, keys=[]):
if len(keys) == 0:
assert "Keys cannot empty for filtering return dict."
for k in keys:
if k not in lora_args.keys():
assert f"{k} does not exist in available LoRA arguments"
return {k: v for k, v in _dict.items() if k in keys}
class LoraHandler(object):
def __init__(
self,
version: str = LoraVersions.cloneofsimo,
use_unet_lora: bool = False,
use_text_lora: bool = False,
save_for_webui: bool = False,
only_for_webui: bool = False,
lora_bias: str = "none",
unet_replace_modules: list = ["UNet3DConditionModel"],
):
self.version = version
assert self.is_cloneofsimo_lora()
self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader)
self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector)
self.lora_bias = lora_bias
self.use_unet_lora = use_unet_lora
self.use_text_lora = use_text_lora
self.save_for_webui = save_for_webui
self.only_for_webui = only_for_webui
self.unet_replace_modules = unet_replace_modules
self.use_lora = any([use_text_lora, use_unet_lora])
if self.use_lora:
print(f"Using LoRA Version: {self.version}")
def is_cloneofsimo_lora(self):
return self.version == LoraVersions.cloneofsimo
def get_lora_func(self, func_type: str = LoraFuncTypes.loader):
if func_type == LoraFuncTypes.loader:
return monkeypatch_or_replace_lora_extended
if func_type == LoraFuncTypes.injector:
return inject_trainable_lora_extended
assert "LoRA Version does not exist."
def get_lora_func_args(
self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias
):
return_dict = lora_args.copy()
return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS)
return_dict.update(
{
"model": model,
"loras": lora_path,
"target_replace_module": replace_modules,
"r": r,
}
)
return return_dict
def do_lora_injection(
self,
model,
replace_modules,
bias="none",
dropout=0,
r=4,
lora_loader_args=None,
):
REPLACE_MODULES = replace_modules
params = None
negation = None
injector_args = lora_loader_args
params, negation = self.lora_injector(**injector_args)
for _up, _down in extract_lora_ups_down(
model, target_replace_module=REPLACE_MODULES
):
if all(x is not None for x in [_up, _down]):
print(
f"Lora successfully injected into {model.__class__.__name__}."
)
break
return params, negation
def add_lora_to_model(
self, use_lora, model, replace_modules, dropout=0.0, lora_path=None, r=16
):
params = None
negation = None
lora_loader_args = self.get_lora_func_args(
lora_path, use_lora, model, replace_modules, r, dropout, self.lora_bias
)
if use_lora:
params, negation = self.do_lora_injection(
model,
replace_modules,
bias=self.lora_bias,
lora_loader_args=lora_loader_args,
dropout=dropout,
r=r,
)
params = model if params is None else params
return params, negation
|