|
import inspect |
|
import os |
|
import time |
|
from typing import Any, Callable, Dict, List, Optional, Union, Tuple |
|
|
|
import gc |
|
import torch |
|
import numpy as np |
|
from glob import glob |
|
|
|
from diffusers import StableDiffusionXLInpaintPipeline, UNet2DConditionModel |
|
from diffusers.loaders import TextualInversionLoaderMixin |
|
from diffusers.image_processor import VaeImageProcessor, PipelineImageInput |
|
from diffusers.models import AutoencoderKL |
|
from diffusers.schedulers import (DPMSolverMultistepScheduler, |
|
EulerAncestralDiscreteScheduler, |
|
EulerDiscreteScheduler, |
|
KarrasDiffusionSchedulers) |
|
from diffusers.models.embeddings import TimestepEmbedding, Timesteps |
|
from diffusers.utils.torch_utils import randn_tensor |
|
from diffusers.utils import logging |
|
from PIL import Image |
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection |
|
from .lyrasd_vae_model import LyraSdVaeModel |
|
from .module.lyrasd_ip_adapter import LyraIPAdapter |
|
from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict |
|
from safetensors.torch import load_file |
|
|
|
from .lyrasdxl_pipeline_base import LyraSDXLPipelineBase |
|
|
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
|
""" |
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
|
""" |
|
std_text = noise_pred_text.std( |
|
dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + \ |
|
(1 - guidance_rescale) * noise_cfg |
|
return noise_cfg |
|
|
|
def retrieve_latents( |
|
encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" |
|
): |
|
if sample_mode == "sample": |
|
return encoder_output.sample(generator) |
|
elif sample_mode == "argmax": |
|
return encoder_output.mode() |
|
else: |
|
return encoder_output |
|
|
|
|
|
def retrieve_timesteps( |
|
scheduler, |
|
num_inference_steps: Optional[int] = None, |
|
device: Optional[Union[str, torch.device]] = None, |
|
timesteps: Optional[List[int]] = None, |
|
**kwargs, |
|
): |
|
""" |
|
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles |
|
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. |
|
|
|
Args: |
|
scheduler (`SchedulerMixin`): |
|
The scheduler to get timesteps from. |
|
num_inference_steps (`int`): |
|
The number of diffusion steps used when generating samples with a pre-trained model. If used, |
|
`timesteps` must be `None`. |
|
device (`str` or `torch.device`, *optional*): |
|
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. |
|
timesteps (`List[int]`, *optional*): |
|
Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default |
|
timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` |
|
must be `None`. |
|
|
|
Returns: |
|
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the |
|
second element is the number of inference steps. |
|
""" |
|
if timesteps is not None: |
|
accepts_timesteps = "timesteps" in set( |
|
inspect.signature(scheduler.set_timesteps).parameters.keys()) |
|
if not accepts_timesteps: |
|
raise ValueError( |
|
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" |
|
f" timestep schedules. Please check whether you are using the correct scheduler." |
|
) |
|
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
num_inference_steps = len(timesteps) |
|
else: |
|
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) |
|
timesteps = scheduler.timesteps |
|
return timesteps, num_inference_steps |
|
|
|
|
|
class LyraSdXLTxt2ImgInpaintPipeline(LyraSDXLPipelineBase, StableDiffusionXLInpaintPipeline): |
|
device = torch.device("cpu") |
|
dtype = torch.float32 |
|
|
|
def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.13025, num_channels_unet=9, num_channels_latents=4, requires_aesthetics_score: bool = False, |
|
force_zeros_for_empty_prompt: bool = True) -> None: |
|
self.register_to_config( |
|
force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) |
|
self.register_to_config( |
|
requires_aesthetics_score=requires_aesthetics_score) |
|
|
|
super().__init__(device, dtype, num_channels_unet=num_channels_unet, num_channels_latents=num_channels_latents, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor) |
|
|
|
|
|
def encode_image(self, image, device, num_images_per_prompt): |
|
dtype = next(self.image_encoder.parameters()).dtype |
|
if not isinstance(image, torch.Tensor): |
|
image = self.feature_extractor( |
|
image, return_tensors="pt").pixel_values |
|
|
|
image = image.to(device=device, dtype=dtype) |
|
image_embeds = self.image_encoder(image).image_embeds |
|
image_embeds = image_embeds.repeat_interleave( |
|
num_images_per_prompt, dim=0) |
|
|
|
uncond_image_embeds = torch.zeros_like(image_embeds) |
|
return image_embeds, uncond_image_embeds |
|
|
|
def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): |
|
dtype = image.dtype |
|
|
|
|
|
|
|
|
|
if isinstance(generator, list): |
|
image_latents = [ |
|
retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) |
|
for i in range(image.shape[0]) |
|
] |
|
image_latents = torch.cat(image_latents, dim=0) |
|
else: |
|
image_latents = retrieve_latents(self.vae.encode(image), generator=generator) |
|
|
|
image_latents = image_latents.to(dtype) |
|
image_latents = self.vae.scaling_factor * image_latents |
|
|
|
return image_latents |
|
|
|
def _get_add_time_ids( |
|
self, |
|
original_size, |
|
crops_coords_top_left, |
|
target_size, |
|
aesthetic_score, |
|
negative_aesthetic_score, |
|
negative_original_size, |
|
negative_crops_coords_top_left, |
|
negative_target_size, |
|
dtype, |
|
text_encoder_projection_dim=None, |
|
): |
|
if self.config.requires_aesthetics_score: |
|
add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) |
|
add_neg_time_ids = list( |
|
negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) |
|
) |
|
else: |
|
add_time_ids = list(original_size + crops_coords_top_left + target_size) |
|
add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) |
|
|
|
passed_add_embed_dim = ( |
|
self.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim |
|
) |
|
expected_add_embed_dim = self.add_embedding.linear_1.in_features |
|
|
|
if ( |
|
expected_add_embed_dim > passed_add_embed_dim |
|
and (expected_add_embed_dim - passed_add_embed_dim) == self.addition_time_embed_dim |
|
): |
|
raise ValueError( |
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." |
|
) |
|
elif ( |
|
expected_add_embed_dim < passed_add_embed_dim |
|
and (passed_add_embed_dim - expected_add_embed_dim) == self.addition_time_embed_dim |
|
): |
|
raise ValueError( |
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." |
|
) |
|
elif expected_add_embed_dim != passed_add_embed_dim: |
|
raise ValueError( |
|
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." |
|
) |
|
|
|
add_time_ids = torch.tensor([add_time_ids], dtype=dtype) |
|
add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) |
|
|
|
return add_time_ids, add_neg_time_ids |
|
|
|
def load_ip_adapter(self, dir_ip_adapter, ip_plus, image_encoder_path, num_ip_tokens, ip_projection_dim, dir_face_in=None, num_fp_tokens=1, fp_projection_dim=None, sdxl=True): |
|
self.ip_adapter_helper = LyraIPAdapter(self, sdxl, "cuda", dir_ip_adapter, ip_plus, image_encoder_path, |
|
num_ip_tokens, ip_projection_dim, dir_face_in, num_fp_tokens, fp_projection_dim) |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
prompt: Union[str, List[str]] = None, |
|
prompt_2: Optional[Union[str, List[str]]] = None, |
|
image: PipelineImageInput = None, |
|
mask_image: PipelineImageInput = None, |
|
masked_image_latents: torch.FloatTensor = None, |
|
height: Optional[int] = None, |
|
width: Optional[int] = None, |
|
strength: float = 0.9999, |
|
num_inference_steps: int = 50, |
|
timesteps: List[int] = None, |
|
denoising_start: Optional[float] = None, |
|
denoising_end: Optional[float] = None, |
|
guidance_scale: float = 7.5, |
|
negative_prompt: Optional[Union[str, List[str]]] = None, |
|
negative_prompt_2: Optional[Union[str, List[str]]] = None, |
|
num_images_per_prompt: Optional[int] = 1, |
|
eta: float = 0.0, |
|
generator: Optional[Union[torch.Generator, |
|
List[torch.Generator]]] = None, |
|
latents: Optional[torch.FloatTensor] = None, |
|
prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, |
|
output_type: Optional[str] = "pil", |
|
return_dict: bool = True, |
|
cross_attention_kwargs: Optional[Dict[str, Any]] = None, |
|
guidance_rescale: float = 0.0, |
|
original_size: Tuple[int, int] = None, |
|
crops_coords_top_left: Tuple[int, int] = (0, 0), |
|
target_size: Tuple[int, int] = None, |
|
negative_original_size: Optional[Tuple[int, int]] = None, |
|
negative_crops_coords_top_left: Tuple[int, int] = (0, 0), |
|
negative_target_size: Optional[Tuple[int, int]] = None, |
|
aesthetic_score: float = 6.0, |
|
negative_aesthetic_score: float = 2.5, |
|
clip_skip: Optional[int] = None, |
|
extra_tensor_dict: Optional[Dict[str, torch.FloatTensor]] = {}, |
|
param_scale_dict: Optional[Dict[str, int]] = {}, |
|
**kwargs |
|
): |
|
|
|
callback = kwargs.pop("callback", None) |
|
callback_steps = kwargs.pop("callback_steps", None) |
|
|
|
|
|
height = height or self.default_sample_size * self.vae_scale_factor |
|
width = width or self.default_sample_size * self.vae_scale_factor |
|
|
|
original_size = original_size or (height, width) |
|
target_size = target_size or (height, width) |
|
|
|
self._guidance_scale = guidance_scale |
|
self._guidance_rescale = guidance_rescale |
|
self._clip_skip = clip_skip |
|
self._cross_attention_kwargs = cross_attention_kwargs |
|
self._denoising_end = denoising_end |
|
self._denoising_start = denoising_start |
|
|
|
|
|
self.check_inputs( |
|
prompt, |
|
prompt_2, |
|
height, |
|
width, |
|
strength, |
|
callback_steps, |
|
negative_prompt, |
|
negative_prompt_2, |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
) |
|
|
|
|
|
if prompt is not None and isinstance(prompt, str): |
|
batch_size = 1 |
|
elif prompt is not None and isinstance(prompt, list): |
|
batch_size = len(prompt) |
|
else: |
|
batch_size = prompt_embeds.shape[0] |
|
|
|
device = self._execution_device |
|
|
|
do_classifier_free_guidance = guidance_scale > 1.0 |
|
|
|
|
|
text_encoder_lora_scale = ( |
|
cross_attention_kwargs.get( |
|
"scale", None) if cross_attention_kwargs is not None else None |
|
) |
|
( |
|
prompt_embeds, |
|
negative_prompt_embeds, |
|
pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds, |
|
) = self.encode_prompt( |
|
prompt=prompt, |
|
prompt_2=prompt_2, |
|
device=device, |
|
num_images_per_prompt=num_images_per_prompt, |
|
do_classifier_free_guidance=do_classifier_free_guidance, |
|
negative_prompt=negative_prompt, |
|
negative_prompt_2=negative_prompt_2, |
|
prompt_embeds=prompt_embeds, |
|
negative_prompt_embeds=negative_prompt_embeds, |
|
pooled_prompt_embeds=pooled_prompt_embeds, |
|
negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, |
|
lora_scale=text_encoder_lora_scale, |
|
clip_skip=clip_skip |
|
) |
|
|
|
def denoising_value_valid(dnv): |
|
return isinstance(self.denoising_end, float) and 0 < dnv < 1 |
|
|
|
|
|
timesteps, num_inference_steps = retrieve_timesteps( |
|
self.scheduler, num_inference_steps, device, timesteps) |
|
timesteps, num_inference_steps = self.get_timesteps( |
|
num_inference_steps, |
|
strength, |
|
device, |
|
denoising_start=self.denoising_start if denoising_value_valid else None, |
|
) |
|
|
|
latent_timestep = timesteps[:1].repeat( |
|
batch_size * num_images_per_prompt) |
|
is_strength_max = strength == 1.0 |
|
|
|
|
|
|
|
init_image = self.image_processor.preprocess( |
|
image, height=height, width=width) |
|
init_image = init_image.to(dtype=torch.float32) |
|
|
|
mask = self.mask_processor.preprocess( |
|
mask_image, height=height, width=width) |
|
|
|
if masked_image_latents is not None: |
|
masked_image = masked_image_latents |
|
elif init_image.shape[1] == 4: |
|
|
|
masked_image = None |
|
else: |
|
masked_image = init_image * (mask < 0.5) |
|
|
|
add_noise = True if self.denoising_start is None else False |
|
|
|
return_image_latents = self.num_channels_unet == 4 |
|
|
|
latents_outputs = self.prepare_latents( |
|
batch_size * num_images_per_prompt, |
|
self.num_channels_latents, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
latents, |
|
image=init_image, |
|
timestep=latent_timestep, |
|
is_strength_max=is_strength_max, |
|
add_noise=add_noise, |
|
return_noise=True, |
|
return_image_latents=return_image_latents, |
|
) |
|
|
|
if return_image_latents: |
|
latents, noise, image_latents = latents_outputs |
|
else: |
|
latents, noise = latents_outputs |
|
|
|
mask, masked_image_latents = self.prepare_mask_latents( |
|
mask, |
|
masked_image, |
|
batch_size * num_images_per_prompt, |
|
height, |
|
width, |
|
prompt_embeds.dtype, |
|
device, |
|
generator, |
|
do_classifier_free_guidance, |
|
) |
|
|
|
|
|
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) |
|
|
|
|
|
add_text_embeds = pooled_prompt_embeds |
|
if self.text_encoder_2 is None: |
|
text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) |
|
else: |
|
text_encoder_projection_dim = self.text_encoder_2.config.projection_dim |
|
|
|
if negative_original_size is None: |
|
negative_original_size = original_size |
|
if negative_target_size is None: |
|
negative_target_size = target_size |
|
|
|
add_time_ids, add_neg_time_ids = self._get_add_time_ids( |
|
original_size, |
|
crops_coords_top_left, |
|
target_size, |
|
aesthetic_score, |
|
negative_aesthetic_score, |
|
negative_original_size, |
|
negative_crops_coords_top_left, |
|
negative_target_size, |
|
dtype=prompt_embeds.dtype, |
|
text_encoder_projection_dim=text_encoder_projection_dim, |
|
) |
|
add_time_ids = add_time_ids.repeat( |
|
batch_size * num_images_per_prompt, 1) |
|
|
|
if do_classifier_free_guidance: |
|
prompt_embeds = torch.cat( |
|
[negative_prompt_embeds, prompt_embeds], dim=0) |
|
add_text_embeds = torch.cat( |
|
[negative_pooled_prompt_embeds, add_text_embeds], dim=0) |
|
add_neg_time_ids = add_neg_time_ids.repeat( |
|
batch_size * num_images_per_prompt, 1) |
|
add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) |
|
|
|
prompt_embeds = prompt_embeds.to(device) |
|
add_text_embeds = add_text_embeds.to(device) |
|
add_time_ids = add_time_ids.to(device) |
|
|
|
|
|
num_warmup_steps = max( |
|
len(timesteps) - num_inference_steps * self.scheduler.order, 0) |
|
|
|
|
|
if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: |
|
discrete_timestep_cutoff = int( |
|
round( |
|
self.scheduler.config.num_train_timesteps |
|
- (denoising_end * self.scheduler.config.num_train_timesteps) |
|
) |
|
) |
|
num_inference_steps = len( |
|
list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) |
|
timesteps = timesteps[:num_inference_steps] |
|
|
|
aug_emb = self._get_aug_emb( |
|
add_time_ids, add_text_embeds, prompt_embeds.dtype) |
|
|
|
extra_tensor_dict2 = {} |
|
for name in extra_tensor_dict: |
|
if name in ["fp_hidden_states", "ip_hidden_states"]: |
|
v1, v2 = extra_tensor_dict[name][0], extra_tensor_dict[name][1] |
|
extra_tensor_dict2[name] = torch.cat( |
|
[v1.repeat(num_images_per_prompt, 1, 1), v2.repeat(num_images_per_prompt, 1, 1)]) |
|
else: |
|
extra_tensor_dict2[name] = extra_tensor_dict[name] |
|
|
|
|
|
|
|
self._num_timesteps = len(timesteps) |
|
|
|
with self.progress_bar(total=num_inference_steps) as progress_bar: |
|
for i, t in enumerate(timesteps): |
|
|
|
latent_model_input = torch.cat( |
|
[latents] * 2) if do_classifier_free_guidance else latents |
|
|
|
latent_model_input = self.scheduler.scale_model_input( |
|
latent_model_input, t) |
|
|
|
if self.num_channels_unet == 9: |
|
latent_model_input = torch.cat( |
|
[latent_model_input, mask, masked_image_latents], dim=1) |
|
|
|
latent_model_input = latent_model_input.permute( |
|
0, 2, 3, 1).contiguous() |
|
|
|
noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, aug_emb, None, None, |
|
None, None, None, extra_tensor_dict2, param_scale_dict).permute(0, 3, 1, 2).contiguous() |
|
|
|
|
|
if do_classifier_free_guidance: |
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred = noise_pred_uncond + self.guidance_scale * \ |
|
(noise_pred_text - noise_pred_uncond) |
|
|
|
if do_classifier_free_guidance and self.guidance_rescale > 0.0: |
|
|
|
noise_pred = rescale_noise_cfg( |
|
noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) |
|
|
|
|
|
latents = self.scheduler.step( |
|
noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] |
|
|
|
if self.num_channels_unet == 4: |
|
init_latents_proper = image_latents |
|
if do_classifier_free_guidance: |
|
init_mask, _ = mask.chunk(2) |
|
else: |
|
init_mask = mask |
|
|
|
if i < len(timesteps) - 1: |
|
noise_timestep = timesteps[i + 1] |
|
init_latents_proper = self.scheduler.add_noise( |
|
init_latents_proper, noise, torch.tensor( |
|
[noise_timestep]) |
|
) |
|
|
|
latents = (1 - init_mask) * \ |
|
init_latents_proper + init_mask * latents |
|
|
|
|
|
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): |
|
progress_bar.update() |
|
if callback is not None and i % callback_steps == 0: |
|
callback(i, t, latents) |
|
|
|
if output_type == "latent": |
|
return latents |
|
|
|
image = self.vae.decode(1 / self.vae.scaling_factor * latents) |
|
image = self.image_processor.postprocess( |
|
image, output_type=output_type) |
|
|
|
|
|
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: |
|
self.final_offload_hook.offload() |
|
|
|
return image |
|
|