import torch from contextlib import contextmanager from typing import Union, Tuple _size_2_t = Union[int, Tuple[int, int]] class LinearWithLoRA(torch.nn.Module): def __init__( self, in_features: int, out_features: int, bias: bool = True, device=None, dtype=None) -> None: super().__init__() self.weight_module = None self.up = None self.down = None self.bias = None self.in_features = in_features self.out_features = out_features self.device = device self.dtype = dtype self.weight = None def bind_lora(self, weight_module): self.weight_module = [weight_module] def unbind_lora(self): if self.up is not None and self.down is not None: # SAI's model is weird and needs this self.weight_module = None def get_original_weight(self): if self.weight_module is None: return None return self.weight_module[0].weight def forward(self, x): if self.weight is not None: return torch.nn.functional.linear(x, self.weight.to(x), self.bias.to(x) if self.bias is not None else None) original_weight = self.get_original_weight() if original_weight is None: return None # A1111 needs first_time_calculation if self.up is not None and self.down is not None: weight = original_weight.to(x) + torch.mm(self.up, self.down).to(x) else: weight = original_weight.to(x) return torch.nn.functional.linear(x, weight, self.bias.to(x) if self.bias is not None else None) class Conv2dWithLoRA(torch.nn.Module): def __init__( self, in_channels: int, out_channels: int, kernel_size: _size_2_t, stride: _size_2_t = 1, padding: Union[str, _size_2_t] = 0, dilation: _size_2_t = 1, groups: int = 1, bias: bool = True, padding_mode: str = 'zeros', device=None, dtype=None ) -> None: super().__init__() self.stride = stride self.padding = padding self.dilation = dilation self.groups = groups self.weight_module = None self.bias = None self.up = None self.down = None self.in_channels = in_channels self.out_channels = out_channels self.kernel_size = kernel_size self.padding_mode = padding_mode self.device = device self.dtype = dtype self.weight = None def bind_lora(self, weight_module): self.weight_module = [weight_module] def unbind_lora(self): if self.up is not None and self.down is not None: # SAI's model is weird and needs this self.weight_module = None def get_original_weight(self): if self.weight_module is None: return None return self.weight_module[0].weight def forward(self, x): if self.weight is not None: return torch.nn.functional.conv2d(x, self.weight.to(x), self.bias.to(x) if self.bias is not None else None, self.stride, self.padding, self.dilation, self.groups) original_weight = self.get_original_weight() if original_weight is None: return None # A1111 needs first_time_calculation if self.up is not None and self.down is not None: weight = original_weight.to(x) + torch.mm(self.up.flatten(start_dim=1), self.down.flatten(start_dim=1)).reshape(original_weight.shape).to(x) else: weight = original_weight.to(x) return torch.nn.functional.conv2d(x, weight, self.bias.to(x) if self.bias is not None else None, self.stride, self.padding, self.dilation, self.groups) @contextmanager def controlnet_lora_hijack(): linear, conv2d = torch.nn.Linear, torch.nn.Conv2d torch.nn.Linear, torch.nn.Conv2d = LinearWithLoRA, Conv2dWithLoRA try: yield finally: torch.nn.Linear, torch.nn.Conv2d = linear, conv2d def recursive_set(obj, key, value): if obj is None: return if '.' in key: k1, k2 = key.split('.', 1) recursive_set(getattr(obj, k1, None), k2, value) else: setattr(obj, key, value) def force_load_state_dict(model, state_dict): for k in list(state_dict.keys()): recursive_set(model, k, torch.nn.Parameter(state_dict[k])) del state_dict[k] return def recursive_bind_lora(obj, key, value): if obj is None: return if '.' in key: k1, k2 = key.split('.', 1) recursive_bind_lora(getattr(obj, k1, None), k2, value) else: target = getattr(obj, key, None) if target is not None and hasattr(target, 'bind_lora'): target.bind_lora(value) def recursive_get(obj, key): if obj is None: return if '.' in key: k1, k2 = key.split('.', 1) return recursive_get(getattr(obj, k1, None), k2) else: return getattr(obj, key, None) def bind_control_lora(base_model, control_lora_model): sd = base_model.state_dict() keys = list(sd.keys()) keys = list(set([k.rsplit('.', 1)[0] for k in keys])) module_dict = {k: recursive_get(base_model, k) for k in keys} for k, v in module_dict.items(): recursive_bind_lora(control_lora_model, k, v) def torch_dfs(model: torch.nn.Module): result = [model] for child in model.children(): result += torch_dfs(child) return result def unbind_control_lora(control_lora_model): for m in torch_dfs(control_lora_model): if hasattr(m, 'unbind_lora'): m.unbind_lora() return