import inspect import os import time from typing import Any, Callable, Dict, List, Optional, Union, Tuple import gc import torch import numpy as np from glob import glob import PIL from diffusers import StableDiffusionXLPipeline, UNet2DConditionModel from diffusers.loaders import TextualInversionLoaderMixin from diffusers.image_processor import VaeImageProcessor from diffusers.models import AutoencoderKL from diffusers.schedulers import (DPMSolverMultistepScheduler, EulerAncestralDiscreteScheduler, EulerDiscreteScheduler, KarrasDiffusionSchedulers) from diffusers.models.embeddings import TimestepEmbedding, Timesteps from diffusers.utils.torch_utils import randn_tensor from diffusers.utils import logging from PIL import Image from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer, CLIPTextModelWithProjection from diffusers.utils import PIL_INTERPOLATION from .lyrasd_vae_model import LyraSdVaeModel from .lora_util import add_text_lora_layer, add_xltext_lora_layer, add_lora_to_opt_model, load_state_dict from safetensors.torch import load_file from .lyrasdxl_pipeline_base import LyraSDXLPipelineBase def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): """ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 """ std_text = noise_pred_text.std( dim=list(range(1, noise_pred_text.ndim)), keepdim=True) std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) # rescale the results from guidance (fixes overexposure) noise_pred_rescaled = noise_cfg * (std_text / std_cfg) # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images noise_cfg = guidance_rescale * noise_pred_rescaled + \ (1 - guidance_rescale) * noise_cfg return noise_cfg class LyraSdXLControlnetTxt2ImgPipeline(LyraSDXLPipelineBase, StableDiffusionXLPipeline): device = torch.device("cpu") dtype = torch.float32 def __init__(self, device=torch.device("cuda"), dtype=torch.float16, vae_scale_factor=8, vae_scaling_factor=0.13025) -> None: self.register_to_config(force_zeros_for_empty_prompt=True) super().__init__(device, dtype, vae_scale_factor=vae_scale_factor, vae_scaling_factor=vae_scaling_factor) def prepare_image( self, image, width, height, batch_size, num_images_per_prompt, device, dtype, do_classifier_free_guidance=False, guess_mode=False, ): image = self.control_image_processor.preprocess(image, height, width) image = image.permute(0, 2, 3, 1) image = image.to(device=device, dtype=dtype) # print(image.shape) # print(image) return image @property def _execution_device(self): if not hasattr(self.unet, "_hf_hook"): return self.device for module in self.unet.modules(): if ( hasattr(module, "_hf_hook") and hasattr(module._hf_hook, "execution_device") and module._hf_hook.execution_device is not None ): return torch.device(module._hf_hook.execution_device) return self.device def _get_aug_emb(self, add_embedding, time_ids, text_embeds, dtype): time_embeds = self.add_time_proj(time_ids.flatten()) time_embeds = time_embeds.reshape((text_embeds.shape[0], -1)) add_embeds = torch.concat([text_embeds, time_embeds], dim=-1) add_embeds = add_embeds.to(dtype) aug_emb = add_embedding(add_embeds) return aug_emb @torch.no_grad() def __call__( self, prompt: Union[str, List[str]] = None, prompt_2: Optional[Union[str, List[str]]] = None, height: Optional[int] = None, width: Optional[int] = None, num_inference_steps: int = 50, denoising_end: Optional[float] = None, guidance_scale: float = 5.0, negative_prompt: Optional[Union[str, List[str]]] = None, negative_prompt_2: Optional[Union[str, List[str]]] = None, num_images_per_prompt: Optional[int] = 1, controlnet_names: Optional[List[str]] = None, controlnet_images: Optional[List[PIL.Image.Image]] = None, controlnet_scale: Optional[List[float]] = None, guess_mode=False, eta: float = 0.0, generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, latents: Optional[torch.FloatTensor] = None, prompt_embeds: Optional[torch.FloatTensor] = None, negative_prompt_embeds: Optional[torch.FloatTensor] = None, pooled_prompt_embeds: Optional[torch.FloatTensor] = None, negative_pooled_prompt_embeds: Optional[torch.FloatTensor] = None, output_type: Optional[str] = "pil", return_dict: bool = True, callback: Optional[Callable[[ int, int, torch.FloatTensor], None]] = None, callback_steps: int = 1, cross_attention_kwargs: Optional[Dict[str, Any]] = None, guidance_rescale: float = 0.0, original_size: Optional[Tuple[int, int]] = None, crops_coords_top_left: Tuple[int, int] = (0, 0), target_size: Optional[Tuple[int, int]] = None, ): # 0. Default height and width to unet height = height or self.default_sample_size * self.vae_scale_factor width = width or self.default_sample_size * self.vae_scale_factor original_size = original_size or (height, width) target_size = target_size or (height, width) # 1. Check inputs. Raise error if not correct self.check_inputs( prompt, prompt_2, height, width, callback_steps, negative_prompt, negative_prompt_2, prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) # 2. Define call parameters if prompt is not None and isinstance(prompt, str): batch_size = 1 elif prompt is not None and isinstance(prompt, list): batch_size = len(prompt) else: batch_size = prompt_embeds.shape[0] device = self._execution_device do_classifier_free_guidance = guidance_scale > 1.0 # 3. Encode input prompt text_encoder_lora_scale = ( cross_attention_kwargs.get( "scale", None) if cross_attention_kwargs is not None else None ) ( prompt_embeds, negative_prompt_embeds, pooled_prompt_embeds, negative_pooled_prompt_embeds, ) = self.encode_prompt( prompt=prompt, prompt_2=prompt_2, device=device, num_images_per_prompt=num_images_per_prompt, do_classifier_free_guidance=do_classifier_free_guidance, negative_prompt=negative_prompt, negative_prompt_2=negative_prompt_2, prompt_embeds=prompt_embeds, negative_prompt_embeds=negative_prompt_embeds, pooled_prompt_embeds=pooled_prompt_embeds, negative_pooled_prompt_embeds=negative_pooled_prompt_embeds, lora_scale=text_encoder_lora_scale, ) control_images = [] for image_ in controlnet_images: image_ = self.prepare_image( image=image_, width=width, height=height, batch_size=batch_size * num_images_per_prompt, num_images_per_prompt=num_images_per_prompt, device=device, dtype=prompt_embeds.dtype, do_classifier_free_guidance=do_classifier_free_guidance ) control_images.append(image_) control_scales = [] scales = [1.0, ] * 10 if guess_mode: scales = torch.logspace(-1, 0, 10).tolist() for scale in controlnet_scale: scales_ = [d * scale for d in scales] control_scales.append(scales_) # 4. Prepare timesteps self.scheduler.set_timesteps(num_inference_steps, device=device) timesteps = self.scheduler.timesteps # 5. Prepare latent variables num_channels_latents = self.unet_in_channels latents = self.prepare_latents( batch_size * num_images_per_prompt, num_channels_latents, height, width, prompt_embeds.dtype, device, generator, latents, ) # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta) # 7. Prepare added time ids & embeddings add_text_embeds = pooled_prompt_embeds add_time_ids = list( original_size + crops_coords_top_left + target_size) add_time_ids = torch.tensor([add_time_ids], dtype=prompt_embeds.dtype) if do_classifier_free_guidance: prompt_embeds = torch.cat( [negative_prompt_embeds, prompt_embeds], dim=0) add_text_embeds = torch.cat( [negative_pooled_prompt_embeds, add_text_embeds], dim=0) add_time_ids = torch.cat([add_time_ids, add_time_ids], dim=0) prompt_embeds = prompt_embeds.to(device) add_text_embeds = add_text_embeds.to(device) add_time_ids = add_time_ids.to(device).repeat( batch_size * num_images_per_prompt, 1) # 8. Denoising loop num_warmup_steps = max( len(timesteps) - num_inference_steps * self.scheduler.order, 0) # 7.1 Apply denoising_end if denoising_end is not None and type(denoising_end) == float and denoising_end > 0 and denoising_end < 1: discrete_timestep_cutoff = int( round( self.scheduler.config.num_train_timesteps - (denoising_end * self.scheduler.config.num_train_timesteps) ) ) num_inference_steps = len( list(filter(lambda ts: ts >= discrete_timestep_cutoff, timesteps))) timesteps = timesteps[:num_inference_steps] aug_emb = self._get_aug_emb( self.add_embedding, add_time_ids, add_text_embeds, prompt_embeds.dtype) controlnet_aug_embs = [] for controlnet_name in controlnet_names: controlnet_aug_embs.append(self._get_aug_emb(self.controlnet_add_embedding[controlnet_name], add_time_ids, add_text_embeds, prompt_embeds.dtype)) with self.progress_bar(total=num_inference_steps) as progress_bar: for i, t in enumerate(timesteps): # expand the latents if we are doing classifier free guidance latent_model_input = torch.cat( [latents] * 2) if do_classifier_free_guidance else latents latent_model_input = self.scheduler.scale_model_input( latent_model_input, t) latent_model_input = latent_model_input.permute( 0, 2, 3, 1).contiguous() noise_pred = self.unet.forward( latent_model_input, prompt_embeds, t, aug_emb, controlnet_names, control_images, controlnet_aug_embs, control_scales, guess_mode).permute(0, 3, 1, 2) # print(noise_pred) # perform guidance if do_classifier_free_guidance: noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) noise_pred = noise_pred_uncond + guidance_scale * \ (noise_pred_text - noise_pred_uncond) if do_classifier_free_guidance and guidance_rescale > 0.0: # Based on 3.4. in https://arxiv.org/pdf/2305.08891.pdf noise_pred = rescale_noise_cfg( noise_pred, noise_pred_text, guidance_rescale=guidance_rescale) # compute the previous noisy sample x_t -> x_t-1 latents = self.scheduler.step( noise_pred, t, latents, **extra_step_kwargs, return_dict=False)[0] # call the callback, if provided if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): progress_bar.update() if callback is not None and i % callback_steps == 0: callback(i, t, latents) # make sure the VAE is in float32 mode, as it overflows in float16 # if self.vae.dtype == torch.float16 and self.vae.config.force_upcast: # self.upcast_vae() # latents = latents.to( # next(iter(self.vae.post_quant_conv.parameters())).dtype) # # latents = latents.to(torch.float32) # if output_type == "latent": # return latents # np.save(f"/workspace/latents.npy", latents.detach().cpu().numpy()) # image = self.vae.decode( # latents / self.vae.config.scaling_factor, return_dict=False)[0] image = self.vae.decode(1 / self.vae.scaling_factor * latents) image = self.image_processor.postprocess( image, output_type=output_type) # Offload last model to CPU if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None: self.final_offload_hook.offload() return image