|
import inspect |
|
import os |
|
import time |
|
from typing import Any, Callable, Dict, List, Optional, Union, Tuple |
|
|
|
import gc |
|
import torch |
|
import numpy as np |
|
from glob import glob |
|
|
|
from diffusers.loaders import TextualInversionLoaderMixin |
|
from diffusers.image_processor import VaeImageProcessor |
|
from diffusers.models import AutoencoderKL |
|
from diffusers.schedulers import (DPMSolverMultistepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
KarrasDiffusionSchedulers) |
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection |
|
from .lyrasd_vae_model import LyraSdVaeModel |
|
from .module.lyrasd_ip_adapter import LyraIPAdapter |
|
from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict |
|
from safetensors.torch import load_file |
|
|
|
|
|
class LyraSDXLPipelineBase(TextualInversionLoaderMixin): |
|
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, num_channels_unet=4, num_channels_latents=4, vae_scale_factor=8, vae_scaling_factor=0.18215) -> None: |
|
self.device = device |
|
self.dtype = dtype |
|
|
|
self.num_channels_unet = num_channels_unet |
|
self.num_channels_latents = num_channels_latents |
|
self.vae_scale_factor = vae_scale_factor |
|
self.vae_scaling_factor = vae_scaling_factor |
|
|
|
self.unet_cache = {} |
|
self.unet_in_channels = 4 |
|
|
|
self.controlnet_cache = {} |
|
|
|
self.loaded_lora = {} |
|
self.loaded_lora_strength = {} |
|
|
|
self.scheduler = None |
|
|
|
self.init_pipe() |
|
|
|
def init_pipe(self): |
|
self.vae = LyraSdVaeModel( |
|
scale_factor=self.vae_scale_factor, scaling_factor=self.vae_scaling_factor) |
|
|
|
self.unet = torch.classes.lyrasd.Unet2dConditionalModelOp( |
|
3, |
|
"fp16", |
|
self.num_channels_unet, |
|
self.num_channels_latents |
|
) |
|
|
|
self.image_processor = VaeImageProcessor( |
|
vae_scale_factor=self.vae_scale_factor) |
|
|
|
self.mask_processor = VaeImageProcessor( |
|
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True |
|
) |
|
|
|
self.feature_extractor = CLIPImageProcessor() |
|
|
|
def reload_pipe(self, model_path): |
|
self.tokenizer = CLIPTokenizer.from_pretrained( |
|
model_path, subfolder="tokenizer") |
|
self.text_encoder = CLIPTextModel.from_pretrained( |
|
model_path, subfolder="text_encoder").to(self.dtype).to(self.device) |
|
|
|
self.reload_unet_model_v2(model_path) |
|
self.reload_vae_model_v2(model_path) |
|
|
|
if not self.scheduler: |
|
self.scheduler = EulerAncestralDiscreteScheduler.from_pretrained( |
|
model_path, subfolder="scheduler") |
|
|
|
@property |
|
def _execution_device(self): |
|
if not hasattr(self.unet, "_hf_hook"): |
|
return self.device |
|
for module in self.unet.modules(): |
|
if ( |
|
hasattr(module, "_hf_hook") |
|
and hasattr(module._hf_hook, "execution_device") |
|
and module._hf_hook.execution_device is not None |
|
): |
|
return torch.device(module._hf_hook.execution_device) |
|
return self.device |
|
|
|
def reload_unet_model(self, unet_path, unet_file_format='fp32'): |
|
if len(unet_path) > 0 and unet_path[-1] != "/": |
|
unet_path = unet_path + "/" |
|
self.unet.reload_unet_model(unet_path, unet_file_format) |
|
self.load_embedding_weight( |
|
self.add_embedding, f"{unet_path}add_embedding*", unet_file_format=unet_file_format) |
|
|
|
def reload_vae_model(self, vae_path, vae_file_format='fp32'): |
|
if len(vae_path) > 0 and vae_path[-1] != "/": |
|
vae_path = vae_path + "/" |
|
return self.vae.reload_vae_model(vae_path, vae_file_format) |
|
|
|
def load_lora(self, lora_model_path, lora_name, lora_strength, lora_file_format='fp32'): |
|
if len(lora_model_path) > 0 and lora_model_path[-1] != "/": |
|
lora_model_path = lora_model_path + "/" |
|
lora = add_xltext_lora_layer( |
|
self.text_encoder, self.text_encoder_2, lora_model_path, lora_strength, lora_file_format) |
|
|
|
self.loaded_lora[lora_name] = lora |
|
self.unet.load_lora(lora_model_path, lora_name, |
|
lora_strength, lora_file_format) |
|
|
|
def unload_lora(self, lora_name, clean_cache=False): |
|
for layer_data in self.loaded_lora[lora_name]: |
|
layer = layer_data['layer'] |
|
added_weight = layer_data['added_weight'] |
|
layer.weight.data -= added_weight |
|
self.unet.unload_lora(lora_name, clean_cache) |
|
del self.loaded_lora[lora_name] |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def load_lora_v2(self, lora_model_path, lora_name, lora_strength): |
|
if lora_name in self.loaded_lora: |
|
state_dict = self.loaded_lora[lora_name] |
|
else: |
|
state_dict = load_state_dict(lora_model_path) |
|
self.loaded_lora[lora_name] = state_dict |
|
self.loaded_lora_strength[lora_name] = lora_strength |
|
add_lora_to_opt_model(state_dict, self.unet, self.text_encoder, |
|
None, lora_strength) |
|
|
|
def unload_lora_v2(self, lora_name, clean_cache=False): |
|
state_dict = self.loaded_lora[lora_name] |
|
lora_strength = self.loaded_lora_strength[lora_name] |
|
add_lora_to_opt_model(state_dict, self.unet, self.text_encoder, |
|
None, -1.0 * lora_strength) |
|
del self.loaded_lora_strength[lora_name] |
|
|
|
if clean_cache: |
|
del self.loaded_lora[lora_name] |
|
gc.collect() |
|
torch.cuda.empty_cache() |
|
|
|
def clean_lora_cache(self): |
|
self.unet.clean_lora_cache() |
|
|
|
def get_loaded_lora(self): |
|
return self.unet.get_loaded_lora() |
|
|
|
def load_ip_adapter(self, dir_ip_adapter, ip_plus, image_encoder_path, num_ip_tokens, ip_projection_dim, dir_face_in=None, num_fp_tokens=1, fp_projection_dim=None, sdxl=True): |
|
self.ip_adapter_helper = LyraIPAdapter(self, sdxl, "cuda", dir_ip_adapter, ip_plus, image_encoder_path, |
|
num_ip_tokens, ip_projection_dim, dir_face_in, num_fp_tokens, fp_projection_dim) |
|
|
|
def reload_unet_model_v2(self, model_path): |
|
checkpoint_file = os.path.join( |
|
model_path, "unet/diffusion_pytorch_model.bin") |
|
if not os.path.exists(checkpoint_file): |
|
checkpoint_file = os.path.join( |
|
model_path, "unet/diffusion_pytorch_model.safetensors") |
|
if checkpoint_file in self.unet_cache: |
|
state_dict = self.unet_cache[checkpoint_file] |
|
else: |
|
if "safetensors" in checkpoint_file: |
|
state_dict = load_file(checkpoint_file) |
|
else: |
|
state_dict = torch.load(checkpoint_file, map_location="cpu") |
|
|
|
for key in state_dict: |
|
if len(state_dict[key].shape) == 4: |
|
|
|
state_dict[key] = state_dict[key].to( |
|
torch.float16).permute(0, 2, 3, 1).contiguous() |
|
state_dict[key] = state_dict[key].to(torch.float16) |
|
self.unet_cache[checkpoint_file] = state_dict |
|
|
|
self.unet.reload_unet_model_from_cache(state_dict, "cpu") |
|
|
|
def reload_vae_model_v2(self, model_path): |
|
self.vae.reload_vae_model_v2(model_path) |
|
|
|
def load_controlnet_model(self, model_name, controlnet_path, model_dtype="fp32"): |
|
if len(controlnet_path) > 0 and controlnet_path[-1] != "/": |
|
controlnet_path = controlnet_path + "/" |
|
self.unet.load_controlnet_model(model_name, controlnet_path, model_dtype) |
|
|
|
def unload_controlnet_model(self, model_name): |
|
self.unet.unload_controlnet_model(model_name, True) |
|
|
|
def get_loaded_controlnet(self): |
|
return self.unet.get_loaded_controlnet() |
|
|
|
def load_controlnet_model_v2(self, model_name, controlnet_path): |
|
checkpoint_file = os.path.join(controlnet_path, "diffusion_pytorch_model.bin") |
|
if not os.path.exists(checkpoint_file): |
|
checkpoint_file = os.path.join(controlnet_path, "diffusion_pytorch_model.safetensors") |
|
if checkpoint_file in self.controlnet_cache: |
|
state_dict = self.controlnet_cache[checkpoint_file] |
|
else: |
|
if "safetensors" in checkpoint_file: |
|
state_dict = load_file(checkpoint_file) |
|
else: |
|
state_dict = torch.load(checkpoint_file, map_location="cpu") |
|
|
|
for key in state_dict: |
|
if len(state_dict[key].shape) == 4: |
|
|
|
state_dict[key] = state_dict[key].to(torch.float16).permute(0,2,3,1).contiguous() |
|
state_dict[key] = state_dict[key].to(torch.float16) |
|
self.controlnet_cache[checkpoint_file] = state_dict |
|
|
|
self.unet.load_controlnet_model_from_state_dict(model_name, state_dict, "cpu") |
|
|