Spaces:
Paused
Paused
import torch | |
from contextlib import contextmanager | |
from typing import Union, Tuple | |
_size_2_t = Union[int, Tuple[int, int]] | |
class LinearWithLoRA(torch.nn.Module): | |
def __init__( | |
self, | |
in_features: int, | |
out_features: int, | |
bias: bool = True, | |
device=None, | |
dtype=None) -> None: | |
super().__init__() | |
self.weight_module = None | |
self.up = None | |
self.down = None | |
self.bias = None | |
self.in_features = in_features | |
self.out_features = out_features | |
self.device = device | |
self.dtype = dtype | |
self.weight = None | |
def bind_lora(self, weight_module): | |
self.weight_module = [weight_module] | |
def unbind_lora(self): | |
if self.up is not None and self.down is not None: # SAI's model is weird and needs this | |
self.weight_module = None | |
def get_original_weight(self): | |
if self.weight_module is None: | |
return None | |
return self.weight_module[0].weight | |
def forward(self, x): | |
if self.weight is not None: | |
return torch.nn.functional.linear(x, self.weight.to(x), | |
self.bias.to(x) if self.bias is not None else None) | |
original_weight = self.get_original_weight() | |
if original_weight is None: | |
return None # A1111 needs first_time_calculation | |
if self.up is not None and self.down is not None: | |
weight = original_weight.to(x) + torch.mm(self.up, self.down).to(x) | |
else: | |
weight = original_weight.to(x) | |
return torch.nn.functional.linear(x, weight, self.bias.to(x) if self.bias is not None else None) | |
class Conv2dWithLoRA(torch.nn.Module): | |
def __init__( | |
self, | |
in_channels: int, | |
out_channels: int, | |
kernel_size: _size_2_t, | |
stride: _size_2_t = 1, | |
padding: Union[str, _size_2_t] = 0, | |
dilation: _size_2_t = 1, | |
groups: int = 1, | |
bias: bool = True, | |
padding_mode: str = 'zeros', | |
device=None, | |
dtype=None | |
) -> None: | |
super().__init__() | |
self.stride = stride | |
self.padding = padding | |
self.dilation = dilation | |
self.groups = groups | |
self.weight_module = None | |
self.bias = None | |
self.up = None | |
self.down = None | |
self.in_channels = in_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.padding_mode = padding_mode | |
self.device = device | |
self.dtype = dtype | |
self.weight = None | |
def bind_lora(self, weight_module): | |
self.weight_module = [weight_module] | |
def unbind_lora(self): | |
if self.up is not None and self.down is not None: # SAI's model is weird and needs this | |
self.weight_module = None | |
def get_original_weight(self): | |
if self.weight_module is None: | |
return None | |
return self.weight_module[0].weight | |
def forward(self, x): | |
if self.weight is not None: | |
return torch.nn.functional.conv2d(x, self.weight.to(x), self.bias.to(x) if self.bias is not None else None, | |
self.stride, self.padding, self.dilation, self.groups) | |
original_weight = self.get_original_weight() | |
if original_weight is None: | |
return None # A1111 needs first_time_calculation | |
if self.up is not None and self.down is not None: | |
weight = original_weight.to(x) + torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1)).reshape(original_weight.shape).to(x) | |
else: | |
weight = original_weight.to(x) | |
return torch.nn.functional.conv2d(x, weight, self.bias.to(x) if self.bias is not None else None, | |
self.stride, self.padding, self.dilation, self.groups) | |
def controlnet_lora_hijack(): | |
linear, conv2d = torch.nn.Linear, torch.nn.Conv2d | |
torch.nn.Linear, torch.nn.Conv2d = LinearWithLoRA, Conv2dWithLoRA | |
try: | |
yield | |
finally: | |
torch.nn.Linear, torch.nn.Conv2d = linear, conv2d | |
def recursive_set(obj, key, value): | |
if obj is None: | |
return | |
if '.' in key: | |
k1, k2 = key.split('.', 1) | |
recursive_set(getattr(obj, k1, None), k2, value) | |
else: | |
setattr(obj, key, value) | |
def force_load_state_dict(model, state_dict): | |
for k in list(state_dict.keys()): | |
recursive_set(model, k, torch.nn.Parameter(state_dict[k])) | |
del state_dict[k] | |
return | |
def recursive_bind_lora(obj, key, value): | |
if obj is None: | |
return | |
if '.' in key: | |
k1, k2 = key.split('.', 1) | |
recursive_bind_lora(getattr(obj, k1, None), k2, value) | |
else: | |
target = getattr(obj, key, None) | |
if target is not None and hasattr(target, 'bind_lora'): | |
target.bind_lora(value) | |
def recursive_get(obj, key): | |
if obj is None: | |
return | |
if '.' in key: | |
k1, k2 = key.split('.', 1) | |
return recursive_get(getattr(obj, k1, None), k2) | |
else: | |
return getattr(obj, key, None) | |
def bind_control_lora(base_model, control_lora_model): | |
sd = base_model.state_dict() | |
keys = list(sd.keys()) | |
keys = list(set([k.rsplit('.', 1)[0] for k in keys])) | |
module_dict = {k: recursive_get(base_model, k) for k in keys} | |
for k, v in module_dict.items(): | |
recursive_bind_lora(control_lora_model, k, v) | |
def torch_dfs(model: torch.nn.Module): | |
result = [model] | |
for child in model.children(): | |
result += torch_dfs(child) | |
return result | |
def unbind_control_lora(control_lora_model): | |
for m in torch_dfs(control_lora_model): | |
if hasattr(m, 'unbind_lora'): | |
m.unbind_lora() | |
return | |