|
import re |
|
from copy import deepcopy |
|
from dataclasses import asdict, dataclass |
|
from enum import Enum |
|
from typing import List, Optional, Union |
|
|
|
import numpy as np |
|
import torch |
|
from numpy import exp, pi, sqrt |
|
from torchvision.transforms.functional import resize |
|
from tqdm.auto import tqdm |
|
from transformers import CLIPImageProcessor, CLIPTextModel, CLIPTokenizer |
|
|
|
from diffusers.models import AutoencoderKL, UNet2DConditionModel |
|
from diffusers.pipelines.pipeline_utils import DiffusionPipeline, StableDiffusionMixin |
|
from diffusers.pipelines.stable_diffusion import StableDiffusionSafetyChecker |
|
from diffusers.schedulers import DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler |
|
|
|
|
|
def preprocess_image(image): |
|
from PIL import Image |
|
|
|
"""Preprocess an input image |
|
|
|
Same as |
|
https://github.com/huggingface/diffusers/blob/1138d63b519e37f0ce04e027b9f4a3261d27c628/src/diffusers/pipelines/stable_diffusion/pipeline_stable_diffusion_img2img.py#L44 |
|
""" |
|
w, h = image.size |
|
w, h = (x - x % 32 for x in (w, h)) |
|
image = image.resize((w, h), resample=Image.LANCZOS) |
|
image = np.array(image).astype(np.float32) / 255.0 |
|
image = image[None].transpose(0, 3, 1, 2) |
|
image = torch.from_numpy(image) |
|
return 2.0 * image - 1.0 |
|
|
|
|
|
@dataclass |
|
class CanvasRegion: |
|
"""Class defining a rectangular region in the canvas""" |
|
|
|
row_init: int |
|
row_end: int |
|
col_init: int |
|
col_end: int |
|
region_seed: int = None |
|
noise_eps: float = 0.0 |
|
|
|
def __post_init__(self): |
|
|
|
if self.region_seed is None: |
|
self.region_seed = np.random.randint(9999999999) |
|
|
|
for coord in [self.row_init, self.row_end, self.col_init, self.col_end]: |
|
if coord < 0: |
|
raise ValueError( |
|
f"A CanvasRegion must be defined with non-negative indices, found ({self.row_init}, {self.row_end}, {self.col_init}, {self.col_end})" |
|
) |
|
|
|
for coord in [self.row_init, self.row_end, self.col_init, self.col_end]: |
|
if coord // 8 != coord / 8: |
|
raise ValueError( |
|
f"A CanvasRegion must be defined with locations divisible by 8, found ({self.row_init}-{self.row_end}, {self.col_init}-{self.col_end})" |
|
) |
|
|
|
if self.noise_eps < 0: |
|
raise ValueError(f"A CanvasRegion must be defined noises eps non-negative, found {self.noise_eps}") |
|
|
|
self.latent_row_init = self.row_init // 8 |
|
self.latent_row_end = self.row_end // 8 |
|
self.latent_col_init = self.col_init // 8 |
|
self.latent_col_end = self.col_end // 8 |
|
|
|
@property |
|
def width(self): |
|
return self.col_end - self.col_init |
|
|
|
@property |
|
def height(self): |
|
return self.row_end - self.row_init |
|
|
|
def get_region_generator(self, device="cpu"): |
|
"""Creates a torch.Generator based on the random seed of this region""" |
|
|
|
return torch.Generator(device).manual_seed(self.region_seed) |
|
|
|
@property |
|
def __dict__(self): |
|
return asdict(self) |
|
|
|
|
|
class MaskModes(Enum): |
|
"""Modes in which the influence of diffuser is masked""" |
|
|
|
CONSTANT = "constant" |
|
GAUSSIAN = "gaussian" |
|
QUARTIC = "quartic" |
|
|
|
|
|
@dataclass |
|
class DiffusionRegion(CanvasRegion): |
|
"""Abstract class defining a region where some class of diffusion process is acting""" |
|
|
|
pass |
|
|
|
|
|
@dataclass |
|
class Text2ImageRegion(DiffusionRegion): |
|
"""Class defining a region where a text guided diffusion process is acting""" |
|
|
|
prompt: str = "" |
|
guidance_scale: float = 7.5 |
|
mask_type: MaskModes = MaskModes.GAUSSIAN.value |
|
mask_weight: float = 1.0 |
|
tokenized_prompt = None |
|
encoded_prompt = None |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
|
|
if self.mask_weight < 0: |
|
raise ValueError( |
|
f"A Text2ImageRegion must be defined with non-negative mask weight, found {self.mask_weight}" |
|
) |
|
|
|
if self.mask_type not in [e.value for e in MaskModes]: |
|
raise ValueError( |
|
f"A Text2ImageRegion was defined with mask {self.mask_type}, which is not an accepted mask ({[e.value for e in MaskModes]})" |
|
) |
|
|
|
if self.guidance_scale is None: |
|
self.guidance_scale = np.random.randint(5, 30) |
|
|
|
self.prompt = re.sub(" +", " ", self.prompt).replace("\n", " ") |
|
|
|
def tokenize_prompt(self, tokenizer): |
|
"""Tokenizes the prompt for this diffusion region using a given tokenizer""" |
|
self.tokenized_prompt = tokenizer( |
|
self.prompt, |
|
padding="max_length", |
|
max_length=tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
|
|
def encode_prompt(self, text_encoder, device): |
|
"""Encodes the previously tokenized prompt for this diffusion region using a given encoder""" |
|
assert self.tokenized_prompt is not None, ValueError( |
|
"Prompt in diffusion region must be tokenized before encoding" |
|
) |
|
self.encoded_prompt = text_encoder(self.tokenized_prompt.input_ids.to(device))[0] |
|
|
|
|
|
@dataclass |
|
class Image2ImageRegion(DiffusionRegion): |
|
"""Class defining a region where an image guided diffusion process is acting""" |
|
|
|
reference_image: torch.Tensor = None |
|
strength: float = 0.8 |
|
|
|
def __post_init__(self): |
|
super().__post_init__() |
|
if self.reference_image is None: |
|
raise ValueError("Must provide a reference image when creating an Image2ImageRegion") |
|
if self.strength < 0 or self.strength > 1: |
|
raise ValueError(f"The value of strength should in [0.0, 1.0] but is {self.strength}") |
|
|
|
self.reference_image = resize(self.reference_image, size=[self.height, self.width]) |
|
|
|
def encode_reference_image(self, encoder, device, generator, cpu_vae=False): |
|
"""Encodes the reference image for this Image2Image region into the latent space""" |
|
|
|
if cpu_vae: |
|
|
|
self.reference_latents = encoder.cpu().encode(self.reference_image).latent_dist.mean.to(device) |
|
else: |
|
self.reference_latents = encoder.encode(self.reference_image.to(device)).latent_dist.sample( |
|
generator=generator |
|
) |
|
self.reference_latents = 0.18215 * self.reference_latents |
|
|
|
@property |
|
def __dict__(self): |
|
|
|
|
|
|
|
super_fields = {key: getattr(self, key) for key in DiffusionRegion.__dataclass_fields__.keys()} |
|
|
|
return {**super_fields, "reference_image": self.reference_image.cpu().tolist(), "strength": self.strength} |
|
|
|
|
|
class RerollModes(Enum): |
|
"""Modes in which the reroll regions operate""" |
|
|
|
RESET = "reset" |
|
EPSILON = "epsilon" |
|
|
|
|
|
@dataclass |
|
class RerollRegion(CanvasRegion): |
|
"""Class defining a rectangular canvas region in which initial latent noise will be rerolled""" |
|
|
|
reroll_mode: RerollModes = RerollModes.RESET.value |
|
|
|
|
|
@dataclass |
|
class MaskWeightsBuilder: |
|
"""Auxiliary class to compute a tensor of weights for a given diffusion region""" |
|
|
|
latent_space_dim: int |
|
nbatch: int = 1 |
|
|
|
def compute_mask_weights(self, region: DiffusionRegion) -> torch.tensor: |
|
"""Computes a tensor of weights for a given diffusion region""" |
|
MASK_BUILDERS = { |
|
MaskModes.CONSTANT.value: self._constant_weights, |
|
MaskModes.GAUSSIAN.value: self._gaussian_weights, |
|
MaskModes.QUARTIC.value: self._quartic_weights, |
|
} |
|
return MASK_BUILDERS[region.mask_type](region) |
|
|
|
def _constant_weights(self, region: DiffusionRegion) -> torch.tensor: |
|
"""Computes a tensor of constant for a given diffusion region""" |
|
latent_width = region.latent_col_end - region.latent_col_init |
|
latent_height = region.latent_row_end - region.latent_row_init |
|
return torch.ones(self.nbatch, self.latent_space_dim, latent_height, latent_width) * region.mask_weight |
|
|
|
def _gaussian_weights(self, region: DiffusionRegion) -> torch.tensor: |
|
"""Generates a gaussian mask of weights for tile contributions""" |
|
latent_width = region.latent_col_end - region.latent_col_init |
|
latent_height = region.latent_row_end - region.latent_row_init |
|
|
|
var = 0.01 |
|
midpoint = (latent_width - 1) / 2 |
|
x_probs = [ |
|
exp(-(x - midpoint) * (x - midpoint) / (latent_width * latent_width) / (2 * var)) / sqrt(2 * pi * var) |
|
for x in range(latent_width) |
|
] |
|
midpoint = (latent_height - 1) / 2 |
|
y_probs = [ |
|
exp(-(y - midpoint) * (y - midpoint) / (latent_height * latent_height) / (2 * var)) / sqrt(2 * pi * var) |
|
for y in range(latent_height) |
|
] |
|
|
|
weights = np.outer(y_probs, x_probs) * region.mask_weight |
|
return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1)) |
|
|
|
def _quartic_weights(self, region: DiffusionRegion) -> torch.tensor: |
|
"""Generates a quartic mask of weights for tile contributions |
|
|
|
The quartic kernel has bounded support over the diffusion region, and a smooth decay to the region limits. |
|
""" |
|
quartic_constant = 15.0 / 16.0 |
|
|
|
support = (np.array(range(region.latent_col_init, region.latent_col_end)) - region.latent_col_init) / ( |
|
region.latent_col_end - region.latent_col_init - 1 |
|
) * 1.99 - (1.99 / 2.0) |
|
x_probs = quartic_constant * np.square(1 - np.square(support)) |
|
support = (np.array(range(region.latent_row_init, region.latent_row_end)) - region.latent_row_init) / ( |
|
region.latent_row_end - region.latent_row_init - 1 |
|
) * 1.99 - (1.99 / 2.0) |
|
y_probs = quartic_constant * np.square(1 - np.square(support)) |
|
|
|
weights = np.outer(y_probs, x_probs) * region.mask_weight |
|
return torch.tile(torch.tensor(weights), (self.nbatch, self.latent_space_dim, 1, 1)) |
|
|
|
|
|
class StableDiffusionCanvasPipeline(DiffusionPipeline, StableDiffusionMixin): |
|
"""Stable Diffusion pipeline that mixes several diffusers in the same canvas""" |
|
|
|
def __init__( |
|
self, |
|
vae: AutoencoderKL, |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
unet: UNet2DConditionModel, |
|
scheduler: Union[DDIMScheduler, LMSDiscreteScheduler, PNDMScheduler], |
|
safety_checker: StableDiffusionSafetyChecker, |
|
feature_extractor: CLIPImageProcessor, |
|
): |
|
super().__init__() |
|
self.register_modules( |
|
vae=vae, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
unet=unet, |
|
scheduler=scheduler, |
|
safety_checker=safety_checker, |
|
feature_extractor=feature_extractor, |
|
) |
|
|
|
def decode_latents(self, latents, cpu_vae=False): |
|
"""Decodes a given array of latents into pixel space""" |
|
|
|
if cpu_vae: |
|
lat = deepcopy(latents).cpu() |
|
vae = deepcopy(self.vae).cpu() |
|
else: |
|
lat = latents |
|
vae = self.vae |
|
|
|
lat = 1 / 0.18215 * lat |
|
image = vae.decode(lat).sample |
|
|
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
image = image.cpu().permute(0, 2, 3, 1).numpy() |
|
|
|
return self.numpy_to_pil(image) |
|
|
|
def get_latest_timestep_img2img(self, num_inference_steps, strength): |
|
"""Finds the latest timesteps where an img2img strength does not impose latents anymore""" |
|
|
|
offset = self.scheduler.config.get("steps_offset", 0) |
|
init_timestep = int(num_inference_steps * (1 - strength)) + offset |
|
init_timestep = min(init_timestep, num_inference_steps) |
|
|
|
t_start = min(max(num_inference_steps - init_timestep + offset, 0), num_inference_steps - 1) |
|
latest_timestep = self.scheduler.timesteps[t_start] |
|
|
|
return latest_timestep |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
canvas_height: int, |
|
canvas_width: int, |
|
regions: List[DiffusionRegion], |
|
num_inference_steps: Optional[int] = 50, |
|
seed: Optional[int] = 12345, |
|
reroll_regions: Optional[List[RerollRegion]] = None, |
|
cpu_vae: Optional[bool] = False, |
|
decode_steps: Optional[bool] = False, |
|
): |
|
if reroll_regions is None: |
|
reroll_regions = [] |
|
batch_size = 1 |
|
|
|
if decode_steps: |
|
steps_images = [] |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=self.device) |
|
|
|
|
|
text2image_regions = [region for region in regions if isinstance(region, Text2ImageRegion)] |
|
image2image_regions = [region for region in regions if isinstance(region, Image2ImageRegion)] |
|
|
|
|
|
for region in text2image_regions: |
|
region.tokenize_prompt(self.tokenizer) |
|
region.encode_prompt(self.text_encoder, self.device) |
|
|
|
|
|
latents_shape = (batch_size, self.unet.config.in_channels, canvas_height // 8, canvas_width // 8) |
|
generator = torch.Generator(self.device).manual_seed(seed) |
|
init_noise = torch.randn(latents_shape, generator=generator, device=self.device) |
|
|
|
|
|
for region in reroll_regions: |
|
if region.reroll_mode == RerollModes.RESET.value: |
|
region_shape = ( |
|
latents_shape[0], |
|
latents_shape[1], |
|
region.latent_row_end - region.latent_row_init, |
|
region.latent_col_end - region.latent_col_init, |
|
) |
|
init_noise[ |
|
:, |
|
:, |
|
region.latent_row_init : region.latent_row_end, |
|
region.latent_col_init : region.latent_col_end, |
|
] = torch.randn(region_shape, generator=region.get_region_generator(self.device), device=self.device) |
|
|
|
|
|
all_eps_rerolls = regions + [r for r in reroll_regions if r.reroll_mode == RerollModes.EPSILON.value] |
|
for region in all_eps_rerolls: |
|
if region.noise_eps > 0: |
|
region_noise = init_noise[ |
|
:, |
|
:, |
|
region.latent_row_init : region.latent_row_end, |
|
region.latent_col_init : region.latent_col_end, |
|
] |
|
eps_noise = ( |
|
torch.randn( |
|
region_noise.shape, generator=region.get_region_generator(self.device), device=self.device |
|
) |
|
* region.noise_eps |
|
) |
|
init_noise[ |
|
:, |
|
:, |
|
region.latent_row_init : region.latent_row_end, |
|
region.latent_col_init : region.latent_col_end, |
|
] += eps_noise |
|
|
|
|
|
latents = init_noise * self.scheduler.init_noise_sigma |
|
|
|
|
|
for region in text2image_regions: |
|
max_length = region.tokenized_prompt.input_ids.shape[-1] |
|
uncond_input = self.tokenizer( |
|
[""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt" |
|
) |
|
uncond_embeddings = self.text_encoder(uncond_input.input_ids.to(self.device))[0] |
|
|
|
|
|
|
|
|
|
region.encoded_prompt = torch.cat([uncond_embeddings, region.encoded_prompt]) |
|
|
|
|
|
for region in image2image_regions: |
|
region.encode_reference_image(self.vae, device=self.device, generator=generator) |
|
|
|
|
|
mask_builder = MaskWeightsBuilder(latent_space_dim=self.unet.config.in_channels, nbatch=batch_size) |
|
mask_weights = [mask_builder.compute_mask_weights(region).to(self.device) for region in text2image_regions] |
|
|
|
|
|
for i, t in tqdm(enumerate(self.scheduler.timesteps)): |
|
|
|
noise_preds_regions = [] |
|
|
|
|
|
for region in text2image_regions: |
|
region_latents = latents[ |
|
:, |
|
:, |
|
region.latent_row_init : region.latent_row_end, |
|
region.latent_col_init : region.latent_col_end, |
|
] |
|
|
|
latent_model_input = torch.cat([region_latents] * 2) |
|
|
|
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) |
|
|
|
noise_pred = self.unet(latent_model_input, t, encoder_hidden_states=region.encoded_prompt)["sample"] |
|
|
|
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2) |
|
noise_pred_region = noise_pred_uncond + region.guidance_scale * (noise_pred_text - noise_pred_uncond) |
|
noise_preds_regions.append(noise_pred_region) |
|
|
|
|
|
noise_pred = torch.zeros(latents.shape, device=self.device) |
|
contributors = torch.zeros(latents.shape, device=self.device) |
|
|
|
for region, noise_pred_region, mask_weights_region in zip( |
|
text2image_regions, noise_preds_regions, mask_weights |
|
): |
|
noise_pred[ |
|
:, |
|
:, |
|
region.latent_row_init : region.latent_row_end, |
|
region.latent_col_init : region.latent_col_end, |
|
] += noise_pred_region * mask_weights_region |
|
contributors[ |
|
:, |
|
:, |
|
region.latent_row_init : region.latent_row_end, |
|
region.latent_col_init : region.latent_col_end, |
|
] += mask_weights_region |
|
|
|
noise_pred /= contributors |
|
noise_pred = torch.nan_to_num( |
|
noise_pred |
|
) |
|
|
|
|
|
latents = self.scheduler.step(noise_pred, t, latents).prev_sample |
|
|
|
|
|
for region in image2image_regions: |
|
influence_step = self.get_latest_timestep_img2img(num_inference_steps, region.strength) |
|
|
|
if t > influence_step: |
|
timestep = t.repeat(batch_size) |
|
region_init_noise = init_noise[ |
|
:, |
|
:, |
|
region.latent_row_init : region.latent_row_end, |
|
region.latent_col_init : region.latent_col_end, |
|
] |
|
region_latents = self.scheduler.add_noise(region.reference_latents, region_init_noise, timestep) |
|
latents[ |
|
:, |
|
:, |
|
region.latent_row_init : region.latent_row_end, |
|
region.latent_col_init : region.latent_col_end, |
|
] = region_latents |
|
|
|
if decode_steps: |
|
steps_images.append(self.decode_latents(latents, cpu_vae)) |
|
|
|
|
|
image = self.decode_latents(latents, cpu_vae) |
|
|
|
output = {"images": image} |
|
if decode_steps: |
|
output = {**output, "steps_images": steps_images} |
|
return output |
|
|