import inspect import os import time from typing import Any, Callable, Dict, List, Optional, Union, Tuple import gc import torch import numpy as np from glob import glob from diffusers import StableDiffusionXLInpaintPipeline, UNet2DConditionModel from diffusers.loaders import TextualInversionLoaderMixin from diffusers.image_processor import VaeImageProcessor, PipelineImageInput from diffusers.models import AutoencoderKL from diffusers.schedulers import (DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, KarrasDiffusionSchedulers) from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.utils.torch_utils import randn_tensor from diffusers.utils import logging from PIL import Image from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from .lyrasd_vae_model import LyraSdVaeModel from .module.lyrasd_ip_adapter import LyraIPAdapter from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict from safetensors.torch import load_file from .lyrasdxl_pipeline_base import LyraSDXLPipelineBase def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std( dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + \ (1 - guidance_rescale) * noise_cfg return noise_cfg def retrieve_latents( encoder_output: torch.Tensor, generator: Optional[torch.Generator] = None, sample_mode: str = "sample" ): if sample_mode == "sample": return encoder_output.sample(generator) elif sample_mode == "argmax": return encoder_output.mode() else: return encoder_output def retrieve_timesteps( scheduler, num_inference_steps: Optional[int] = None, device: Optional[Union[str, torch.device]] = None, timesteps: Optional[List[int]] = None, **kwargs, ): """ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. Args: scheduler (`SchedulerMixin`): The scheduler to get timesteps from. num_inference_steps (`int`): The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` must be `None`. device (`str` or `torch.device`, *optional*): The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. timesteps (`List[int]`, *optional*): Custom timesteps used to support arbitrary spacing between timesteps. If `None`, then the default timestep spacing strategy of the scheduler is used. If `timesteps` is passed, `num_inference_steps` must be `None`. Returns: `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the second element is the number of inference steps. """ if timesteps is not None: accepts_timesteps = "timesteps" in set( inspect.signature(scheduler.set_timesteps).parameters.keys()) if not accepts_timesteps: raise ValueError( f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" f" timestep schedules. Please check whether you are using the correct scheduler." ) scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) timesteps = scheduler.timesteps num_inference_steps = len(timesteps) else: scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) timesteps = scheduler.timesteps return timesteps, num_inference_steps class LyraSdXLTxt2ImgInpaintPipeline(LyraSDXLPipelineBase, StableDiffusionXLInpaintPipeline): device = torch.device("cpu") dtype = torch.float32 def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.13025, num_channels_unet=9, num_channels_latents=4, requires_aesthetics_score: bool = False, force_zeros_for_empty_prompt: bool = True) -> None: self.register_to_config( force_zeros_for_empty_prompt=force_zeros_for_empty_prompt) self.register_to_config( requires_aesthetics_score=requires_aesthetics_score) super().__init__(device, dtype, num_channels_unet=num_channels_unet, num_channels_latents=num_channels_latents, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor) def encode_image(self, image, device, num_images_per_prompt): dtype = next(self.image_encoder.parameters()).dtype if not isinstance(image, torch.Tensor): image = self.feature_extractor( image, return_tensors="pt").pixel_values image = image.to(device=device, dtype=dtype) image_embeds = self.image_encoder(image).image_embeds image_embeds = image_embeds.repeat_interleave( num_images_per_prompt, dim=0) uncond_image_embeds = torch.zeros_like(image_embeds) return image_embeds, uncond_image_embeds def _encode_vae_image(self, image: torch.Tensor, generator: torch.Generator): dtype = image.dtype # if self.vae.config.force_upcast: # image = image.float() # self.vae.to(dtype=torch.float32) if isinstance(generator, list): image_latents = [ retrieve_latents(self.vae.encode(image[i : i + 1]), generator=generator[i]) for i in range(image.shape[0]) ] image_latents = torch.cat(image_latents, dim=0) else: image_latents = retrieve_latents(self.vae.encode(image), generator=generator) image_latents = image_latents.to(dtype) image_latents = self.vae.scaling_factor * image_latents return image_latents def _get_add_time_ids( self, original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype, text_encoder_projection_dim=None, ): if self.config.requires_aesthetics_score: add_time_ids = list(original_size + crops_coords_top_left + (aesthetic_score,)) add_neg_time_ids = list( negative_original_size + negative_crops_coords_top_left + (negative_aesthetic_score,) ) else: add_time_ids = list(original_size + crops_coords_top_left + target_size) add_neg_time_ids = list(negative_original_size + crops_coords_top_left + negative_target_size) passed_add_embed_dim = ( self.addition_time_embed_dim * len(add_time_ids) + text_encoder_projection_dim ) expected_add_embed_dim = self.add_embedding.linear_1.in_features if ( expected_add_embed_dim > passed_add_embed_dim and (expected_add_embed_dim - passed_add_embed_dim) == self.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to enable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=True)` to make sure `aesthetic_score` {aesthetic_score} and `negative_aesthetic_score` {negative_aesthetic_score} is correctly used by the model." ) elif ( expected_add_embed_dim < passed_add_embed_dim and (passed_add_embed_dim - expected_add_embed_dim) == self.addition_time_embed_dim ): raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. Please make sure to disable `requires_aesthetics_score` with `pipe.register_to_config(requires_aesthetics_score=False)` to make sure `target_size` {target_size} is correctly used by the model." ) elif expected_add_embed_dim != passed_add_embed_dim: raise ValueError( f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`." ) add_time_ids = torch.tensor([add_time_ids], dtype=dtype) add_neg_time_ids = torch.tensor([add_neg_time_ids], dtype=dtype) return add_time_ids, add_neg_time_ids def load_ip_adapter(self, dir_ip_adapter, ip_plus, image_encoder_path, num_ip_tokens, ip_projection_dim, dir_face_in=None, num_fp_tokens=1, fp_projection_dim=None, sdxl=True): self.ip_adapter_helper = LyraIPAdapter(self, sdxl, "cuda", dir_ip_adapter, ip_plus, image_encoder_path, num_ip_tokens, ip_projection_dim, dir_face_in, num_fp_tokens, fp_projection_dim) @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, image: PipelineImageInput = None, mask_image: PipelineImageInput = None, masked_image_latents: torch.FloatTensor = None, height: Optional[int] = None, width: Optional[int] = None, strength: float = 0.9999, num_inference_steps: int = 50, timesteps: List[int] = None, denoising_start: Optional[float] = None, denoising_end: Optional[float] = None, guidance_scale: float = 7.5, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Tuple[int, int] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Tuple[int, int] = None, negative_original_size: Optional[Tuple[int, int]] = None, negative_crops_coords_top_left: Tuple[int, int] = (0, 0), negative_target_size: Optional[Tuple[int, int]] = None, aesthetic_score: float = 6.0, negative_aesthetic_score: float = 2.5, clip_skip: Optional[int] = None, extra_tensor_dict: Optional[Dict[str, torch.FloatTensor]] = {}, param_scale_dict: Optional[Dict[str, int]] = {}, **kwargs ): callback = kwargs.pop("callback", None) callback_steps = kwargs.pop("callback_steps", None) # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) self._guidance_scale = guidance_scale self._guidance_rescale = guidance_rescale self._clip_skip = clip_skip self._cross_attention_kwargs = cross_attention_kwargs self._denoising_end = denoising_end self._denoising_start = denoising_start # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, strength, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get( "scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, clip_skip=clip_skip ) def denoising_value_valid(dnv): return isinstance(self.denoising_end, float) and 0 < dnv < 1 # 4. Prepare timesteps timesteps, num_inference_steps = retrieve_timesteps( self.scheduler, num_inference_steps, device, timesteps) timesteps, num_inference_steps = self.get_timesteps( num_inference_steps, strength, device, denoising_start=self.denoising_start if denoising_value_valid else None, ) latent_timestep = timesteps[:1].repeat( batch_size * num_images_per_prompt) is_strength_max = strength == 1.0 # 5. Prepare latent variables init_image = self.image_processor.preprocess( image, height=height, width=width) init_image = init_image.to(dtype=torch.float32) mask = self.mask_processor.preprocess( mask_image, height=height, width=width) if masked_image_latents is not None: masked_image = masked_image_latents elif init_image.shape[1] == 4: # if images are in latent space, we can't mask it masked_image = None else: masked_image = init_image * (mask < 0.5) add_noise = True if self.denoising_start is None else False return_image_latents = self.num_channels_unet == 4 latents_outputs = self.prepare_latents( batch_size * num_images_per_prompt, self.num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, image=init_image, timestep=latent_timestep, is_strength_max=is_strength_max, add_noise=add_noise, return_noise=True, return_image_latents=return_image_latents, ) if return_image_latents: latents, noise, image_latents = latents_outputs else: latents, noise = latents_outputs mask, masked_image_latents = self.prepare_mask_latents( mask, masked_image, batch_size * num_images_per_prompt, height, width, prompt_embeds.dtype, device, generator, do_classifier_free_guidance, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds if self.text_encoder_2 is None: text_encoder_projection_dim = int(pooled_prompt_embeds.shape[-1]) else: text_encoder_projection_dim = self.text_encoder_2.config.projection_dim if negative_original_size is None: negative_original_size = original_size if negative_target_size is None: negative_target_size = target_size add_time_ids, add_neg_time_ids = self._get_add_time_ids( original_size, crops_coords_top_left, target_size, aesthetic_score, negative_aesthetic_score, negative_original_size, negative_crops_coords_top_left, negative_target_size, dtype=prompt_embeds.dtype, text_encoder_projection_dim=text_encoder_projection_dim, ) add_time_ids = add_time_ids.repeat( batch_size * num_images_per_prompt, 1) if do_classifier_free_guidance: prompt_embeds = torch.cat( [negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat( [negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_neg_time_ids = add_neg_time_ids.repeat( batch_size * num_images_per_prompt, 1) add_time_ids = torch.cat([add_neg_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device) # 8. Denoising loop num_warmup_steps = max( len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 7.1 Apply denoising_end if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len( list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] aug_emb = self._get_aug_emb( add_time_ids, add_text_embeds, prompt_embeds.dtype) extra_tensor_dict2 = {} for name in extra_tensor_dict: if name in ["fp_hidden_states", "ip_hidden_states"]: v1, v2 = extra_tensor_dict[name][0], extra_tensor_dict[name][1] extra_tensor_dict2[name] = torch.cat( [v1.repeat(num_images_per_prompt, 1, 1), v2.repeat(num_images_per_prompt, 1, 1)]) else: extra_tensor_dict2[name] = extra_tensor_dict[name] # np.save("/workspace/prompt_embeds.npy", prompt_embeds.detach().cpu().numpy()) # prompt_embeds = torch.from_numpy(np.load("/workspace/gt_prompt_embeds.npy")).cuda() self._num_timesteps = len(timesteps) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat( [latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input( latent_model_input, t) if self.num_channels_unet == 9: latent_model_input = torch.cat( [latent_model_input, mask, masked_image_latents], dim=1) latent_model_input = latent_model_input.permute( 0, 2, 3, 1).contiguous() noise_pred = self.unet.forward(latent_model_input, prompt_embeds, t, aug_emb, None, None, None, None, None, extra_tensor_dict2, param_scale_dict).permute(0, 3, 1, 2).contiguous() # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + self.guidance_scale * \ (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and self.guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg( noise_pred, noise_pred_text, guidance_rescale=self.guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] if self.num_channels_unet == 4: init_latents_proper = image_latents if do_classifier_free_guidance: init_mask, _ = mask.chunk(2) else: init_mask = mask if i < len(timesteps) - 1: noise_timestep = timesteps[i + 1] init_latents_proper = self.scheduler.add_noise( init_latents_proper, noise, torch.tensor( [noise_timestep]) ) latents = (1 - init_mask) * \ init_latents_proper + init_mask * latents # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) if output_type == "latent": return latents image = self.vae.decode(1 / self.vae.scaling_factor * latents) image = self.image_processor.postprocess( image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() return image