Spaces:
Paused
Paused
import torch | |
from types import SimpleNamespace | |
from .lora import ( | |
extract_lora_ups_down, | |
inject_trainable_lora_extended, | |
monkeypatch_or_replace_lora_extended, | |
) | |
CLONE_OF_SIMO_KEYS = ["model", "loras", "target_replace_module", "r"] | |
lora_versions = dict(stable_lora="stable_lora", cloneofsimo="cloneofsimo") | |
lora_func_types = dict(loader="loader", injector="injector") | |
lora_args = dict( | |
model=None, | |
loras=None, | |
target_replace_module=[], | |
target_module=[], | |
r=4, | |
search_class=[torch.nn.Linear], | |
dropout=0, | |
lora_bias="none", | |
) | |
LoraVersions = SimpleNamespace(**lora_versions) | |
LoraFuncTypes = SimpleNamespace(**lora_func_types) | |
LORA_VERSIONS = [LoraVersions.stable_lora, LoraVersions.cloneofsimo] | |
LORA_FUNC_TYPES = [LoraFuncTypes.loader, LoraFuncTypes.injector] | |
def filter_dict(_dict, keys=[]): | |
if len(keys) == 0: | |
assert "Keys cannot empty for filtering return dict." | |
for k in keys: | |
if k not in lora_args.keys(): | |
assert f"{k} does not exist in available LoRA arguments" | |
return {k: v for k, v in _dict.items() if k in keys} | |
class LoraHandler(object): | |
def __init__( | |
self, | |
version: str = LoraVersions.cloneofsimo, | |
use_unet_lora: bool = False, | |
use_text_lora: bool = False, | |
save_for_webui: bool = False, | |
only_for_webui: bool = False, | |
lora_bias: str = "none", | |
unet_replace_modules: list = ["UNet3DConditionModel"], | |
): | |
self.version = version | |
assert self.is_cloneofsimo_lora() | |
self.lora_loader = self.get_lora_func(func_type=LoraFuncTypes.loader) | |
self.lora_injector = self.get_lora_func(func_type=LoraFuncTypes.injector) | |
self.lora_bias = lora_bias | |
self.use_unet_lora = use_unet_lora | |
self.use_text_lora = use_text_lora | |
self.save_for_webui = save_for_webui | |
self.only_for_webui = only_for_webui | |
self.unet_replace_modules = unet_replace_modules | |
self.use_lora = any([use_text_lora, use_unet_lora]) | |
if self.use_lora: | |
print(f"Using LoRA Version: {self.version}") | |
def is_cloneofsimo_lora(self): | |
return self.version == LoraVersions.cloneofsimo | |
def get_lora_func(self, func_type: str = LoraFuncTypes.loader): | |
if func_type == LoraFuncTypes.loader: | |
return monkeypatch_or_replace_lora_extended | |
if func_type == LoraFuncTypes.injector: | |
return inject_trainable_lora_extended | |
assert "LoRA Version does not exist." | |
def get_lora_func_args( | |
self, lora_path, use_lora, model, replace_modules, r, dropout, lora_bias | |
): | |
return_dict = lora_args.copy() | |
return_dict = filter_dict(return_dict, keys=CLONE_OF_SIMO_KEYS) | |
return_dict.update( | |
{ | |
"model": model, | |
"loras": lora_path, | |
"target_replace_module": replace_modules, | |
"r": r, | |
} | |
) | |
return return_dict | |
def do_lora_injection( | |
self, | |
model, | |
replace_modules, | |
bias="none", | |
dropout=0, | |
r=4, | |
lora_loader_args=None, | |
): | |
REPLACE_MODULES = replace_modules | |
params = None | |
negation = None | |
injector_args = lora_loader_args | |
params, negation = self.lora_injector(**injector_args) | |
for _up, _down in extract_lora_ups_down( | |
model, target_replace_module=REPLACE_MODULES | |
): | |
if all(x is not None for x in [_up, _down]): | |
print( | |
f"Lora successfully injected into {model.__class__.__name__}." | |
) | |
break | |
return params, negation | |
def add_lora_to_model( | |
self, use_lora, model, replace_modules, dropout=0.0, lora_path=None, r=16 | |
): | |
params = None | |
negation = None | |
lora_loader_args = self.get_lora_func_args( | |
lora_path, use_lora, model, replace_modules, r, dropout, self.lora_bias | |
) | |
if use_lora: | |
params, negation = self.do_lora_injection( | |
model, | |
replace_modules, | |
bias=self.lora_bias, | |
lora_loader_args=lora_loader_args, | |
dropout=dropout, | |
r=r, | |
) | |
params = model if params is None else params | |
return params, negation | |