|
import torch |
|
from types import SimpleNamespace |
|
|
|
from .lora import ( |
|
extract_lora_ups_down, |
|
inject_trainable_lora_extended, |
|
monkeypatch_or_replace_lora_extended, |
|
) |
|
|
|
CLONE_OF_SIMO_KEYS = ["model", "loras", "target_replace_module", "r"] |
|
|
|
lora_versions = dict(stable_lora="stable_lora", cloneofsimo="cloneofsimo") |
|
|
|
lora_func_types = dict(loader="loader", injector="injector") |
|
|
|
lora_args = dict( |
|
model=None, |
|
loras=None, |
|
target_replace_module=[], |
|
target_module=[], |
|
r=4, |
|
search_class=[torch.nn.Linear], |
|
dropout=0, |
|
lora_bias="none", |
|
) |
|
|
|
LoraVersions = SimpleNamespace(**lora_versions) |
|
LoraFuncTypes = SimpleNamespace(**lora_func_types) |
|
|
|
LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo] |
|
LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector] |
|
|
|
|
|
def filter_dict(_dict, keys=[]): |
|
if len(keys) == 0: |
|
assert "Keys cannot empty for filtering return dict." |
|
|
|
for k in keys: |
|
if k not in lora_args.keys(): |
|
assert f"{k} does not exist in available LoRA arguments" |
|
|
|
return {k: v for k, v in _dict.items() if k in keys} |
|
|
|
|
|
class LoraHandler(object): |
|
def __init__( |
|
self, |
|
version: str = LoraVersions.cloneofsimo, |
|
use_unet_lora: bool = False, |
|
use_text_lora: bool = False, |
|
save_for_webui: bool = False, |
|
only_for_webui: bool = False, |
|
lora_bias: str = "none", |
|
unet_replace_modules: list = ["UNet3DConditionModel"], |
|
): |
|
self.version = version |
|
assert self.is_cloneofsimo_lora() |
|
|
|
self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader) |
|
self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector) |
|
self.lora_bias = lora_bias |
|
self.use_unet_lora = use_unet_lora |
|
self.use_text_lora = use_text_lora |
|
self.save_for_webui = save_for_webui |
|
self.only_for_webui = only_for_webui |
|
self.unet_replace_modules = unet_replace_modules |
|
self.use_lora = any([use_text_lora, use_unet_lora]) |
|
|
|
if self.use_lora: |
|
print(f"Using LoRA Version: {self.version}") |
|
|
|
def is_cloneofsimo_lora(self): |
|
return self.version == LoraVersions.cloneofsimo |
|
|
|
def get_lora_func(self, func_type: str = LoraFuncTypes.loader): |
|
if func_type == LoraFuncTypes.loader: |
|
return monkeypatch_or_replace_lora_extended |
|
|
|
if func_type == LoraFuncTypes.injector: |
|
return inject_trainable_lora_extended |
|
|
|
assert "LoRA Version does not exist." |
|
|
|
def get_lora_func_args( |
|
self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias |
|
): |
|
return_dict = lora_args.copy() |
|
|
|
return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS) |
|
return_dict.update( |
|
{ |
|
"model": model, |
|
"loras": lora_path, |
|
"target_replace_module": replace_modules, |
|
"r": r, |
|
} |
|
) |
|
|
|
return return_dict |
|
|
|
def do_lora_injection( |
|
self, |
|
model, |
|
replace_modules, |
|
bias="none", |
|
dropout=0, |
|
r=4, |
|
lora_loader_args=None, |
|
): |
|
REPLACE_MODULES = replace_modules |
|
|
|
params = None |
|
negation = None |
|
|
|
injector_args = lora_loader_args |
|
|
|
params, negation = self.lora_injector(**injector_args) |
|
for _up, _down in extract_lora_ups_down( |
|
model, target_replace_module=REPLACE_MODULES |
|
): |
|
|
|
if all(x is not None for x in [_up, _down]): |
|
print( |
|
f"Lora successfully injected into {model.__class__.__name__}." |
|
) |
|
|
|
break |
|
|
|
return params, negation |
|
|
|
def add_lora_to_model( |
|
self, use_lora, model, replace_modules, dropout=0.0, lora_path=None, r=16 |
|
): |
|
|
|
params = None |
|
negation = None |
|
|
|
lora_loader_args = self.get_lora_func_args( |
|
lora_path, use_lora, model, replace_modules, r, dropout, self.lora_bias |
|
) |
|
|
|
if use_lora: |
|
params, negation = self.do_lora_injection( |
|
model, |
|
replace_modules, |
|
bias=self.lora_bias, |
|
lora_loader_args=lora_loader_args, |
|
dropout=dropout, |
|
r=r, |
|
) |
|
|
|
params = model if params is None else params |
|
return params, negation |
|
|