|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
from diffusers.image_processor import VaeImageProcessor |
|
import pdb |
|
from typing import Dict, Optional, Union |
|
import PIL.Image |
|
import numpy as np |
|
import torch |
|
from diffusers import ( |
|
AutoencoderKL, |
|
DDIMScheduler, |
|
DiffusionPipeline, |
|
LCMScheduler, |
|
PNDMScheduler, |
|
UNet2DConditionModel, |
|
) |
|
from .duplicate_unet import DoubleUNet2DConditionModel |
|
from torch.nn import Conv2d |
|
from PIL import ImageDraw, ImageFont |
|
from torch.nn.parameter import Parameter |
|
from diffusers.utils import BaseOutput, make_image_grid |
|
from PIL import Image |
|
from torch.utils.data import DataLoader, TensorDataset |
|
from torchvision.transforms import InterpolationMode |
|
from torchvision.transforms.functional import pil_to_tensor, resize |
|
from tqdm.auto import tqdm |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
from .util.batchsize import find_batch_size |
|
from .util.ensemble import ensemble_depth |
|
from .util.image_util import ( |
|
chw2hwc, |
|
colorize_depth_maps, |
|
get_tv_resample_method, |
|
resize_max_res, |
|
) |
|
|
|
def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): |
|
""" |
|
Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and |
|
Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 |
|
""" |
|
std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) |
|
std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) |
|
|
|
noise_pred_rescaled = noise_cfg * (std_text / std_cfg) |
|
|
|
noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg |
|
return noise_cfg |
|
|
|
class MarigoldDepthOutput(BaseOutput): |
|
""" |
|
Output class for Marigold monocular depth prediction pipeline. |
|
|
|
Args: |
|
depth_np (`np.ndarray`): |
|
Predicted depth map, with depth values in the range of [0, 1]. |
|
depth_colored (`PIL.Image.Image`): |
|
Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. |
|
uncertainty (`None` or `np.ndarray`): |
|
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. |
|
""" |
|
|
|
depth_np: np.ndarray |
|
depth_colored: Union[None, Image.Image] |
|
uncertainty: Union[None, np.ndarray] |
|
|
|
class MarigoldInpaintPipeline(DiffusionPipeline): |
|
""" |
|
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io. |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
|
|
|
Args: |
|
unet (`UNet2DConditionModel`): |
|
Conditional U-Net to denoise the depth latent, conditioned on image latent. |
|
vae (`AutoencoderKL`): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps |
|
to and from latent representations. |
|
scheduler (`DDIMScheduler`): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. |
|
text_encoder (`CLIPTextModel`): |
|
Text-encoder, for empty text embedding. |
|
tokenizer (`CLIPTokenizer`): |
|
CLIP tokenizer. |
|
scale_invariant (`bool`, *optional*): |
|
A model property specifying whether the predicted depth maps are scale-invariant. This value must be set in |
|
the model config. When used together with the `shift_invariant=True` flag, the model is also called |
|
"affine-invariant". NB: overriding this value is not supported. |
|
shift_invariant (`bool`, *optional*): |
|
A model property specifying whether the predicted depth maps are shift-invariant. This value must be set in |
|
the model config. When used together with the `scale_invariant=True` flag, the model is also called |
|
"affine-invariant". NB: overriding this value is not supported. |
|
default_denoising_steps (`int`, *optional*): |
|
The minimum number of denoising diffusion steps that are required to produce a prediction of reasonable |
|
quality with the given model. This value must be set in the model config. When the pipeline is called |
|
without explicitly setting `num_inference_steps`, the default value is used. This is required to ensure |
|
reasonable results with various model flavors compatible with the pipeline, such as those relying on very |
|
short denoising schedules (`LCMScheduler`) and those with full diffusion schedules (`DDIMScheduler`). |
|
default_processing_resolution (`int`, *optional*): |
|
The recommended value of the `processing_resolution` parameter of the pipeline. This value must be set in |
|
the model config. When the pipeline is called without explicitly setting `processing_resolution`, the |
|
default value is used. This is required to ensure reasonable results with various model flavors trained |
|
with varying optimal processing resolution values. |
|
""" |
|
|
|
rgb_latent_scale_factor = 0.18215 |
|
depth_latent_scale_factor = 0.18215 |
|
|
|
def __init__( |
|
self, |
|
unet: DoubleUNet2DConditionModel, |
|
vae: AutoencoderKL, |
|
scheduler: Union[DDIMScheduler, LCMScheduler], |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
scale_invariant: Optional[bool] = True, |
|
shift_invariant: Optional[bool] = True, |
|
default_denoising_steps: Optional[int] = None, |
|
default_processing_resolution: Optional[int] = None, |
|
requires_safety_checker: bool = False, |
|
): |
|
super().__init__() |
|
|
|
self.register_modules( |
|
unet=unet, |
|
vae=vae, |
|
scheduler=scheduler, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
) |
|
self.register_to_config( |
|
scale_invariant=scale_invariant, |
|
shift_invariant=shift_invariant, |
|
default_denoising_steps=default_denoising_steps, |
|
default_processing_resolution=default_processing_resolution, |
|
) |
|
|
|
self.scale_invariant = scale_invariant |
|
self.shift_invariant = shift_invariant |
|
self.default_denoising_steps = default_denoising_steps |
|
self.default_processing_resolution = default_processing_resolution |
|
self.rgb_scheduler = None |
|
self.depth_scheduler = None |
|
|
|
self.empty_text_embed = None |
|
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) |
|
self.image_processor = VaeImageProcessor(vae_scale_factor=self.vae_scale_factor) |
|
self.mask_processor = VaeImageProcessor( |
|
vae_scale_factor=self.vae_scale_factor, do_normalize=False, do_binarize=True, do_convert_grayscale=True |
|
) |
|
self.register_to_config(requires_safety_checker=requires_safety_checker) |
|
self.separate_list = [0,0] |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
input_image: Union[Image.Image, torch.Tensor], |
|
denoising_steps: Optional[int] = None, |
|
ensemble_size: int = 5, |
|
processing_res: Optional[int] = None, |
|
match_input_res: bool = True, |
|
resample_method: str = "bilinear", |
|
batch_size: int = 0, |
|
generator: Union[torch.Generator, None] = None, |
|
color_map: str = "Spectral", |
|
show_progress_bar: bool = True, |
|
ensemble_kwargs: Dict = None, |
|
) -> MarigoldDepthOutput: |
|
""" |
|
Function invoked when calling the pipeline. |
|
|
|
Args: |
|
input_image (`Image`): |
|
Input RGB (or gray-scale) image. |
|
denoising_steps (`int`, *optional*, defaults to `None`): |
|
Number of denoising diffusion steps during inference. The default value `None` results in automatic |
|
selection. The number of steps should be at least 10 with the full Marigold models, and between 1 and 4 |
|
for Marigold-LCM models. |
|
ensemble_size (`int`, *optional*, defaults to `10`): |
|
Number of predictions to be ensembled. |
|
processing_res (`int`, *optional*, defaults to `None`): |
|
Effective processing resolution. When set to `0`, processes at the original image resolution. This |
|
produces crisper predictions, but may also lead to the overall loss of global context. The default |
|
value `None` resolves to the optimal value from the model config. |
|
match_input_res (`bool`, *optional*, defaults to `True`): |
|
Resize depth prediction to match input resolution. |
|
Only valid if `processing_res` > 0. |
|
resample_method: (`str`, *optional*, defaults to `bilinear`): |
|
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. |
|
batch_size (`int`, *optional*, defaults to `0`): |
|
Inference batch size, no bigger than `num_ensemble`. |
|
If set to 0, the script will automatically decide the proper batch size. |
|
generator (`torch.Generator`, *optional*, defaults to `None`) |
|
Random generator for initial noise generation. |
|
show_progress_bar (`bool`, *optional*, defaults to `True`): |
|
Display a progress bar of diffusion denoising. |
|
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation): |
|
Colormap used to colorize the depth map. |
|
scale_invariant (`str`, *optional*, defaults to `True`): |
|
Flag of scale-invariant prediction, if True, scale will be adjusted from the raw prediction. |
|
shift_invariant (`str`, *optional*, defaults to `True`): |
|
Flag of shift-invariant prediction, if True, shift will be adjusted from the raw prediction, if False, near plane will be fixed at 0m. |
|
ensemble_kwargs (`dict`, *optional*, defaults to `None`): |
|
Arguments for detailed ensembling settings. |
|
Returns: |
|
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: |
|
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] |
|
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None` |
|
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) |
|
coming from ensembling. None if `ensemble_size = 1` |
|
""" |
|
|
|
if denoising_steps is None: |
|
denoising_steps = self.default_denoising_steps |
|
if processing_res is None: |
|
processing_res = self.default_processing_resolution |
|
|
|
assert processing_res >= 0 |
|
assert ensemble_size >= 1 |
|
|
|
|
|
self._check_inference_step(denoising_steps) |
|
|
|
resample_method: InterpolationMode = get_tv_resample_method(resample_method) |
|
|
|
|
|
|
|
if isinstance(input_image, Image.Image): |
|
input_image = input_image.convert("RGB") |
|
|
|
rgb = pil_to_tensor(input_image) |
|
rgb = rgb.unsqueeze(0) |
|
elif isinstance(input_image, torch.Tensor): |
|
rgb = input_image |
|
else: |
|
raise TypeError(f"Unknown input type: {type(input_image) = }") |
|
input_size = rgb.shape |
|
assert ( |
|
4 == rgb.dim() and 3 == input_size[-3] |
|
), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" |
|
|
|
|
|
if processing_res > 0: |
|
rgb = resize_max_res( |
|
rgb, |
|
max_edge_resolution=processing_res, |
|
resample_method=resample_method, |
|
) |
|
|
|
|
|
rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 |
|
rgb_norm = rgb_norm.to(self.dtype) |
|
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 |
|
|
|
|
|
|
|
duplicated_rgb = rgb_norm.expand(ensemble_size, -1, -1, -1) |
|
single_rgb_dataset = TensorDataset(duplicated_rgb) |
|
if batch_size > 0: |
|
_bs = batch_size |
|
else: |
|
_bs = find_batch_size( |
|
ensemble_size=ensemble_size, |
|
input_res=max(rgb_norm.shape[1:]), |
|
dtype=self.dtype, |
|
) |
|
|
|
single_rgb_loader = DataLoader( |
|
single_rgb_dataset, batch_size=_bs, shuffle=False |
|
) |
|
|
|
|
|
depth_pred_ls = [] |
|
if show_progress_bar: |
|
iterable = tqdm( |
|
single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False |
|
) |
|
else: |
|
iterable = single_rgb_loader |
|
for batch in iterable: |
|
(batched_img,) = batch |
|
depth_pred_raw = self.single_infer( |
|
rgb_in=batched_img, |
|
num_inference_steps=denoising_steps, |
|
show_pbar=show_progress_bar, |
|
generator=generator, |
|
) |
|
depth_pred_ls.append(depth_pred_raw.detach()) |
|
depth_preds = torch.concat(depth_pred_ls, dim=0) |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if ensemble_size > 1: |
|
depth_pred, pred_uncert = ensemble_depth( |
|
depth_preds, |
|
scale_invariant=self.scale_invariant, |
|
shift_invariant=self.shift_invariant, |
|
max_res=50, |
|
**(ensemble_kwargs or {}), |
|
) |
|
else: |
|
depth_pred = depth_preds |
|
pred_uncert = None |
|
|
|
|
|
if match_input_res: |
|
depth_pred = resize( |
|
depth_pred, |
|
input_size[-2:], |
|
interpolation=resample_method, |
|
antialias=True, |
|
) |
|
|
|
|
|
depth_pred = depth_pred.squeeze() |
|
depth_pred = depth_pred.cpu().numpy() |
|
if pred_uncert is not None: |
|
pred_uncert = pred_uncert.squeeze().cpu().numpy() |
|
|
|
|
|
depth_pred = depth_pred.clip(0, 1) |
|
|
|
|
|
if color_map is not None: |
|
depth_colored = colorize_depth_maps( |
|
depth_pred, 0, 1, cmap=color_map |
|
).squeeze() |
|
depth_colored = (depth_colored * 255).astype(np.uint8) |
|
depth_colored_hwc = chw2hwc(depth_colored) |
|
depth_colored_img = Image.fromarray(depth_colored_hwc) |
|
else: |
|
depth_colored_img = None |
|
|
|
return MarigoldDepthOutput( |
|
depth_np=depth_pred, |
|
depth_colored=depth_colored_img, |
|
uncertainty=pred_uncert, |
|
) |
|
|
|
def _replace_unet_conv_in(self): |
|
|
|
_weight = self.unet.conv_in.weight.clone() |
|
_bias = self.unet.conv_in.bias.clone() |
|
zero_weight = torch.zeros(_weight.shape).to(_weight.device) |
|
_weight = torch.cat([_weight, zero_weight], dim=1) |
|
|
|
|
|
|
|
|
|
_n_convin_out_channel = self.unet.conv_in.out_channels |
|
_new_conv_in = Conv2d( |
|
8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) |
|
) |
|
_new_conv_in.weight = Parameter(_weight) |
|
_new_conv_in.bias = Parameter(_bias) |
|
self.unet.conv_in = _new_conv_in |
|
logging.info("Unet conv_in layer is replaced") |
|
|
|
self.unet.config["in_channels"] = 8 |
|
logging.info("Unet config is updated") |
|
return |
|
|
|
def _replace_unet_conv_out(self): |
|
|
|
_weight = self.unet.conv_out.weight.clone() |
|
_bias = self.unet.conv_out.bias.clone() |
|
_weight = _weight.repeat((2, 1, 1, 1)) |
|
_bias = _bias.repeat((2)) |
|
|
|
|
|
|
|
_n_convin_out_channel = self.unet.conv_out.out_channels |
|
_new_conv_out = Conv2d( |
|
_n_convin_out_channel, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) |
|
) |
|
_new_conv_out.weight = Parameter(_weight) |
|
_new_conv_out.bias = Parameter(_bias) |
|
self.unet.conv_out = _new_conv_out |
|
logging.info("Unet conv_out layer is replaced") |
|
|
|
self.unet.config["out_channels"] = 8 |
|
logging.info("Unet config is updated") |
|
return |
|
|
|
def _check_inference_step(self, n_step: int) -> None: |
|
""" |
|
Check if denoising step is reasonable |
|
Args: |
|
n_step (`int`): denoising steps |
|
""" |
|
assert n_step >= 1 |
|
|
|
if isinstance(self.scheduler, DDIMScheduler): |
|
if n_step < 10: |
|
logging.warning( |
|
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference." |
|
) |
|
elif isinstance(self.scheduler, LCMScheduler): |
|
if not 1 <= n_step <= 4: |
|
logging.warning( |
|
f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps." |
|
) |
|
elif isinstance(self.scheduler, PNDMScheduler): |
|
if n_step < 10: |
|
logging.warning( |
|
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference." |
|
) |
|
else: |
|
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") |
|
|
|
def encode_empty_text(self): |
|
""" |
|
Encode text embedding for empty prompt |
|
""" |
|
prompt = "" |
|
text_inputs = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) |
|
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) |
|
|
|
def encode_text(self, prompt): |
|
""" |
|
Encode text embedding for empty prompt |
|
""" |
|
text_inputs = self.tokenizer( |
|
prompt, |
|
padding="max_length", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) |
|
text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) |
|
return text_embed |
|
|
|
def numpy_to_pil(self, images: np.ndarray) -> PIL.Image.Image: |
|
""" |
|
Convert a numpy image or a batch of images to a PIL image. |
|
""" |
|
if images.ndim == 3: |
|
images = images[None, ...] |
|
images = (images * 255).round().astype("uint8") |
|
if images.shape[-1] == 1: |
|
|
|
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images] |
|
else: |
|
pil_images = [Image.fromarray(image) for image in images] |
|
|
|
return pil_images |
|
|
|
def full_depth_rgb_inpaint(self, |
|
rgb_in, |
|
depth_in, |
|
image_mask, |
|
text_embed, |
|
timesteps, |
|
generator, |
|
guidance_scale, |
|
): |
|
depth_latent = self.encode_depth(depth_in) |
|
depth_mask = torch.zeros_like(image_mask) |
|
depth_mask_latent = self.encode_depth(depth_in) |
|
|
|
rgb_latent = torch.randn( |
|
depth_latent.shape, |
|
device=self.device, |
|
dtype=self.unet.dtype, |
|
generator=generator, |
|
) * self.rgb_scheduler.init_noise_sigma |
|
rgb_mask = image_mask |
|
rgb_mask_latent = self.encode_rgb(rgb_in * (image_mask.squeeze() < 0.5), generator=generator) |
|
|
|
rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:]) |
|
depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:]) |
|
|
|
for i, t in enumerate(timesteps): |
|
cat_latent = torch.cat( |
|
[rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent, |
|
depth_mask_latent], dim=1 |
|
).float() |
|
|
|
latent_model_input = torch.cat([cat_latent] * 2) |
|
|
|
|
|
with torch.no_grad(): |
|
partial_noise_pred = self.unet( |
|
latent_model_input, |
|
rgb_timestep=t, |
|
depth_timestep=t, |
|
encoder_hidden_states=text_embed, |
|
return_dict=False, |
|
depth2rgb_scale=0.2 |
|
)[0] |
|
noise_pred = self.unet( |
|
latent_model_input, |
|
rgb_timestep=t, |
|
depth_timestep=t, |
|
encoder_hidden_states=text_embed, |
|
return_dict=False, |
|
|
|
)[0] |
|
|
|
rgb_pred_wo_depth_text = partial_noise_pred[0, :4, :, :] |
|
rgb_pred_wo_text = noise_pred[0, :4, :, :] |
|
rgb_pred = noise_pred[1, :4, :, :] |
|
noise_pred = rgb_pred_wo_depth_text + 2 * (rgb_pred_wo_text - rgb_pred_wo_depth_text) + 3 * (rgb_pred - rgb_pred_wo_text) |
|
|
|
|
|
rgb_latent = self.rgb_scheduler.step(noise_pred, t, rgb_latent).prev_sample |
|
return rgb_latent, depth_latent |
|
|
|
def full_rgb_depth_inpaint(self, |
|
rgb_in, |
|
depth_in, |
|
image_mask, |
|
text_embed, |
|
timesteps, |
|
generator, |
|
guidance_scale |
|
): |
|
rgb_latent = self.encode_rgb(rgb_in) |
|
rgb_mask = torch.zeros_like(image_mask) |
|
rgb_mask_latent = self.encode_rgb(rgb_in) |
|
|
|
depth_latent = torch.randn( |
|
rgb_latent.shape, |
|
device=self.device, |
|
dtype=self.unet.dtype, |
|
generator=generator, |
|
) * self.depth_scheduler.init_noise_sigma |
|
depth_mask = image_mask |
|
depth_mask_latent = self.encode_depth(depth_in * (image_mask.squeeze() < 0.5)) |
|
|
|
rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:]) |
|
depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:]) |
|
|
|
for i, t in enumerate(timesteps): |
|
cat_latent = torch.cat( |
|
[rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent, |
|
depth_mask_latent], dim=1 |
|
).float() |
|
|
|
latent_model_input = torch.cat([cat_latent] * 2) |
|
|
|
|
|
with torch.no_grad(): |
|
partial_noise_pred = self.unet( |
|
latent_model_input, |
|
rgb_timestep=t, |
|
depth_timestep=t, |
|
encoder_hidden_states=text_embed, |
|
return_dict=False, |
|
rgb2depth_scale=0.2 |
|
)[0] |
|
noise_pred = self.unet( |
|
latent_model_input, |
|
rgb_timestep=t, |
|
depth_timestep=t, |
|
encoder_hidden_states=text_embed, |
|
return_dict=False, |
|
|
|
)[0] |
|
|
|
depth_pre_wo_rgb = partial_noise_pred[1, 4:, :, :] |
|
|
|
depth_pre = depth_pre_wo_rgb + 4 * (noise_pred[1, 4:, :, :] - depth_pre_wo_rgb) |
|
|
|
depth_latent = self.depth_scheduler.step(depth_pre, t, depth_latent, generator=generator).prev_sample |
|
return rgb_latent, depth_latent |
|
|
|
def joint_inpaint(self, |
|
rgb_in, |
|
depth_in, |
|
image_mask, |
|
text_embed, |
|
timesteps, |
|
generator, |
|
guidance_scale |
|
): |
|
bs = rgb_in.shape[0] |
|
h, w = int(rgb_in.shape[-2]/8), int(rgb_in.shape[-1]/8) |
|
rgb_latent = torch.randn( |
|
[bs, 4, h, w], |
|
device=self.device, |
|
dtype=self.unet.dtype, |
|
generator=generator, |
|
) * self.rgb_scheduler.init_noise_sigma |
|
rgb_mask = image_mask |
|
rgb_mask_latent = self.encode_rgb(rgb_in * (rgb_mask.squeeze() < 0.5), generator=generator) |
|
|
|
depth_latent = torch.randn( |
|
[bs, 4, h, w], |
|
device=self.device, |
|
dtype=self.unet.dtype, |
|
generator=generator, |
|
) * self.depth_scheduler.init_noise_sigma |
|
depth_mask = image_mask |
|
depth_mask_latent = self.encode_depth(depth_in * (image_mask.squeeze() < 0.5)) |
|
|
|
rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:]) |
|
depth_mask = torch.nn.functional.interpolate(depth_mask, size=rgb_latent.shape[-2:]) |
|
|
|
for i, t in enumerate(timesteps): |
|
cat_latent = torch.cat( |
|
[rgb_latent, rgb_mask, rgb_mask_latent, depth_mask_latent, depth_latent, depth_mask, rgb_mask_latent, depth_mask_latent], dim=1 |
|
).float() |
|
|
|
latent_model_input = torch.cat([cat_latent] * 2) |
|
|
|
with torch.no_grad(): |
|
partial_noise_pred = self.unet( |
|
latent_model_input, |
|
rgb_timestep=t, |
|
depth_timestep=t, |
|
encoder_hidden_states=text_embed, |
|
return_dict=False, |
|
depth2rgb_scale=0, |
|
rgb2depth_scale=0.2 |
|
)[0] |
|
noise_pred = self.unet( |
|
latent_model_input, |
|
rgb_timestep=t, |
|
depth_timestep=t, |
|
encoder_hidden_states=text_embed, |
|
return_dict=False, |
|
)[0] |
|
|
|
|
|
noise_pred_untext_undual, noise_pred_undual = partial_noise_pred.chunk(2) |
|
noise_pred_untext, noise_pred_cond = noise_pred.chunk(2) |
|
|
|
|
|
depth_noise_pred = noise_pred_undual + 3 * (noise_pred_cond - noise_pred_undual) |
|
|
|
rgb_latent = self.rgb_scheduler.step(noise_pred_cond[:, :4, :, :], t, rgb_latent, return_dict=False)[0] |
|
depth_latent = self.depth_scheduler.step(depth_noise_pred[:, 4:, :, :], t, depth_latent, generator=generator, return_dict=False)[0] |
|
return rgb_latent, depth_latent |
|
|
|
@torch.no_grad() |
|
def _rgbd_inpaint(self, |
|
input_image: [torch.Tensor, PIL.Image.Image], |
|
depth_image: [torch.Tensor, PIL.Image.Image], |
|
mask: [torch.Tensor, PIL.Image.Image], |
|
prompt: str = '', |
|
guidance_scale: float = 4.5, |
|
generator: Union[torch.Generator, None] = None, |
|
num_inference_steps: int = 50, |
|
resample_method: str = "bilinear", |
|
processing_res: int = 512, |
|
mode: str = 'full_depth_rgb_inpaint' |
|
) -> PIL.Image: |
|
self._check_inference_step(num_inference_steps) |
|
|
|
resample_method: InterpolationMode = get_tv_resample_method(resample_method) |
|
|
|
|
|
if isinstance(prompt, list): |
|
bs = len(prompt) |
|
batch_text_embed = [] |
|
for p in prompt: |
|
batch_text_embed.append(self.encode_text(p)) |
|
batch_text_embed = torch.cat(batch_text_embed, dim=0) |
|
elif isinstance(prompt, str): |
|
bs = 1 |
|
batch_text_embed = self.encode_text(prompt).unsqueeze(0) |
|
else: |
|
raise NotImplementedError |
|
|
|
if self.empty_text_embed is None: |
|
self.encode_empty_text() |
|
batch_empty_text_embed = self.empty_text_embed.repeat( |
|
(batch_text_embed.shape[0], 1, 1) |
|
).to(self.device) |
|
text_embed = torch.cat([batch_empty_text_embed, batch_text_embed], dim=0) |
|
|
|
|
|
|
|
if isinstance(input_image, Image.Image): |
|
rgb_in = self.image_processor.preprocess(input_image, height=processing_res, |
|
width=processing_res).to(self.dtype).to(self.device) |
|
elif isinstance(input_image, torch.Tensor): |
|
rgb = input_image.unsqueeze(0) |
|
input_size = rgb.shape |
|
assert ( |
|
4 == rgb.dim() and 3 == input_size[-3] |
|
), f"Wrong input shape {input_size}, expected [1, rgb, H, W]" |
|
if processing_res > 0: |
|
rgb = resize(rgb, [processing_res, processing_res], resample_method, antialias=True) |
|
rgb_norm: torch.Tensor = rgb / 255.0 * 2.0 - 1.0 |
|
rgb_in = rgb_norm.to(self.dtype).to(self.device) |
|
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 |
|
|
|
if isinstance(depth_image, Image.Image): |
|
depth = pil_to_tensor(depth_image) |
|
depth = depth.unsqueeze(0) |
|
elif isinstance(depth_image, torch.Tensor): |
|
if len(depth_image.shape) == 3: |
|
depth = depth_image.unsqueeze(0) |
|
else: |
|
depth = depth_image |
|
|
|
depth = depth.repeat(1, 3, 1, 1) |
|
input_size = depth.shape |
|
assert ( |
|
4 == depth.dim() and 3 == input_size[-3] |
|
), f"Wrong input shape {input_size}, expected [1, 1, H, W]" |
|
if processing_res > 0: |
|
depth = resize(depth, [processing_res, processing_res], resample_method, antialias=True) |
|
depth_norm: torch.Tensor = (depth - depth.min()) / ( |
|
depth.max() - depth.min()) * 2.0 - 1.0 |
|
depth_in = depth_norm.to(self.dtype).to(self.device) |
|
assert depth_norm.min() >= -1.0 and depth_norm.max() <= 1.0 |
|
|
|
if (mask.max() - mask.min()) != 0: |
|
mask = (mask - mask.min()) / (mask.max() - mask.min()) * 255 |
|
image_mask = self.mask_processor.preprocess(mask, height=processing_res, width=processing_res).to(self.device) |
|
|
|
self.rgb_scheduler.set_timesteps(num_inference_steps, device=self.device) |
|
self.depth_scheduler.set_timesteps(num_inference_steps, device=self.device) |
|
timesteps = self.rgb_scheduler.timesteps |
|
|
|
if mode == 'full_rgb_depth_inpaint': |
|
rgb_latent, depth_latent = self.full_rgb_depth_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps, |
|
generator, guidance_scale=guidance_scale) |
|
if mode == 'partial_depth_rgb_inpaint': |
|
rgb_latent, depth_latent = self.partial_depth_rgb_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps, |
|
generator, guidance_scale=guidance_scale) |
|
if mode == 'full_depth_rgb_inpaint': |
|
rgb_latent, depth_latent = self.full_depth_rgb_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps, |
|
generator, guidance_scale=guidance_scale) |
|
if mode == 'joint_inpaint': |
|
rgb_latent, depth_latent = self.joint_inpaint(rgb_in, depth_in, image_mask, text_embed, timesteps, |
|
generator, guidance_scale=guidance_scale) |
|
|
|
image = self.decode_image(rgb_latent) |
|
image = self.numpy_to_pil(image)[0] |
|
|
|
d_image = self.decode_depth(depth_latent) |
|
d_image = d_image.cpu().permute(0, 2, 3, 1).numpy() |
|
d_image = (d_image - d_image.min()) / (d_image.max() - d_image.min()) |
|
d_image = self.numpy_to_pil(d_image)[0] |
|
|
|
depth = depth.squeeze().permute(1, 2, 0).cpu().numpy() |
|
depth = (depth - depth.min()) / (depth.max() - depth.min()) |
|
ori_depth = self.numpy_to_pil(depth)[0] |
|
|
|
ori_image = input_image.squeeze().permute(1, 2, 0).cpu().numpy() |
|
ori_image = self.numpy_to_pil(ori_image/255)[0] |
|
|
|
image_mask = self.numpy_to_pil(image_mask.permute(0, 2, 3, 1).cpu().numpy())[0] |
|
cat_image = make_image_grid([ori_image, ori_depth, image_mask, image, d_image], rows=1, cols=5) |
|
return cat_image |
|
|
|
|
|
def encode_rgb(self, rgb_in: torch.Tensor, generator=None) -> torch.Tensor: |
|
""" |
|
Encode RGB image into latent. |
|
|
|
Args: |
|
rgb_in (`torch.Tensor`): |
|
Input RGB image to be encoded. |
|
|
|
Returns: |
|
`torch.Tensor`: Image latent. |
|
""" |
|
|
|
image_latents = self.vae.encode(rgb_in).latent_dist.sample(generator=generator) |
|
image_latents = self.vae.config.scaling_factor * image_latents |
|
return image_latents |
|
|
|
def encode_depth(self, depth_in: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Encode RGB image into latent. |
|
|
|
Args: |
|
rgb_in (`torch.Tensor`): |
|
Input RGB image to be encoded. |
|
|
|
Returns: |
|
`torch.Tensor`: Image latent. |
|
""" |
|
|
|
h = self.vae.encoder(depth_in) |
|
moments = self.vae.quant_conv(h) |
|
mean, logvar = torch.chunk(moments, 2, dim=1) |
|
|
|
depth_latent = mean * self.depth_latent_scale_factor |
|
return depth_latent |
|
|
|
def decode_image(self, latents): |
|
latents = 1 / self.vae.config.scaling_factor * latents |
|
z = self.vae.post_quant_conv(latents) |
|
image = self.vae.decoder(z) |
|
image = (image / 2 + 0.5).clamp(0, 1) |
|
|
|
image = image.cpu().permute(0, 2, 3, 1).float().numpy() |
|
return image |
|
|
|
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Decode depth latent into depth map. |
|
|
|
Args: |
|
depth_latent (`torch.Tensor`): |
|
Depth latent to be decoded. |
|
|
|
Returns: |
|
`torch.Tensor`: Decoded depth map. |
|
""" |
|
|
|
depth_latent = depth_latent / self.depth_latent_scale_factor |
|
|
|
z = self.vae.post_quant_conv(depth_latent) |
|
stacked = self.vae.decoder(z) |
|
|
|
depth_mean = stacked.mean(dim=1, keepdim=True) |
|
return depth_mean |
|
|
|
def post_process_rgbd(self, prompts, rgb_image, depth_image): |
|
|
|
rgbd_images = [] |
|
for idx, p in enumerate(prompts): |
|
image1, image2 = rgb_image[idx], depth_image[idx] |
|
|
|
width1, height1 = image1.size |
|
width2, height2 = image2.size |
|
|
|
font = ImageFont.load_default(size=20) |
|
text = p |
|
draw = ImageDraw.Draw(image1) |
|
text_bbox = draw.textbbox((0, 0), text, font=font) |
|
text_width = text_bbox[2] - text_bbox[0] |
|
text_height = text_bbox[3] - text_bbox[1] |
|
|
|
new_image = Image.new('RGB', (width1 + width2, max(height1, height2) + text_height), (255, 255, 255)) |
|
|
|
text_x = (new_image.width - text_width) // 2 |
|
text_y = 0 |
|
draw = ImageDraw.Draw(new_image) |
|
draw.text((text_x, text_y), text, fill="black", font=font) |
|
|
|
new_image.paste(image1, (0, text_height)) |
|
new_image.paste(image2, (width1, text_height)) |
|
|
|
rgbd_images.append(pil_to_tensor(new_image)) |
|
|
|
return rgbd_images |
|
|