|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import math |
|
from typing import Dict, Union |
|
|
|
import matplotlib |
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from PIL.Image import Resampling |
|
from scipy.optimize import minimize |
|
from torch.utils.data import DataLoader, TensorDataset |
|
from tqdm.auto import tqdm |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
|
|
from diffusers import ( |
|
AutoencoderKL, |
|
DDIMScheduler, |
|
DiffusionPipeline, |
|
LCMScheduler, |
|
UNet2DConditionModel, |
|
) |
|
from diffusers.utils import BaseOutput, check_min_version |
|
|
|
|
|
|
|
check_min_version("0.31.0.dev0") |
|
|
|
|
|
class MarigoldDepthOutput(BaseOutput): |
|
""" |
|
Output class for Marigold monocular depth prediction pipeline. |
|
|
|
Args: |
|
depth_np (`np.ndarray`): |
|
Predicted depth map, with depth values in the range of [0, 1]. |
|
depth_colored (`None` or `PIL.Image.Image`): |
|
Colorized depth map, with the shape of [3, H, W] and values in [0, 1]. |
|
uncertainty (`None` or `np.ndarray`): |
|
Uncalibrated uncertainty(MAD, median absolute deviation) coming from ensembling. |
|
""" |
|
|
|
depth_np: np.ndarray |
|
depth_colored: Union[None, Image.Image] |
|
uncertainty: Union[None, np.ndarray] |
|
|
|
|
|
def get_pil_resample_method(method_str: str) -> Resampling: |
|
resample_method_dic = { |
|
"bilinear": Resampling.BILINEAR, |
|
"bicubic": Resampling.BICUBIC, |
|
"nearest": Resampling.NEAREST, |
|
} |
|
resample_method = resample_method_dic.get(method_str, None) |
|
if resample_method is None: |
|
raise ValueError(f"Unknown resampling method: {resample_method}") |
|
else: |
|
return resample_method |
|
|
|
|
|
class MarigoldPipeline(DiffusionPipeline): |
|
""" |
|
Pipeline for monocular depth estimation using Marigold: https://marigoldmonodepth.github.io. |
|
|
|
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods the |
|
library implements for all the pipelines (such as downloading or saving, running on a particular device, etc.) |
|
|
|
Args: |
|
unet (`UNet2DConditionModel`): |
|
Conditional U-Net to denoise the depth latent, conditioned on image latent. |
|
vae (`AutoencoderKL`): |
|
Variational Auto-Encoder (VAE) Model to encode and decode images and depth maps |
|
to and from latent representations. |
|
scheduler (`DDIMScheduler`): |
|
A scheduler to be used in combination with `unet` to denoise the encoded image latents. |
|
text_encoder (`CLIPTextModel`): |
|
Text-encoder, for empty text embedding. |
|
tokenizer (`CLIPTokenizer`): |
|
CLIP tokenizer. |
|
""" |
|
|
|
rgb_latent_scale_factor = 0.18215 |
|
depth_latent_scale_factor = 0.18215 |
|
|
|
def __init__( |
|
self, |
|
unet: UNet2DConditionModel, |
|
vae: AutoencoderKL, |
|
scheduler: DDIMScheduler, |
|
text_encoder: CLIPTextModel, |
|
tokenizer: CLIPTokenizer, |
|
): |
|
super().__init__() |
|
|
|
self.register_modules( |
|
unet=unet, |
|
vae=vae, |
|
scheduler=scheduler, |
|
text_encoder=text_encoder, |
|
tokenizer=tokenizer, |
|
) |
|
|
|
self.empty_text_embed = None |
|
|
|
@torch.no_grad() |
|
def __call__( |
|
self, |
|
input_image: Image, |
|
denoising_steps: int = 10, |
|
ensemble_size: int = 10, |
|
processing_res: int = 768, |
|
match_input_res: bool = True, |
|
resample_method: str = "bilinear", |
|
batch_size: int = 0, |
|
seed: Union[int, None] = None, |
|
color_map: str = "Spectral", |
|
show_progress_bar: bool = True, |
|
ensemble_kwargs: Dict = None, |
|
) -> MarigoldDepthOutput: |
|
""" |
|
Function invoked when calling the pipeline. |
|
|
|
Args: |
|
input_image (`Image`): |
|
Input RGB (or gray-scale) image. |
|
processing_res (`int`, *optional*, defaults to `768`): |
|
Maximum resolution of processing. |
|
If set to 0: will not resize at all. |
|
match_input_res (`bool`, *optional*, defaults to `True`): |
|
Resize depth prediction to match input resolution. |
|
Only valid if `processing_res` > 0. |
|
resample_method: (`str`, *optional*, defaults to `bilinear`): |
|
Resampling method used to resize images and depth predictions. This can be one of `bilinear`, `bicubic` or `nearest`, defaults to: `bilinear`. |
|
denoising_steps (`int`, *optional*, defaults to `10`): |
|
Number of diffusion denoising steps (DDIM) during inference. |
|
ensemble_size (`int`, *optional*, defaults to `10`): |
|
Number of predictions to be ensembled. |
|
batch_size (`int`, *optional*, defaults to `0`): |
|
Inference batch size, no bigger than `num_ensemble`. |
|
If set to 0, the script will automatically decide the proper batch size. |
|
seed (`int`, *optional*, defaults to `None`) |
|
Reproducibility seed. |
|
show_progress_bar (`bool`, *optional*, defaults to `True`): |
|
Display a progress bar of diffusion denoising. |
|
color_map (`str`, *optional*, defaults to `"Spectral"`, pass `None` to skip colorized depth map generation): |
|
Colormap used to colorize the depth map. |
|
ensemble_kwargs (`dict`, *optional*, defaults to `None`): |
|
Arguments for detailed ensembling settings. |
|
Returns: |
|
`MarigoldDepthOutput`: Output class for Marigold monocular depth prediction pipeline, including: |
|
- **depth_np** (`np.ndarray`) Predicted depth map, with depth values in the range of [0, 1] |
|
- **depth_colored** (`PIL.Image.Image`) Colorized depth map, with the shape of [3, H, W] and values in [0, 1], None if `color_map` is `None` |
|
- **uncertainty** (`None` or `np.ndarray`) Uncalibrated uncertainty(MAD, median absolute deviation) |
|
coming from ensembling. None if `ensemble_size = 1` |
|
""" |
|
|
|
device = self.device |
|
input_size = input_image.size |
|
|
|
if not match_input_res: |
|
assert processing_res is not None, "Value error: `resize_output_back` is only valid with " |
|
assert processing_res >= 0 |
|
assert ensemble_size >= 1 |
|
|
|
|
|
self._check_inference_step(denoising_steps) |
|
|
|
resample_method: Resampling = get_pil_resample_method(resample_method) |
|
|
|
|
|
|
|
if processing_res > 0: |
|
input_image = self.resize_max_res( |
|
input_image, |
|
max_edge_resolution=processing_res, |
|
resample_method=resample_method, |
|
) |
|
|
|
input_image = input_image.convert("RGB") |
|
image = np.asarray(input_image) |
|
|
|
|
|
rgb = np.transpose(image, (2, 0, 1)) |
|
rgb_norm = rgb / 255.0 * 2.0 - 1.0 |
|
rgb_norm = torch.from_numpy(rgb_norm).to(self.dtype) |
|
rgb_norm = rgb_norm.to(device) |
|
assert rgb_norm.min() >= -1.0 and rgb_norm.max() <= 1.0 |
|
|
|
|
|
|
|
duplicated_rgb = torch.stack([rgb_norm] * ensemble_size) |
|
single_rgb_dataset = TensorDataset(duplicated_rgb) |
|
if batch_size > 0: |
|
_bs = batch_size |
|
else: |
|
_bs = self._find_batch_size( |
|
ensemble_size=ensemble_size, |
|
input_res=max(rgb_norm.shape[1:]), |
|
dtype=self.dtype, |
|
) |
|
|
|
single_rgb_loader = DataLoader(single_rgb_dataset, batch_size=_bs, shuffle=False) |
|
|
|
|
|
depth_pred_ls = [] |
|
if show_progress_bar: |
|
iterable = tqdm(single_rgb_loader, desc=" " * 2 + "Inference batches", leave=False) |
|
else: |
|
iterable = single_rgb_loader |
|
for batch in iterable: |
|
(batched_img,) = batch |
|
depth_pred_raw = self.single_infer( |
|
rgb_in=batched_img, |
|
num_inference_steps=denoising_steps, |
|
show_pbar=show_progress_bar, |
|
seed=seed, |
|
) |
|
depth_pred_ls.append(depth_pred_raw.detach()) |
|
depth_preds = torch.concat(depth_pred_ls, dim=0).squeeze() |
|
torch.cuda.empty_cache() |
|
|
|
|
|
if ensemble_size > 1: |
|
depth_pred, pred_uncert = self.ensemble_depths(depth_preds, **(ensemble_kwargs or {})) |
|
else: |
|
depth_pred = depth_preds |
|
pred_uncert = None |
|
|
|
|
|
|
|
min_d = torch.min(depth_pred) |
|
max_d = torch.max(depth_pred) |
|
depth_pred = (depth_pred - min_d) / (max_d - min_d) |
|
|
|
|
|
depth_pred = depth_pred.cpu().numpy().astype(np.float32) |
|
|
|
|
|
if match_input_res: |
|
pred_img = Image.fromarray(depth_pred) |
|
pred_img = pred_img.resize(input_size, resample=resample_method) |
|
depth_pred = np.asarray(pred_img) |
|
|
|
|
|
depth_pred = depth_pred.clip(0, 1) |
|
|
|
|
|
if color_map is not None: |
|
depth_colored = self.colorize_depth_maps( |
|
depth_pred, 0, 1, cmap=color_map |
|
).squeeze() |
|
depth_colored = (depth_colored * 255).astype(np.uint8) |
|
depth_colored_hwc = self.chw2hwc(depth_colored) |
|
depth_colored_img = Image.fromarray(depth_colored_hwc) |
|
else: |
|
depth_colored_img = None |
|
|
|
return MarigoldDepthOutput( |
|
depth_np=depth_pred, |
|
depth_colored=depth_colored_img, |
|
uncertainty=pred_uncert, |
|
) |
|
|
|
def _check_inference_step(self, n_step: int): |
|
""" |
|
Check if denoising step is reasonable |
|
Args: |
|
n_step (`int`): denoising steps |
|
""" |
|
assert n_step >= 1 |
|
|
|
if isinstance(self.scheduler, DDIMScheduler): |
|
if n_step < 10: |
|
logging.warning( |
|
f"Too few denoising steps: {n_step}. Recommended to use the LCM checkpoint for few-step inference." |
|
) |
|
elif isinstance(self.scheduler, LCMScheduler): |
|
if not 1 <= n_step <= 4: |
|
logging.warning(f"Non-optimal setting of denoising steps: {n_step}. Recommended setting is 1-4 steps.") |
|
else: |
|
raise RuntimeError(f"Unsupported scheduler type: {type(self.scheduler)}") |
|
|
|
def _encode_empty_text(self): |
|
""" |
|
Encode text embedding for empty prompt. |
|
""" |
|
prompt = "" |
|
text_inputs = self.tokenizer( |
|
prompt, |
|
padding="do_not_pad", |
|
max_length=self.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
text_input_ids = text_inputs.input_ids.to(self.text_encoder.device) |
|
self.empty_text_embed = self.text_encoder(text_input_ids)[0].to(self.dtype) |
|
|
|
@torch.no_grad() |
|
def single_infer( |
|
self, |
|
rgb_in: torch.Tensor, |
|
num_inference_steps: int, |
|
seed: Union[int, None], |
|
show_pbar: bool, |
|
) -> torch.Tensor: |
|
""" |
|
Perform an individual depth prediction without ensembling. |
|
|
|
Args: |
|
rgb_in (`torch.Tensor`): |
|
Input RGB image. |
|
num_inference_steps (`int`): |
|
Number of diffusion denoisign steps (DDIM) during inference. |
|
show_pbar (`bool`): |
|
Display a progress bar of diffusion denoising. |
|
Returns: |
|
`torch.Tensor`: Predicted depth map. |
|
""" |
|
device = rgb_in.device |
|
|
|
|
|
self.scheduler.set_timesteps(num_inference_steps, device=device) |
|
timesteps = self.scheduler.timesteps |
|
|
|
|
|
rgb_latent = self.encode_rgb(rgb_in) |
|
|
|
|
|
if seed is None: |
|
rand_num_generator = None |
|
else: |
|
rand_num_generator = torch.Generator(device=device) |
|
rand_num_generator.manual_seed(seed) |
|
depth_latent = torch.randn( |
|
rgb_latent.shape, |
|
device=device, |
|
dtype=self.dtype, |
|
generator=rand_num_generator, |
|
) |
|
|
|
|
|
if self.empty_text_embed is None: |
|
self._encode_empty_text() |
|
batch_empty_text_embed = self.empty_text_embed.repeat((rgb_latent.shape[0], 1, 1)) |
|
|
|
|
|
if show_pbar: |
|
iterable = tqdm( |
|
enumerate(timesteps), |
|
total=len(timesteps), |
|
leave=False, |
|
desc=" " * 4 + "Diffusion denoising", |
|
) |
|
else: |
|
iterable = enumerate(timesteps) |
|
|
|
for i, t in iterable: |
|
unet_input = torch.cat([rgb_latent, depth_latent], dim=1) |
|
|
|
|
|
noise_pred = self.unet(unet_input, t, encoder_hidden_states=batch_empty_text_embed).sample |
|
|
|
|
|
depth_latent = self.scheduler.step(noise_pred, t, depth_latent, generator=rand_num_generator).prev_sample |
|
|
|
depth = self.decode_depth(depth_latent) |
|
|
|
|
|
depth = torch.clip(depth, -1.0, 1.0) |
|
|
|
depth = (depth + 1.0) / 2.0 |
|
|
|
return depth |
|
|
|
def encode_rgb(self, rgb_in: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Encode RGB image into latent. |
|
|
|
Args: |
|
rgb_in (`torch.Tensor`): |
|
Input RGB image to be encoded. |
|
|
|
Returns: |
|
`torch.Tensor`: Image latent. |
|
""" |
|
|
|
h = self.vae.encoder(rgb_in) |
|
moments = self.vae.quant_conv(h) |
|
mean, logvar = torch.chunk(moments, 2, dim=1) |
|
|
|
rgb_latent = mean * self.rgb_latent_scale_factor |
|
return rgb_latent |
|
|
|
def decode_depth(self, depth_latent: torch.Tensor) -> torch.Tensor: |
|
""" |
|
Decode depth latent into depth map. |
|
|
|
Args: |
|
depth_latent (`torch.Tensor`): |
|
Depth latent to be decoded. |
|
|
|
Returns: |
|
`torch.Tensor`: Decoded depth map. |
|
""" |
|
|
|
depth_latent = depth_latent / self.depth_latent_scale_factor |
|
|
|
z = self.vae.post_quant_conv(depth_latent) |
|
stacked = self.vae.decoder(z) |
|
|
|
depth_mean = stacked.mean(dim=1, keepdim=True) |
|
return depth_mean |
|
|
|
@staticmethod |
|
def resize_max_res(img: Image.Image, max_edge_resolution: int, resample_method=Resampling.BILINEAR) -> Image.Image: |
|
""" |
|
Resize image to limit maximum edge length while keeping aspect ratio. |
|
|
|
Args: |
|
img (`Image.Image`): |
|
Image to be resized. |
|
max_edge_resolution (`int`): |
|
Maximum edge length (pixel). |
|
resample_method (`PIL.Image.Resampling`): |
|
Resampling method used to resize images. |
|
|
|
Returns: |
|
`Image.Image`: Resized image. |
|
""" |
|
original_width, original_height = img.size |
|
downscale_factor = min(max_edge_resolution / original_width, max_edge_resolution / original_height) |
|
|
|
new_width = int(original_width * downscale_factor) |
|
new_height = int(original_height * downscale_factor) |
|
|
|
resized_img = img.resize((new_width, new_height), resample=resample_method) |
|
return resized_img |
|
|
|
@staticmethod |
|
def colorize_depth_maps(depth_map, min_depth, max_depth, cmap="Spectral", valid_mask=None): |
|
""" |
|
Colorize depth maps. |
|
""" |
|
assert len(depth_map.shape) >= 2, "Invalid dimension" |
|
|
|
if isinstance(depth_map, torch.Tensor): |
|
depth = depth_map.detach().clone().squeeze().numpy() |
|
elif isinstance(depth_map, np.ndarray): |
|
depth = depth_map.copy().squeeze() |
|
|
|
if depth.ndim < 3: |
|
depth = depth[np.newaxis, :, :] |
|
|
|
|
|
cm = matplotlib.colormaps[cmap] |
|
depth = ((depth - min_depth) / (max_depth - min_depth)).clip(0, 1) |
|
img_colored_np = cm(depth, bytes=False)[:, :, :, 0:3] |
|
img_colored_np = np.rollaxis(img_colored_np, 3, 1) |
|
|
|
if valid_mask is not None: |
|
if isinstance(depth_map, torch.Tensor): |
|
valid_mask = valid_mask.detach().numpy() |
|
valid_mask = valid_mask.squeeze() |
|
if valid_mask.ndim < 3: |
|
valid_mask = valid_mask[np.newaxis, np.newaxis, :, :] |
|
else: |
|
valid_mask = valid_mask[:, np.newaxis, :, :] |
|
valid_mask = np.repeat(valid_mask, 3, axis=1) |
|
img_colored_np[~valid_mask] = 0 |
|
|
|
if isinstance(depth_map, torch.Tensor): |
|
img_colored = torch.from_numpy(img_colored_np).float() |
|
elif isinstance(depth_map, np.ndarray): |
|
img_colored = img_colored_np |
|
|
|
return img_colored |
|
|
|
@staticmethod |
|
def chw2hwc(chw): |
|
assert 3 == len(chw.shape) |
|
if isinstance(chw, torch.Tensor): |
|
hwc = torch.permute(chw, (1, 2, 0)) |
|
elif isinstance(chw, np.ndarray): |
|
hwc = np.moveaxis(chw, 0, -1) |
|
return hwc |
|
|
|
@staticmethod |
|
def _find_batch_size(ensemble_size: int, input_res: int, dtype: torch.dtype) -> int: |
|
""" |
|
Automatically search for suitable operating batch size. |
|
|
|
Args: |
|
ensemble_size (`int`): |
|
Number of predictions to be ensembled. |
|
input_res (`int`): |
|
Operating resolution of the input image. |
|
|
|
Returns: |
|
`int`: Operating batch size. |
|
""" |
|
|
|
bs_search_table = [ |
|
|
|
{"res": 768, "total_vram": 79, "bs": 35, "dtype": torch.float32}, |
|
{"res": 1024, "total_vram": 79, "bs": 20, "dtype": torch.float32}, |
|
|
|
{"res": 768, "total_vram": 39, "bs": 15, "dtype": torch.float32}, |
|
{"res": 1024, "total_vram": 39, "bs": 8, "dtype": torch.float32}, |
|
{"res": 768, "total_vram": 39, "bs": 30, "dtype": torch.float16}, |
|
{"res": 1024, "total_vram": 39, "bs": 15, "dtype": torch.float16}, |
|
|
|
{"res": 512, "total_vram": 23, "bs": 20, "dtype": torch.float32}, |
|
{"res": 768, "total_vram": 23, "bs": 7, "dtype": torch.float32}, |
|
{"res": 1024, "total_vram": 23, "bs": 3, "dtype": torch.float32}, |
|
{"res": 512, "total_vram": 23, "bs": 40, "dtype": torch.float16}, |
|
{"res": 768, "total_vram": 23, "bs": 18, "dtype": torch.float16}, |
|
{"res": 1024, "total_vram": 23, "bs": 10, "dtype": torch.float16}, |
|
|
|
{"res": 512, "total_vram": 10, "bs": 5, "dtype": torch.float32}, |
|
{"res": 768, "total_vram": 10, "bs": 2, "dtype": torch.float32}, |
|
{"res": 512, "total_vram": 10, "bs": 10, "dtype": torch.float16}, |
|
{"res": 768, "total_vram": 10, "bs": 5, "dtype": torch.float16}, |
|
{"res": 1024, "total_vram": 10, "bs": 3, "dtype": torch.float16}, |
|
] |
|
|
|
if not torch.cuda.is_available(): |
|
return 1 |
|
|
|
total_vram = torch.cuda.mem_get_info()[1] / 1024.0**3 |
|
filtered_bs_search_table = [s for s in bs_search_table if s["dtype"] == dtype] |
|
for settings in sorted( |
|
filtered_bs_search_table, |
|
key=lambda k: (k["res"], -k["total_vram"]), |
|
): |
|
if input_res <= settings["res"] and total_vram >= settings["total_vram"]: |
|
bs = settings["bs"] |
|
if bs > ensemble_size: |
|
bs = ensemble_size |
|
elif bs > math.ceil(ensemble_size / 2) and bs < ensemble_size: |
|
bs = math.ceil(ensemble_size / 2) |
|
return bs |
|
|
|
return 1 |
|
|
|
@staticmethod |
|
def ensemble_depths( |
|
input_images: torch.Tensor, |
|
regularizer_strength: float = 0.02, |
|
max_iter: int = 2, |
|
tol: float = 1e-3, |
|
reduction: str = "median", |
|
max_res: int = None, |
|
): |
|
""" |
|
To ensemble multiple affine-invariant depth images (up to scale and shift), |
|
by aligning estimating the scale and shift |
|
""" |
|
|
|
def inter_distances(tensors: torch.Tensor): |
|
""" |
|
To calculate the distance between each two depth maps. |
|
""" |
|
distances = [] |
|
for i, j in torch.combinations(torch.arange(tensors.shape[0])): |
|
arr1 = tensors[i : i + 1] |
|
arr2 = tensors[j : j + 1] |
|
distances.append(arr1 - arr2) |
|
dist = torch.concatenate(distances, dim=0) |
|
return dist |
|
|
|
device = input_images.device |
|
dtype = input_images.dtype |
|
np_dtype = np.float32 |
|
|
|
original_input = input_images.clone() |
|
n_img = input_images.shape[0] |
|
ori_shape = input_images.shape |
|
|
|
if max_res is not None: |
|
scale_factor = torch.min(max_res / torch.tensor(ori_shape[-2:])) |
|
if scale_factor < 1: |
|
downscaler = torch.nn.Upsample(scale_factor=scale_factor, mode="nearest") |
|
input_images = downscaler(torch.from_numpy(input_images)).numpy() |
|
|
|
|
|
_min = np.min(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) |
|
_max = np.max(input_images.reshape((n_img, -1)).cpu().numpy(), axis=1) |
|
s_init = 1.0 / (_max - _min).reshape((-1, 1, 1)) |
|
t_init = (-1 * s_init.flatten() * _min.flatten()).reshape((-1, 1, 1)) |
|
x = np.concatenate([s_init, t_init]).reshape(-1).astype(np_dtype) |
|
|
|
input_images = input_images.to(device) |
|
|
|
|
|
def closure(x): |
|
l = len(x) |
|
s = x[: int(l / 2)] |
|
t = x[int(l / 2) :] |
|
s = torch.from_numpy(s).to(dtype=dtype).to(device) |
|
t = torch.from_numpy(t).to(dtype=dtype).to(device) |
|
|
|
transformed_arrays = input_images * s.view((-1, 1, 1)) + t.view((-1, 1, 1)) |
|
dists = inter_distances(transformed_arrays) |
|
sqrt_dist = torch.sqrt(torch.mean(dists**2)) |
|
|
|
if "mean" == reduction: |
|
pred = torch.mean(transformed_arrays, dim=0) |
|
elif "median" == reduction: |
|
pred = torch.median(transformed_arrays, dim=0).values |
|
else: |
|
raise ValueError |
|
|
|
near_err = torch.sqrt((0 - torch.min(pred)) ** 2) |
|
far_err = torch.sqrt((1 - torch.max(pred)) ** 2) |
|
|
|
err = sqrt_dist + (near_err + far_err) * regularizer_strength |
|
err = err.detach().cpu().numpy().astype(np_dtype) |
|
return err |
|
|
|
res = minimize( |
|
closure, |
|
x, |
|
method="BFGS", |
|
tol=tol, |
|
options={"maxiter": max_iter, "disp": False}, |
|
) |
|
x = res.x |
|
l = len(x) |
|
s = x[: int(l / 2)] |
|
t = x[int(l / 2) :] |
|
|
|
|
|
s = torch.from_numpy(s).to(dtype=dtype).to(device) |
|
t = torch.from_numpy(t).to(dtype=dtype).to(device) |
|
transformed_arrays = original_input * s.view(-1, 1, 1) + t.view(-1, 1, 1) |
|
if "mean" == reduction: |
|
aligned_images = torch.mean(transformed_arrays, dim=0) |
|
std = torch.std(transformed_arrays, dim=0) |
|
uncertainty = std |
|
elif "median" == reduction: |
|
aligned_images = torch.median(transformed_arrays, dim=0).values |
|
|
|
abs_dev = torch.abs(transformed_arrays - aligned_images) |
|
mad = torch.median(abs_dev, dim=0).values |
|
uncertainty = mad |
|
else: |
|
raise ValueError(f"Unknown reduction method: {reduction}") |
|
|
|
|
|
_min = torch.min(aligned_images) |
|
_max = torch.max(aligned_images) |
|
aligned_images = (aligned_images - _min) / (_max - _min) |
|
uncertainty /= _max - _min |
|
|
|
return aligned_images, uncertainty |
|
|