diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000000000000000000000000000000000000..963e2c9ef50c09d69ebe43922ca2aaf72e46d0c6 --- /dev/null +++ b/.gitignore @@ -0,0 +1,25 @@ +/data +/download +/extract +/view_point_cloud +/view_depth_map +/blobcache +/snapshot +/reference_embeddings +/.gradio +/debug +/workspace +/mlruns +/infer_output +/video_output +/eval_output +/.blobcache +/test_images +/test_videos +/vis +/videos +/raid +/blobmnt +/eval_dump +/pretrained +__pycache__/ \ No newline at end of file diff --git a/app.py b/app.py new file mode 100644 index 0000000000000000000000000000000000000000..4872f2d04cecc429312978f6bca200a6d7d4d519 --- /dev/null +++ b/app.py @@ -0,0 +1,111 @@ +import os +import time +from pathlib import Path +import uuid +import tempfile +from typing import Union +import spaces +import atexit +from concurrent.futures import ThreadPoolExecutor + +import gradio as gr +import cv2 +import torch +import numpy as np + +from moge.model import MoGeModel +from moge.utils.vis import colorize_depth +import utils3d + +model = MoGeModel.from_pretrained('Ruicheng/moge-vitl').cuda().eval() +thread_pool_executor = ThreadPoolExecutor(max_workers=1) + + +def delete_later(path: Union[str, os.PathLike], delay: int = 300): + def _delete(): + try: + os.remove(path) + except: + pass + def _wait_and_delete(): + time.sleep(delay) + _delete(path) + thread_pool_executor.submit(_wait_and_delete) + atexit.register(_delete) + +@spaces.GPU +def run(image: np.ndarray, remove_edge: bool = True): + run_id = str(uuid.uuid4()) + + larger_size = max(image.shape[:2]) + if larger_size > 1024: + scale = 1024 / larger_size + image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA) + + image_tensor = torch.tensor(image, dtype=torch.float32, device=torch.device('cuda')).permute(2, 0, 1) / 255 + output = model.infer(image_tensor, resolution_level=9, apply_mask=True) + points, depth, mask = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy() + + if remove_edge: + mask = mask & ~utils3d.numpy.depth_edge(depth, mask=mask, rtol=0.02) + mask = mask & (depth > 0) + + _, faces, indices = utils3d.numpy.image_mesh(width=image.shape[1], height=image.shape[0], mask=mask) + faces = utils3d.numpy.triangulate(faces) + + tempdir = Path(tempfile.gettempdir(), 'moge') + tempdir.mkdir(exist_ok=True) + + output_glb_path = Path(tempdir, f'{run_id}.glb') + output_glb_path.parent.mkdir(exist_ok=True) + tempfile.TemporaryFile() + utils3d.io.write_glb( + output_glb_path, + vertices=points.reshape(-1, 3)[indices] * [-1, -1, 1], + faces=faces, + vertex_colors=image.reshape(-1, 3)[indices] / 255, + ) + + output_ply_path = Path(tempdir, f'{run_id}.ply') + output_ply_path.parent.mkdir(exist_ok=True) + utils3d.io.write_ply( + output_ply_path, + vertices=points.reshape(-1, 3)[indices] * [-1, -1, 1], + faces=faces, + vertex_colors=image.reshape(-1, 3)[indices] / 255, + ) + + colorized_depth = colorize_depth(depth) + + delete_later(output_glb_path, delay=300) + delete_later(output_ply_path, delay=300) + + return colorized_depth, output_glb_path, output_ply_path.as_posix() + + +DESCRIPTION = """ +MoGe turns 2D images into 3D point maps. + +NOTE: +* If the image is too large (> 1024px), it will be resized accordingly. +* The color in the 3D viewer may look dark due to rendering of 3D viewer. You may download the 3D model as .glb or .ply file to view it in other 3D viewers. +""" + +if __name__ == '__main__': + + gr.Interface( + fn=run, + inputs=[ + gr.Image(type="numpy", image_mode="RGB"), + gr.Checkbox(True, label="Remove edges"), + ], + outputs=[ + gr.Image(type="numpy", label="Depth map (colorized)"), + gr.Model3D(display_mode="solid", clear_color=[1.0, 1.0, 1.0, 1.0], label="3D Viewer"), + gr.File(type="filepath", label="Download the model as .ply file"), + ], + title="MoGe Live Demo", + description=DESCRIPTION, + clear_btn=None, + allow_flagging="never", + ).launch(share=False) \ No newline at end of file diff --git a/moge/model/__init__.py b/moge/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..52b16d09fc22a91d1c8f947909971ee9830d5db3 --- /dev/null +++ b/moge/model/__init__.py @@ -0,0 +1 @@ +from .moge_model import MoGeModel \ No newline at end of file diff --git a/moge/model/dinov2/__init__.py b/moge/model/dinov2/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..ae847e46898077fe3d8701b8a181d7b4e3d41cd9 --- /dev/null +++ b/moge/model/dinov2/__init__.py @@ -0,0 +1,6 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +__version__ = "0.0.1" diff --git a/moge/model/dinov2/hub/__init__.py b/moge/model/dinov2/hub/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/moge/model/dinov2/hub/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/moge/model/dinov2/hub/backbones.py b/moge/model/dinov2/hub/backbones.py new file mode 100644 index 0000000000000000000000000000000000000000..53fe83719d5107eb77a8f25ef1814c3d73446002 --- /dev/null +++ b/moge/model/dinov2/hub/backbones.py @@ -0,0 +1,156 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +from typing import Union + +import torch + +from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name + + +class Weights(Enum): + LVD142M = "LVD142M" + + +def _make_dinov2_model( + *, + arch_name: str = "vit_large", + img_size: int = 518, + patch_size: int = 14, + init_values: float = 1.0, + ffn_layer: str = "mlp", + block_chunks: int = 0, + num_register_tokens: int = 0, + interpolate_antialias: bool = False, + interpolate_offset: float = 0.1, + pretrained: bool = True, + weights: Union[Weights, str] = Weights.LVD142M, + **kwargs, +): + from ..models import vision_transformer as vits + + if isinstance(weights, str): + try: + weights = Weights[weights] + except KeyError: + raise AssertionError(f"Unsupported weights: {weights}") + + model_base_name = _make_dinov2_model_name(arch_name, patch_size) + vit_kwargs = dict( + img_size=img_size, + patch_size=patch_size, + init_values=init_values, + ffn_layer=ffn_layer, + block_chunks=block_chunks, + num_register_tokens=num_register_tokens, + interpolate_antialias=interpolate_antialias, + interpolate_offset=interpolate_offset, + ) + vit_kwargs.update(**kwargs) + model = vits.__dict__[arch_name](**vit_kwargs) + + if pretrained: + model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens) + url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth" + state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu") + model.load_state_dict(state_dict, strict=True) + + return model + + +def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs) + + +def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + **kwargs, + ) + + +def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_small", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_base", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_large", + pretrained=pretrained, + weights=weights, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) + + +def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs): + """ + DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset. + """ + return _make_dinov2_model( + arch_name="vit_giant2", + ffn_layer="swiglufused", + weights=weights, + pretrained=pretrained, + num_register_tokens=4, + interpolate_antialias=True, + interpolate_offset=0.0, + **kwargs, + ) diff --git a/moge/model/dinov2/hub/utils.py b/moge/model/dinov2/hub/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..9c6641404093652d5a2f19b4cf283d976ec39e64 --- /dev/null +++ b/moge/model/dinov2/hub/utils.py @@ -0,0 +1,39 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import itertools +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F + + +_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2" + + +def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str: + compact_arch_name = arch_name.replace("_", "")[:4] + registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else "" + return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}" + + +class CenterPadding(nn.Module): + def __init__(self, multiple): + super().__init__() + self.multiple = multiple + + def _get_pad(self, size): + new_size = math.ceil(size / self.multiple) * self.multiple + pad_size = new_size - size + pad_size_left = pad_size // 2 + pad_size_right = pad_size - pad_size_left + return pad_size_left, pad_size_right + + @torch.inference_mode() + def forward(self, x): + pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1])) + output = F.pad(x, pads) + return output diff --git a/moge/model/dinov2/layers/__init__.py b/moge/model/dinov2/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..05a0b61868e43abb821ca05a813bab2b8b43629e --- /dev/null +++ b/moge/model/dinov2/layers/__init__.py @@ -0,0 +1,11 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from .dino_head import DINOHead +from .mlp import Mlp +from .patch_embed import PatchEmbed +from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused +from .block import NestedTensorBlock +from .attention import MemEffAttention diff --git a/moge/model/dinov2/layers/attention.py b/moge/model/dinov2/layers/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..3fed573116d5c837be46a7525d8acf77422c2400 --- /dev/null +++ b/moge/model/dinov2/layers/attention.py @@ -0,0 +1,89 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +import logging +import os +import warnings + +from torch import Tensor +from torch import nn + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import memory_efficient_attention, unbind + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Attention)") + else: + # warnings.warn("xFormers is disabled (Attention)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Attention)") + + +class Attention(nn.Module): + def __init__( + self, + dim: int, + num_heads: int = 8, + qkv_bias: bool = False, + proj_bias: bool = True, + attn_drop: float = 0.0, + proj_drop: float = 0.0, + ) -> None: + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = head_dim**-0.5 + + self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias) + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim, bias=proj_bias) + self.proj_drop = nn.Dropout(proj_drop) + + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) + + q, k, v = qkv[0] * self.scale, qkv[1], qkv[2] + attn = q @ k.transpose(-2, -1) + + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + + x = (attn @ v).transpose(1, 2).reshape(B, N, C) + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class MemEffAttention(Attention): + def forward(self, x: Tensor, attn_bias=None) -> Tensor: + if not XFORMERS_AVAILABLE: + if attn_bias is not None: + raise AssertionError("xFormers is required for using nested tensors") + return super().forward(x) + + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads) + + q, k, v = unbind(qkv, 2) + + x = memory_efficient_attention(q, k, v, attn_bias=attn_bias) + x = x.reshape([B, N, C]) + + x = self.proj(x) + x = self.proj_drop(x) + return x diff --git a/moge/model/dinov2/layers/block.py b/moge/model/dinov2/layers/block.py new file mode 100644 index 0000000000000000000000000000000000000000..fd5b8a7bb8527b74186af7c1e060e37bdb52c73d --- /dev/null +++ b/moge/model/dinov2/layers/block.py @@ -0,0 +1,259 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +import logging +import os +from typing import Callable, List, Any, Tuple, Dict +import warnings + +import torch +from torch import nn, Tensor + +from .attention import Attention, MemEffAttention +from .drop_path import DropPath +from .layer_scale import LayerScale +from .mlp import Mlp + + +logger = logging.getLogger("dinov2") + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import fmha, scaled_index_add, index_select_cat + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (Block)") + else: + # warnings.warn("xFormers is disabled (Block)") + raise ImportError +except ImportError: + XFORMERS_AVAILABLE = False + # warnings.warn("xFormers is not available (Block)") + + +class Block(nn.Module): + def __init__( + self, + dim: int, + num_heads: int, + mlp_ratio: float = 4.0, + qkv_bias: bool = False, + proj_bias: bool = True, + ffn_bias: bool = True, + drop: float = 0.0, + attn_drop: float = 0.0, + init_values=None, + drop_path: float = 0.0, + act_layer: Callable[..., nn.Module] = nn.GELU, + norm_layer: Callable[..., nn.Module] = nn.LayerNorm, + attn_class: Callable[..., nn.Module] = Attention, + ffn_layer: Callable[..., nn.Module] = Mlp, + ) -> None: + super().__init__() + # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}") + self.norm1 = norm_layer(dim) + self.attn = attn_class( + dim, + num_heads=num_heads, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + attn_drop=attn_drop, + proj_drop=drop, + ) + self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.norm2 = norm_layer(dim) + mlp_hidden_dim = int(dim * mlp_ratio) + self.mlp = ffn_layer( + in_features=dim, + hidden_features=mlp_hidden_dim, + act_layer=act_layer, + drop=drop, + bias=ffn_bias, + ) + self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity() + self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity() + + self.sample_drop_ratio = drop_path + + def forward(self, x: Tensor) -> Tensor: + def attn_residual_func(x: Tensor) -> Tensor: + return self.ls1(self.attn(self.norm1(x))) + + def ffn_residual_func(x: Tensor) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + if self.training and self.sample_drop_ratio > 0.1: + # the overhead is compensated only for a drop path rate larger than 0.1 + x = drop_add_residual_stochastic_depth( + x, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + x = drop_add_residual_stochastic_depth( + x, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + ) + elif self.training and self.sample_drop_ratio > 0.0: + x = x + self.drop_path1(attn_residual_func(x)) + x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2 + else: + x = x + attn_residual_func(x) + x = x + ffn_residual_func(x) + return x + + +def drop_add_residual_stochastic_depth( + x: Tensor, + residual_func: Callable[[Tensor], Tensor], + sample_drop_ratio: float = 0.0, +) -> Tensor: + # 1) extract subset using permutation + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + x_subset = x[brange] + + # 2) apply residual_func to get residual + residual = residual_func(x_subset) + + x_flat = x.flatten(1) + residual = residual.flatten(1) + + residual_scale_factor = b / sample_subset_size + + # 3) add the residual + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + return x_plus_residual.view_as(x) + + +def get_branges_scales(x, sample_drop_ratio=0.0): + b, n, d = x.shape + sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1) + brange = (torch.randperm(b, device=x.device))[:sample_subset_size] + residual_scale_factor = b / sample_subset_size + return brange, residual_scale_factor + + +def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None): + if scaling_vector is None: + x_flat = x.flatten(1) + residual = residual.flatten(1) + x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor) + else: + x_plus_residual = scaled_index_add( + x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor + ) + return x_plus_residual + + +attn_bias_cache: Dict[Tuple, Any] = {} + + +def get_attn_bias_and_cat(x_list, branges=None): + """ + this will perform the index select, cat the tensors, and provide the attn_bias from cache + """ + batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list] + all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list)) + if all_shapes not in attn_bias_cache.keys(): + seqlens = [] + for b, x in zip(batch_sizes, x_list): + for _ in range(b): + seqlens.append(x.shape[1]) + attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens) + attn_bias._batch_sizes = batch_sizes + attn_bias_cache[all_shapes] = attn_bias + + if branges is not None: + cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1]) + else: + tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list) + cat_tensors = torch.cat(tensors_bs1, dim=1) + + return attn_bias_cache[all_shapes], cat_tensors + + +def drop_add_residual_stochastic_depth_list( + x_list: List[Tensor], + residual_func: Callable[[Tensor, Any], Tensor], + sample_drop_ratio: float = 0.0, + scaling_vector=None, +) -> Tensor: + # 1) generate random set of indices for dropping samples in the batch + branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list] + branges = [s[0] for s in branges_scales] + residual_scale_factors = [s[1] for s in branges_scales] + + # 2) get attention bias and index+concat the tensors + attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges) + + # 3) apply residual_func to get residual, and split the result + residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore + + outputs = [] + for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors): + outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x)) + return outputs + + +class NestedTensorBlock(Block): + def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]: + """ + x_list contains a list of tensors to nest together and run + """ + assert isinstance(self.attn, MemEffAttention) + + if self.training and self.sample_drop_ratio > 0.0: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.attn(self.norm1(x), attn_bias=attn_bias) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.mlp(self.norm2(x)) + + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=attn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None, + ) + x_list = drop_add_residual_stochastic_depth_list( + x_list, + residual_func=ffn_residual_func, + sample_drop_ratio=self.sample_drop_ratio, + scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None, + ) + return x_list + else: + + def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias)) + + def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor: + return self.ls2(self.mlp(self.norm2(x))) + + attn_bias, x = get_attn_bias_and_cat(x_list) + x = x + attn_residual_func(x, attn_bias=attn_bias) + x = x + ffn_residual_func(x) + return attn_bias.split(x) + + def forward(self, x_or_x_list): + if isinstance(x_or_x_list, Tensor): + return super().forward(x_or_x_list) + elif isinstance(x_or_x_list, list): + if not XFORMERS_AVAILABLE: + raise AssertionError("xFormers is required for using nested tensors") + return self.forward_nested(x_or_x_list) + else: + raise AssertionError diff --git a/moge/model/dinov2/layers/dino_head.py b/moge/model/dinov2/layers/dino_head.py new file mode 100644 index 0000000000000000000000000000000000000000..0ace8ffd6297a1dd480b19db407b662a6ea0f565 --- /dev/null +++ b/moge/model/dinov2/layers/dino_head.py @@ -0,0 +1,58 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import torch +import torch.nn as nn +from torch.nn.init import trunc_normal_ +from torch.nn.utils import weight_norm + + +class DINOHead(nn.Module): + def __init__( + self, + in_dim, + out_dim, + use_bn=False, + nlayers=3, + hidden_dim=2048, + bottleneck_dim=256, + mlp_bias=True, + ): + super().__init__() + nlayers = max(nlayers, 1) + self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias) + self.apply(self._init_weights) + self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False)) + self.last_layer.weight_g.data.fill_(1) + + def _init_weights(self, m): + if isinstance(m, nn.Linear): + trunc_normal_(m.weight, std=0.02) + if isinstance(m, nn.Linear) and m.bias is not None: + nn.init.constant_(m.bias, 0) + + def forward(self, x): + x = self.mlp(x) + eps = 1e-6 if x.dtype == torch.float16 else 1e-12 + x = nn.functional.normalize(x, dim=-1, p=2, eps=eps) + x = self.last_layer(x) + return x + + +def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True): + if nlayers == 1: + return nn.Linear(in_dim, bottleneck_dim, bias=bias) + else: + layers = [nn.Linear(in_dim, hidden_dim, bias=bias)] + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + for _ in range(nlayers - 2): + layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias)) + if use_bn: + layers.append(nn.BatchNorm1d(hidden_dim)) + layers.append(nn.GELU()) + layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias)) + return nn.Sequential(*layers) diff --git a/moge/model/dinov2/layers/drop_path.py b/moge/model/dinov2/layers/drop_path.py new file mode 100644 index 0000000000000000000000000000000000000000..1d640e0b969b8dcba96260243473700b4e5b24b5 --- /dev/null +++ b/moge/model/dinov2/layers/drop_path.py @@ -0,0 +1,34 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py + + +from torch import nn + + +def drop_path(x, drop_prob: float = 0.0, training: bool = False): + if drop_prob == 0.0 or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = x.new_empty(shape).bernoulli_(keep_prob) + if keep_prob > 0.0: + random_tensor.div_(keep_prob) + output = x * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks).""" + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) diff --git a/moge/model/dinov2/layers/layer_scale.py b/moge/model/dinov2/layers/layer_scale.py new file mode 100644 index 0000000000000000000000000000000000000000..51df0d7ce61f2b41fa9e6369f52391dd7fe7d386 --- /dev/null +++ b/moge/model/dinov2/layers/layer_scale.py @@ -0,0 +1,27 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110 + +from typing import Union + +import torch +from torch import Tensor +from torch import nn + + +class LayerScale(nn.Module): + def __init__( + self, + dim: int, + init_values: Union[float, Tensor] = 1e-5, + inplace: bool = False, + ) -> None: + super().__init__() + self.inplace = inplace + self.gamma = nn.Parameter(init_values * torch.ones(dim)) + + def forward(self, x: Tensor) -> Tensor: + return x.mul_(self.gamma) if self.inplace else x * self.gamma diff --git a/moge/model/dinov2/layers/mlp.py b/moge/model/dinov2/layers/mlp.py new file mode 100644 index 0000000000000000000000000000000000000000..bbf9432aae9258612caeae910a7bde17999e328e --- /dev/null +++ b/moge/model/dinov2/layers/mlp.py @@ -0,0 +1,40 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py + + +from typing import Callable, Optional + +from torch import Tensor, nn + + +class Mlp(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = nn.GELU, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features, bias=bias) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features, bias=bias) + self.drop = nn.Dropout(drop) + + def forward(self, x: Tensor) -> Tensor: + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x diff --git a/moge/model/dinov2/layers/patch_embed.py b/moge/model/dinov2/layers/patch_embed.py new file mode 100644 index 0000000000000000000000000000000000000000..8b7c0804784a42cf80c0297d110dcc68cc85b339 --- /dev/null +++ b/moge/model/dinov2/layers/patch_embed.py @@ -0,0 +1,88 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py + +from typing import Callable, Optional, Tuple, Union + +from torch import Tensor +import torch.nn as nn + + +def make_2tuple(x): + if isinstance(x, tuple): + assert len(x) == 2 + return x + + assert isinstance(x, int) + return (x, x) + + +class PatchEmbed(nn.Module): + """ + 2D image to patch embedding: (B,C,H,W) -> (B,N,D) + + Args: + img_size: Image size. + patch_size: Patch token size. + in_chans: Number of input image channels. + embed_dim: Number of linear projection output channels. + norm_layer: Normalization layer. + """ + + def __init__( + self, + img_size: Union[int, Tuple[int, int]] = 224, + patch_size: Union[int, Tuple[int, int]] = 16, + in_chans: int = 3, + embed_dim: int = 768, + norm_layer: Optional[Callable] = None, + flatten_embedding: bool = True, + ) -> None: + super().__init__() + + image_HW = make_2tuple(img_size) + patch_HW = make_2tuple(patch_size) + patch_grid_size = ( + image_HW[0] // patch_HW[0], + image_HW[1] // patch_HW[1], + ) + + self.img_size = image_HW + self.patch_size = patch_HW + self.patches_resolution = patch_grid_size + self.num_patches = patch_grid_size[0] * patch_grid_size[1] + + self.in_chans = in_chans + self.embed_dim = embed_dim + + self.flatten_embedding = flatten_embedding + + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW) + self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() + + def forward(self, x: Tensor) -> Tensor: + _, _, H, W = x.shape + patch_H, patch_W = self.patch_size + + assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}" + assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}" + + x = self.proj(x) # B C H W + H, W = x.size(2), x.size(3) + x = x.flatten(2).transpose(1, 2) # B HW C + x = self.norm(x) + if not self.flatten_embedding: + x = x.reshape(-1, H, W, self.embed_dim) # B H W C + return x + + def flops(self) -> float: + Ho, Wo = self.patches_resolution + flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1]) + if self.norm is not None: + flops += Ho * Wo * self.embed_dim + return flops diff --git a/moge/model/dinov2/layers/swiglu_ffn.py b/moge/model/dinov2/layers/swiglu_ffn.py new file mode 100644 index 0000000000000000000000000000000000000000..5ce211515774d42e04c8b51003bae53b88f14b35 --- /dev/null +++ b/moge/model/dinov2/layers/swiglu_ffn.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import os +from typing import Callable, Optional +import warnings + +from torch import Tensor, nn +import torch.nn.functional as F + + +class SwiGLUFFN(nn.Module): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias) + self.w3 = nn.Linear(hidden_features, out_features, bias=bias) + + def forward(self, x: Tensor) -> Tensor: + x12 = self.w12(x) + x1, x2 = x12.chunk(2, dim=-1) + hidden = F.silu(x1) * x2 + return self.w3(hidden) + + +XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None +try: + if XFORMERS_ENABLED: + from xformers.ops import SwiGLU + + XFORMERS_AVAILABLE = True + # warnings.warn("xFormers is available (SwiGLU)") + else: + # warnings.warn("xFormers is disabled (SwiGLU)") + raise ImportError +except ImportError: + SwiGLU = SwiGLUFFN + XFORMERS_AVAILABLE = False + + # warnings.warn("xFormers is not available (SwiGLU)") + + +class SwiGLUFFNFused(SwiGLU): + def __init__( + self, + in_features: int, + hidden_features: Optional[int] = None, + out_features: Optional[int] = None, + act_layer: Callable[..., nn.Module] = None, + drop: float = 0.0, + bias: bool = True, + ) -> None: + out_features = out_features or in_features + hidden_features = hidden_features or in_features + hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8 + super().__init__( + in_features=in_features, + hidden_features=hidden_features, + out_features=out_features, + bias=bias, + ) diff --git a/moge/model/dinov2/models/__init__.py b/moge/model/dinov2/models/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..3fdff20badbd5244bf79f16bf18dd2cb73982265 --- /dev/null +++ b/moge/model/dinov2/models/__init__.py @@ -0,0 +1,43 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging + +from . import vision_transformer as vits + + +logger = logging.getLogger("dinov2") + + +def build_model(args, only_teacher=False, img_size=224): + args.arch = args.arch.removesuffix("_memeff") + if "vit" in args.arch: + vit_kwargs = dict( + img_size=img_size, + patch_size=args.patch_size, + init_values=args.layerscale, + ffn_layer=args.ffn_layer, + block_chunks=args.block_chunks, + qkv_bias=args.qkv_bias, + proj_bias=args.proj_bias, + ffn_bias=args.ffn_bias, + num_register_tokens=args.num_register_tokens, + interpolate_offset=args.interpolate_offset, + interpolate_antialias=args.interpolate_antialias, + ) + teacher = vits.__dict__[args.arch](**vit_kwargs) + if only_teacher: + return teacher, teacher.embed_dim + student = vits.__dict__[args.arch]( + **vit_kwargs, + drop_path_rate=args.drop_path_rate, + drop_path_uniform=args.drop_path_uniform, + ) + embed_dim = student.embed_dim + return student, teacher, embed_dim + + +def build_model_from_cfg(cfg, only_teacher=False): + return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size) diff --git a/moge/model/dinov2/models/vision_transformer.py b/moge/model/dinov2/models/vision_transformer.py new file mode 100644 index 0000000000000000000000000000000000000000..1007ba57ddb35109c91716f1f5bf203db346e7be --- /dev/null +++ b/moge/model/dinov2/models/vision_transformer.py @@ -0,0 +1,396 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +# References: +# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py +# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py + +from functools import partial +import math +import logging +from typing import Sequence, Tuple, Union, Callable + +import torch +import torch.nn as nn +import torch.utils.checkpoint +from torch.nn.init import trunc_normal_ + +from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block + + +logger = logging.getLogger("dinov2") + + +def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module: + if not depth_first and include_root: + fn(module=module, name=name) + for child_name, child_module in module.named_children(): + child_name = ".".join((name, child_name)) if name else child_name + named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True) + if depth_first and include_root: + fn(module=module, name=name) + return module + + +class BlockChunk(nn.ModuleList): + def forward(self, x): + for b in self: + x = b(x) + return x + + +class DinoVisionTransformer(nn.Module): + def __init__( + self, + img_size=224, + patch_size=16, + in_chans=3, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4.0, + qkv_bias=True, + ffn_bias=True, + proj_bias=True, + drop_path_rate=0.0, + drop_path_uniform=False, + init_values=None, # for layerscale: None or 0 => no layerscale + embed_layer=PatchEmbed, + act_layer=nn.GELU, + block_fn=Block, + ffn_layer="mlp", + block_chunks=1, + num_register_tokens=0, + interpolate_antialias=False, + interpolate_offset=0.1, + ): + """ + Args: + img_size (int, tuple): input image size + patch_size (int, tuple): patch size + in_chans (int): number of input channels + embed_dim (int): embedding dimension + depth (int): depth of transformer + num_heads (int): number of attention heads + mlp_ratio (int): ratio of mlp hidden dim to embedding dim + qkv_bias (bool): enable bias for qkv if True + proj_bias (bool): enable bias for proj in attn if True + ffn_bias (bool): enable bias for ffn if True + drop_path_rate (float): stochastic depth rate + drop_path_uniform (bool): apply uniform drop rate across blocks + weight_init (str): weight init scheme + init_values (float): layer-scale init values + embed_layer (nn.Module): patch embedding layer + act_layer (nn.Module): MLP activation layer + block_fn (nn.Module): transformer block class + ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity" + block_chunks: (int) split block sequence into block_chunks units for FSDP wrap + num_register_tokens: (int) number of extra cls tokens (so-called "registers") + interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings + interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings + """ + super().__init__() + norm_layer = partial(nn.LayerNorm, eps=1e-6) + + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + self.num_tokens = 1 + self.n_blocks = depth + self.num_heads = num_heads + self.patch_size = patch_size + self.num_register_tokens = num_register_tokens + self.interpolate_antialias = interpolate_antialias + self.interpolate_offset = interpolate_offset + + self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim) + num_patches = self.patch_embed.num_patches + + self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim)) + self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim)) + assert num_register_tokens >= 0 + self.register_tokens = ( + nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None + ) + + if drop_path_uniform is True: + dpr = [drop_path_rate] * depth + else: + dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule + + if ffn_layer == "mlp": + logger.info("using MLP layer as FFN") + ffn_layer = Mlp + elif ffn_layer == "swiglufused" or ffn_layer == "swiglu": + logger.info("using SwiGLU layer as FFN") + ffn_layer = SwiGLUFFNFused + elif ffn_layer == "identity": + logger.info("using Identity layer as FFN") + + def f(*args, **kwargs): + return nn.Identity() + + ffn_layer = f + else: + raise NotImplementedError + + blocks_list = [ + block_fn( + dim=embed_dim, + num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, + proj_bias=proj_bias, + ffn_bias=ffn_bias, + drop_path=dpr[i], + norm_layer=norm_layer, + act_layer=act_layer, + ffn_layer=ffn_layer, + init_values=init_values, + ) + for i in range(depth) + ] + if block_chunks > 0: + self.chunked_blocks = True + chunked_blocks = [] + chunksize = depth // block_chunks + for i in range(0, depth, chunksize): + # this is to keep the block index consistent if we chunk the block list + chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize]) + self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks]) + else: + self.chunked_blocks = False + self.blocks = nn.ModuleList(blocks_list) + + self.norm = norm_layer(embed_dim) + self.head = nn.Identity() + + self.mask_token = nn.Parameter(torch.zeros(1, embed_dim)) + + self.init_weights() + + def init_weights(self): + trunc_normal_(self.pos_embed, std=0.02) + nn.init.normal_(self.cls_token, std=1e-6) + if self.register_tokens is not None: + nn.init.normal_(self.register_tokens, std=1e-6) + named_apply(init_weights_vit_timm, self) + + def interpolate_pos_encoding(self, x, w, h): + previous_dtype = x.dtype + npatch = x.shape[1] - 1 + N = self.pos_embed.shape[1] - 1 + if npatch == N and w == h: + return self.pos_embed + pos_embed = self.pos_embed.float() + class_pos_embed = pos_embed[:, 0] + patch_pos_embed = pos_embed[:, 1:] + dim = x.shape[-1] + w0 = w // self.patch_size + h0 = h // self.patch_size + M = int(math.sqrt(N)) # Recover the number of patches in each dimension + assert N == M * M + kwargs = {} + if self.interpolate_offset: + # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8 + # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors + sx = float(w0 + self.interpolate_offset) / M + sy = float(h0 + self.interpolate_offset) / M + kwargs["scale_factor"] = (sx, sy) + else: + # Simply specify an output size instead of a scale factor + kwargs["size"] = (w0, h0) + patch_pos_embed = nn.functional.interpolate( + patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2), + mode="bicubic", + antialias=self.interpolate_antialias, + **kwargs, + ) + assert (w0, h0) == patch_pos_embed.shape[-2:] + patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) + return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) + + def prepare_tokens_with_masks(self, x, masks=None): + B, nc, w, h = x.shape + x = self.patch_embed(x) + if masks is not None: + x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x) + + x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1) + x = x + self.interpolate_pos_encoding(x, w, h) + + if self.register_tokens is not None: + x = torch.cat( + ( + x[:, :1], + self.register_tokens.expand(x.shape[0], -1, -1), + x[:, 1:], + ), + dim=1, + ) + + return x + + def forward_features_list(self, x_list, masks_list): + x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)] + for blk in self.blocks: + x = blk(x) + + all_x = x + output = [] + for x, masks in zip(all_x, masks_list): + x_norm = self.norm(x) + output.append( + { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + ) + return output + + def forward_features(self, x, masks=None): + if isinstance(x, list): + return self.forward_features_list(x, masks) + + x = self.prepare_tokens_with_masks(x, masks) + + for blk in self.blocks: + x = blk(x) + + x_norm = self.norm(x) + return { + "x_norm_clstoken": x_norm[:, 0], + "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1], + "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :], + "x_prenorm": x, + "masks": masks, + } + + def _get_intermediate_layers_not_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + # If n is an int, take the n last blocks. If it's a list, take them + output, total_block_len = [], len(self.blocks) + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for i, blk in enumerate(self.blocks): + x = blk(x) + if i in blocks_to_take: + output.append(x) + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def _get_intermediate_layers_chunked(self, x, n=1): + x = self.prepare_tokens_with_masks(x) + output, i, total_block_len = [], 0, len(self.blocks[-1]) + # If n is an int, take the n last blocks. If it's a list, take them + blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n + for block_chunk in self.blocks: + for blk in block_chunk[i:]: # Passing the nn.Identity() + x = blk(x) + if i in blocks_to_take: + output.append(x) + i += 1 + assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found" + return output + + def get_intermediate_layers( + self, + x: torch.Tensor, + n: Union[int, Sequence] = 1, # Layers or n last layers to take + reshape: bool = False, + return_class_token: bool = False, + norm=True, + ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]: + if self.chunked_blocks: + outputs = self._get_intermediate_layers_chunked(x, n) + else: + outputs = self._get_intermediate_layers_not_chunked(x, n) + if norm: + outputs = [self.norm(out) for out in outputs] + class_tokens = [out[:, 0] for out in outputs] + outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs] + if reshape: + B, _, w, h = x.shape + outputs = [ + out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous() + for out in outputs + ] + if return_class_token: + return tuple(zip(outputs, class_tokens)) + return tuple(outputs) + + def forward(self, *args, is_training=False, **kwargs): + ret = self.forward_features(*args, **kwargs) + if is_training: + return ret + else: + return self.head(ret["x_norm_clstoken"]) + + +def init_weights_vit_timm(module: nn.Module, name: str = ""): + """ViT weight initialization, original timm impl (for reproducibility)""" + if isinstance(module, nn.Linear): + trunc_normal_(module.weight, std=0.02) + if module.bias is not None: + nn.init.zeros_(module.bias) + + +def vit_small(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=384, + depth=12, + num_heads=6, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_base(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=768, + depth=12, + num_heads=12, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_large(patch_size=16, num_register_tokens=0, **kwargs): + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1024, + depth=24, + num_heads=16, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model + + +def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs): + """ + Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64 + """ + model = DinoVisionTransformer( + patch_size=patch_size, + embed_dim=1536, + depth=40, + num_heads=24, + mlp_ratio=4, + block_fn=partial(Block, attn_class=MemEffAttention), + num_register_tokens=num_register_tokens, + **kwargs, + ) + return model diff --git a/moge/model/dinov2/utils/__init__.py b/moge/model/dinov2/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..b88da6bf80be92af00b72dfdb0a806fa64a7a2d9 --- /dev/null +++ b/moge/model/dinov2/utils/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. diff --git a/moge/model/dinov2/utils/cluster.py b/moge/model/dinov2/utils/cluster.py new file mode 100644 index 0000000000000000000000000000000000000000..3df87dc3e1eb4f0f8a280dc3137cfef031886314 --- /dev/null +++ b/moge/model/dinov2/utils/cluster.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from enum import Enum +import os +from pathlib import Path +from typing import Any, Dict, Optional + + +class ClusterType(Enum): + AWS = "aws" + FAIR = "fair" + RSC = "rsc" + + +def _guess_cluster_type() -> ClusterType: + uname = os.uname() + if uname.sysname == "Linux": + if uname.release.endswith("-aws"): + # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws" + return ClusterType.AWS + elif uname.nodename.startswith("rsc"): + # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc" + return ClusterType.RSC + + return ClusterType.FAIR + + +def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]: + if cluster_type is None: + return _guess_cluster_type() + + return cluster_type + + +def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + CHECKPOINT_DIRNAMES = { + ClusterType.AWS: "checkpoints", + ClusterType.FAIR: "checkpoint", + ClusterType.RSC: "checkpoint/dino", + } + return Path("/") / CHECKPOINT_DIRNAMES[cluster_type] + + +def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]: + checkpoint_path = get_checkpoint_path(cluster_type) + if checkpoint_path is None: + return None + + username = os.environ.get("USER") + assert username is not None + return checkpoint_path / username + + +def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]: + cluster_type = get_cluster_type(cluster_type) + if cluster_type is None: + return None + + SLURM_PARTITIONS = { + ClusterType.AWS: "learnlab", + ClusterType.FAIR: "learnlab", + ClusterType.RSC: "learn", + } + return SLURM_PARTITIONS[cluster_type] + + +def get_slurm_executor_parameters( + nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs +) -> Dict[str, Any]: + # create default parameters + params = { + "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html + "gpus_per_node": num_gpus_per_node, + "tasks_per_node": num_gpus_per_node, # one task per GPU + "cpus_per_task": 10, + "nodes": nodes, + "slurm_partition": get_slurm_partition(cluster_type), + } + # apply cluster-specific adjustments + cluster_type = get_cluster_type(cluster_type) + if cluster_type == ClusterType.AWS: + params["cpus_per_task"] = 12 + del params["mem_gb"] + elif cluster_type == ClusterType.RSC: + params["cpus_per_task"] = 12 + # set additional parameters / apply overrides + params.update(kwargs) + return params diff --git a/moge/model/dinov2/utils/config.py b/moge/model/dinov2/utils/config.py new file mode 100644 index 0000000000000000000000000000000000000000..c9de578787bbcb376f8bd5a782206d0eb7ec1f52 --- /dev/null +++ b/moge/model/dinov2/utils/config.py @@ -0,0 +1,72 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import math +import logging +import os + +from omegaconf import OmegaConf + +import dinov2.distributed as distributed +from dinov2.logging import setup_logging +from dinov2.utils import utils +from dinov2.configs import dinov2_default_config + + +logger = logging.getLogger("dinov2") + + +def apply_scaling_rules_to_cfg(cfg): # to fix + if cfg.optim.scaling_rule == "sqrt_wrt_1024": + base_lr = cfg.optim.base_lr + cfg.optim.lr = base_lr + cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0) + logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}") + else: + raise NotImplementedError + return cfg + + +def write_config(cfg, output_dir, name="config.yaml"): + logger.info(OmegaConf.to_yaml(cfg)) + saved_cfg_path = os.path.join(output_dir, name) + with open(saved_cfg_path, "w") as f: + OmegaConf.save(config=cfg, f=f) + return saved_cfg_path + + +def get_cfg_from_args(args): + args.output_dir = os.path.abspath(args.output_dir) + args.opts += [f"train.output_dir={args.output_dir}"] + default_cfg = OmegaConf.create(dinov2_default_config) + cfg = OmegaConf.load(args.config_file) + cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts)) + return cfg + + +def default_setup(args): + distributed.enable(overwrite=True) + seed = getattr(args, "seed", 0) + rank = distributed.get_global_rank() + + global logger + setup_logging(output=args.output_dir, level=logging.INFO) + logger = logging.getLogger("dinov2") + + utils.fix_random_seeds(seed + rank) + logger.info("git:\n {}\n".format(utils.get_sha())) + logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items()))) + + +def setup(args): + """ + Create configs and perform basic setups. + """ + cfg = get_cfg_from_args(args) + os.makedirs(args.output_dir, exist_ok=True) + default_setup(args) + apply_scaling_rules_to_cfg(cfg) + write_config(cfg, args.output_dir) + return cfg diff --git a/moge/model/dinov2/utils/dtype.py b/moge/model/dinov2/utils/dtype.py new file mode 100644 index 0000000000000000000000000000000000000000..80f4cd74d99faa2731dbe9f8d3a13d71b3f8e3a8 --- /dev/null +++ b/moge/model/dinov2/utils/dtype.py @@ -0,0 +1,37 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + + +from typing import Dict, Union + +import numpy as np +import torch + + +TypeSpec = Union[str, np.dtype, torch.dtype] + + +_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = { + np.dtype("bool"): torch.bool, + np.dtype("uint8"): torch.uint8, + np.dtype("int8"): torch.int8, + np.dtype("int16"): torch.int16, + np.dtype("int32"): torch.int32, + np.dtype("int64"): torch.int64, + np.dtype("float16"): torch.float16, + np.dtype("float32"): torch.float32, + np.dtype("float64"): torch.float64, + np.dtype("complex64"): torch.complex64, + np.dtype("complex128"): torch.complex128, +} + + +def as_torch_dtype(dtype: TypeSpec) -> torch.dtype: + if isinstance(dtype, torch.dtype): + return dtype + if isinstance(dtype, str): + dtype = np.dtype(dtype) + assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}" + return _NUMPY_TO_TORCH_DTYPE[dtype] diff --git a/moge/model/dinov2/utils/param_groups.py b/moge/model/dinov2/utils/param_groups.py new file mode 100644 index 0000000000000000000000000000000000000000..9a5d2ff627cddadc222e5f836864ee39c865208f --- /dev/null +++ b/moge/model/dinov2/utils/param_groups.py @@ -0,0 +1,103 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +from collections import defaultdict +import logging + + +logger = logging.getLogger("dinov2") + + +def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False): + """ + Calculate lr decay rate for different ViT blocks. + Args: + name (string): parameter name. + lr_decay_rate (float): base lr decay rate. + num_layers (int): number of ViT blocks. + Returns: + lr decay rate for the given parameter. + """ + layer_id = num_layers + 1 + if name.startswith("backbone") or force_is_backbone: + if ( + ".pos_embed" in name + or ".patch_embed" in name + or ".mask_token" in name + or ".cls_token" in name + or ".register_tokens" in name + ): + layer_id = 0 + elif force_is_backbone and ( + "pos_embed" in name + or "patch_embed" in name + or "mask_token" in name + or "cls_token" in name + or "register_tokens" in name + ): + layer_id = 0 + elif ".blocks." in name and ".residual." not in name: + layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1 + elif chunked_blocks and "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1 + elif "blocks." in name and "residual." not in name: + layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1 + + return lr_decay_rate ** (num_layers + 1 - layer_id) + + +def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0): + chunked_blocks = False + if hasattr(model, "n_blocks"): + logger.info("chunked fsdp") + n_blocks = model.n_blocks + chunked_blocks = model.chunked_blocks + elif hasattr(model, "blocks"): + logger.info("first code branch") + n_blocks = len(model.blocks) + elif hasattr(model, "backbone"): + logger.info("second code branch") + n_blocks = len(model.backbone.blocks) + else: + logger.info("else code branch") + n_blocks = 0 + all_param_groups = [] + + for name, param in model.named_parameters(): + name = name.replace("_fsdp_wrapped_module.", "") + if not param.requires_grad: + continue + decay_rate = get_vit_lr_decay_rate( + name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks + ) + d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name} + + if "last_layer" in name: + d.update({"is_last_layer": True}) + + if name.endswith(".bias") or "norm" in name or "gamma" in name: + d.update({"wd_multiplier": 0.0}) + + if "patch_embed" in name: + d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult}) + + all_param_groups.append(d) + logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""") + + return all_param_groups + + +def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")): + fused_params_groups = defaultdict(lambda: {"params": []}) + for d in all_params_groups: + identifier = "" + for k in keys: + identifier += k + str(d[k]) + "_" + + for k in keys: + fused_params_groups[identifier][k] = d[k] + fused_params_groups[identifier]["params"].append(d["params"]) + + return fused_params_groups.values() diff --git a/moge/model/dinov2/utils/utils.py b/moge/model/dinov2/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..68f8e2c3be5f780bbb7e00359b5ac4fd0ba0785f --- /dev/null +++ b/moge/model/dinov2/utils/utils.py @@ -0,0 +1,95 @@ +# Copyright (c) Meta Platforms, Inc. and affiliates. +# +# This source code is licensed under the Apache License, Version 2.0 +# found in the LICENSE file in the root directory of this source tree. + +import logging +import os +import random +import subprocess +from urllib.parse import urlparse + +import numpy as np +import torch +from torch import nn + + +logger = logging.getLogger("dinov2") + + +def load_pretrained_weights(model, pretrained_weights, checkpoint_key): + if urlparse(pretrained_weights).scheme: # If it looks like an URL + state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu") + else: + state_dict = torch.load(pretrained_weights, map_location="cpu") + if checkpoint_key is not None and checkpoint_key in state_dict: + logger.info(f"Take key {checkpoint_key} in provided checkpoint dict") + state_dict = state_dict[checkpoint_key] + # remove `module.` prefix + state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()} + # remove `backbone.` prefix induced by multicrop wrapper + state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()} + msg = model.load_state_dict(state_dict, strict=False) + logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg)) + + +def fix_random_seeds(seed=31): + """ + Fix random seeds. + """ + torch.manual_seed(seed) + torch.cuda.manual_seed_all(seed) + np.random.seed(seed) + random.seed(seed) + + +def get_sha(): + cwd = os.path.dirname(os.path.abspath(__file__)) + + def _run(command): + return subprocess.check_output(command, cwd=cwd).decode("ascii").strip() + + sha = "N/A" + diff = "clean" + branch = "N/A" + try: + sha = _run(["git", "rev-parse", "HEAD"]) + subprocess.check_output(["git", "diff"], cwd=cwd) + diff = _run(["git", "diff-index", "HEAD"]) + diff = "has uncommitted changes" if diff else "clean" + branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"]) + except Exception: + pass + message = f"sha: {sha}, status: {diff}, branch: {branch}" + return message + + +class CosineScheduler(object): + def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0): + super().__init__() + self.final_value = final_value + self.total_iters = total_iters + + freeze_schedule = np.zeros((freeze_iters)) + + warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters) + + iters = np.arange(total_iters - warmup_iters - freeze_iters) + schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters))) + self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule)) + + assert len(self.schedule) == self.total_iters + + def __getitem__(self, it): + if it >= self.total_iters: + return self.final_value + else: + return self.schedule[it] + + +def has_batchnorms(model): + bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm) + for name, module in model.named_modules(): + if isinstance(module, bn_types): + return True + return False diff --git a/moge/model/moge_model.py b/moge/model/moge_model.py new file mode 100644 index 0000000000000000000000000000000000000000..60316f0caffe9f6e6d5d725619668ef71f116268 --- /dev/null +++ b/moge/model/moge_model.py @@ -0,0 +1,376 @@ +from typing import * +from numbers import Number +from functools import partial +from pathlib import Path +import importlib +import warnings +import json + +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils +import torch.utils.checkpoint +import torch.version +import utils3d +from huggingface_hub import hf_hub_download + +from ..utils.geometry_torch import image_plane_uv, point_map_to_depth, gaussian_blur_2d +from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing +from ..utils.tools import timeit + + +class ResidualConvBlock(nn.Module): + def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'): + super(ResidualConvBlock, self).__init__() + if out_channels is None: + out_channels = in_channels + if hidden_channels is None: + hidden_channels = in_channels + + if activation =='relu': + activation_cls = lambda: nn.ReLU(inplace=True) + elif activation == 'leaky_relu': + activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True) + elif activation =='silu': + activation_cls = lambda: nn.SiLU(inplace=True) + elif activation == 'elu': + activation_cls = lambda: nn.ELU(inplace=True) + else: + raise ValueError(f'Unsupported activation function: {activation}') + + self.layers = nn.Sequential( + nn.GroupNorm(1, in_channels), + activation_cls(), + nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode), + nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels), + activation_cls(), + nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode) + ) + + self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity() + + def forward(self, x): + skip = self.skip_connection(x) + x = self.layers(x) + x = x + skip + return x + + +class Head(nn.Module): + def __init__( + self, + num_features: int, + dim_in: int, + dim_out: List[int], + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1 + ): + super().__init__() + + self.projects = nn.ModuleList([ + nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features) + ]) + + self.upsample_blocks = nn.ModuleList([ + nn.Sequential( + self._make_upsampler(in_ch + 2, out_ch), + *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks)) + ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample) + ]) + + self.output_block = nn.ModuleList([ + self._make_output_block( + dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm, + ) for dim_out_ in dim_out + ]) + + def _make_upsampler(self, in_channels: int, out_channels: int): + upsampler = nn.Sequential( + nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2), + nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate') + ) + upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1] + return upsampler + + def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']): + return nn.Sequential( + nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'), + *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)), + nn.ReLU(inplace=True), + nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'), + ) + + def forward(self, hidden_states: torch.Tensor, image: torch.Tensor): + img_h, img_w = image.shape[-2:] + patch_h, patch_w = img_h // 14, img_w // 14 + + # Process the hidden states + x = torch.stack([ + proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous()) + for proj, (feat, clstoken) in zip(self.projects, hidden_states) + ], dim=1).sum(dim=1) + + # Upsample stage + # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8) + for i, block in enumerate(self.upsample_blocks): + # UV coordinates is for awareness of image aspect ratio + uv = image_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + for layer in block: + x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False) + + # (patch_h * 8, patch_w * 8) -> (img_h, img_w) + x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False) + uv = image_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device) + uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1) + x = torch.cat([x, uv], dim=1) + + if isinstance(self.output_block, nn.ModuleList): + output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block] + else: + output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False) + + return output + + +class MoGeModel(nn.Module): + image_mean: torch.Tensor + image_std: torch.Tensor + + def __init__(self, + encoder: str = 'dinov2_vitb14', + intermediate_layers: Union[int, List[int]] = 4, + dim_proj: int = 512, + dim_upsample: List[int] = [256, 128, 128], + dim_times_res_block_hidden: int = 1, + num_res_blocks: int = 1, + output_mask: bool = False, + split_head: bool = False, + remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear', + res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm', + trained_diagonal_size_range: Tuple[Number, Number] = (600, 900), + trained_area_range: Tuple[Number, Number] = (500 * 500, 700 * 700), + last_res_blocks: int = 0, + last_conv_channels: int = 32, + last_conv_size: int = 1, + **deprecated_kwargs + ): + super(MoGeModel, self).__init__() + if deprecated_kwargs: + warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}") + + self.encoder = encoder + self.remap_output = remap_output + self.intermediate_layers = intermediate_layers + self.trained_diagonal_size_range = trained_diagonal_size_range + self.trained_area_range = trained_area_range + self.output_mask = output_mask + self.split_head = split_head + + # NOTE: We have copied the DINOv2 code in torchhub to this repository. + # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues. + hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder) + self.backbone = hub_loader(pretrained=False) + dim_feature = self.backbone.blocks[0].attn.qkv.in_features + + self.head = Head( + num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers), + dim_in=dim_feature, + dim_out=3 if not output_mask else 4 if output_mask and not split_head else [3, 1], + dim_proj=dim_proj, + dim_upsample=dim_upsample, + dim_times_res_block_hidden=dim_times_res_block_hidden, + num_res_blocks=num_res_blocks, + res_block_norm=res_block_norm, + last_res_blocks=last_res_blocks, + last_conv_channels=last_conv_channels, + last_conv_size=last_conv_size + ) + + image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1) + image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1) + + self.register_buffer("image_mean", image_mean) + self.register_buffer("image_std", image_std) + + if torch.__version__ >= '2.0': + self.enable_pytorch_native_sdpa() + + @classmethod + def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel': + """ + Load a model from a checkpoint file. + + ### Parameters: + - `pretrained_model_name_or_path`: path to the checkpoint file or repo id. + - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint. + - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path. + + ### Returns: + - A new instance of `MoGe` with the parameters loaded from the checkpoint. + """ + if Path(pretrained_model_name_or_path).exists(): + checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True) + else: + cached_checkpoint_path = hf_hub_download( + repo_id=pretrained_model_name_or_path, + repo_type="model", + filename="model.pt", + **hf_kwargs + ) + checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True) + model_config = checkpoint['model_config'] + if model_kwargs is not None: + model_config.update(model_kwargs) + model = cls(**model_config) + model.load_state_dict(checkpoint['model']) + return model + + @staticmethod + def cache_pretrained_backbone(encoder: str, pretrained: bool): + _ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained) + + def load_pretrained_backbone(self): + "Load the backbone with pretrained dinov2 weights from torch hub" + state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict() + self.backbone.load_state_dict(state_dict) + + def enable_backbone_gradient_checkpointing(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i]) + + def enable_pytorch_native_sdpa(self): + for i in range(len(self.backbone.blocks)): + self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn) + + def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]: + raw_img_h, raw_img_w = image.shape[-2:] + patch_h, patch_w = raw_img_h // 14, raw_img_w // 14 + + image = (image - self.image_mean) / self.image_std + + # Apply image transformation for DINOv2 + image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True) + + # Get intermediate layers from the backbone + with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision): + features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True) + + # Predict points (and mask) + output = self.head(features, image) + if self.output_mask: + if self.split_head: + points, mask = output + else: + points, mask = output.split([3, 1], dim=1) + points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1) + else: + points = output.permute(0, 2, 3, 1) + + if self.remap_output == 'linear' or self.remap_output == False: + pass + elif self.remap_output =='sinh' or self.remap_output == True: + points = torch.sinh(points) + elif self.remap_output == 'exp': + xy, z = points.split([2, 1], dim=-1) + z = torch.exp(z) + points = torch.cat([xy * z, z], dim=-1) + elif self.remap_output =='sinh_exp': + xy, z = points.split([2, 1], dim=-1) + points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1) + else: + raise ValueError(f"Invalid remap output type: {self.remap_output}") + + return_dict = {'points': points} + if self.output_mask: + return_dict['mask'] = mask + return return_dict + + @torch.inference_mode() + def infer( + self, + image: torch.Tensor, + force_projection: bool = True, + resolution_level: int = 9, + apply_mask: bool = True, + ) -> Dict[str, torch.Tensor]: + """ + User-friendly inference function + + ### Parameters + - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W) + - `resolution_level`: the resolution level to use for the output point map in 0-9. Default: 9 (highest) + - `interpolation_mode`: interpolation mode for the output points map. Default: 'bilinear'. + + ### Returns + + A dictionary containing the following keys: + - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3). + - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map. + - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics. + """ + if image.dim() == 3: + omit_batch_dim = True + image = image.unsqueeze(0) + else: + omit_batch_dim = False + + original_height, original_width = image.shape[-2:] + area = original_height * original_width + + min_area, max_area = self.trained_area_range + expected_area = min_area + (max_area - min_area) * (resolution_level / 9) + + if expected_area != area: + expected_width, expected_height = int(original_width * (expected_area / area) ** 0.5), int(original_height * (expected_area / area) ** 0.5) + image = F.interpolate(image, (expected_height, expected_width), mode="bicubic", align_corners=False, antialias=True) + + output = self.forward(image) + points, mask = output['points'], output.get('mask', None) + + # Get camera-origin-centered point map + depth, fov_x, fov_y, z_shift = point_map_to_depth(points, None if mask is None else mask > 0.5) + intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov_x, fov_y) + + # If projection constraint is forces, recompute the point map using the actual depth map + if force_projection: + points = utils3d.torch.unproject_cv(utils3d.torch.image_uv(width=expected_width, height=expected_height, dtype=points.dtype, device=points.device), depth, extrinsics=None, intrinsics=intrinsics[..., None, :, :]) + else: + points = points + torch.stack([torch.zeros_like(z_shift), torch.zeros_like(z_shift), z_shift], dim=-1)[..., None, None, :] + + # Resize the output to the original resolution + if expected_area != area: + points = F.interpolate(points.permute(0, 3, 1, 2), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).permute(0, 2, 3, 1) + depth = F.interpolate(depth.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1) + mask = None if mask is None else F.interpolate(mask.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1) + + # Apply mask if needed + if self.output_mask and apply_mask: + mask_binary = (depth > 0) & (mask > 0.5) + points = torch.where(mask_binary[..., None], points, torch.inf) + depth = torch.where(mask_binary, depth, torch.inf) + + if omit_batch_dim: + points = points.squeeze(0) + intrinsics = intrinsics.squeeze(0) + depth = depth.squeeze(0) + if self.output_mask: + mask = mask.squeeze(0) + + return_dict = { + 'points': points, + 'intrinsics': intrinsics, + 'depth': depth, + } + if self.output_mask: + return_dict['mask'] = mask > 0.5 + + return return_dict diff --git a/moge/model/utils.py b/moge/model/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..af0d0a042209ed87cb60f340529940359fdfa900 --- /dev/null +++ b/moge/model/utils.py @@ -0,0 +1,38 @@ +from typing import * + +import torch +import torch.nn as nn +import torch.nn.functional as F + +def wrap_module_with_gradient_checkpointing(module: nn.Module): + from torch.utils.checkpoint import checkpoint + class _CheckpointingWrapper(module.__class__): + _restore_cls = module.__class__ + def forward(self, *args, **kwargs): + return checkpoint(super().forward, *args, use_reentrant=False, **kwargs) + + module.__class__ = _CheckpointingWrapper + return module + + +def unwrap_module_with_gradient_checkpointing(module: nn.Module): + module.__class__ = module.__class__._restore_cls + + +def wrap_dinov2_attention_with_sdpa(module: nn.Module): + assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later" + class _AttentionWrapper(module.__class__): + def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor: + B, N, C = x.shape + qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H) + + q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H) + + x = F.scaled_dot_product_attention(q, k, v, attn_bias) + x = x.permute(0, 2, 1, 3).reshape(B, N, C) + + x = self.proj(x) + x = self.proj_drop(x) + return x + module.__class__ = _AttentionWrapper + return module \ No newline at end of file diff --git a/moge/utils/__init__.py b/moge/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/moge/utils/blob.py b/moge/utils/blob.py new file mode 100644 index 0000000000000000000000000000000000000000..cc5adf649726e8663ca344642ad0cd90e8b9f817 --- /dev/null +++ b/moge/utils/blob.py @@ -0,0 +1,314 @@ +from typing import IO, Generator, Tuple, Union, overload +from pathlib import Path, PosixPath, PurePosixPath +import io +import os +import re +import requests +import fnmatch + +from azure.identity import DefaultAzureCredential +from azure.storage.blob import ContainerClient, BlobClient +import requests.adapters +import requests.packages +from urllib3.util.retry import Retry + + +__all__ = [ + 'download_blob', 'upload_blob', + 'download_blob_with_cache', + 'open_blob', 'open_blob_with_cache', + 'blob_file_exists', + 'AzureBlobPath','SmartPath' +] + +DEFAULT_CREDENTIAL = DefaultAzureCredential() + +BLOB_CACHE_DIR = './.blobcache' + +def download_blob(blob: Union[str, BlobClient]) -> bytes: + if isinstance(blob, str): + blob_client = BlobClient.from_blob_url(blob_client) + else: + blob_client = blob + return blob_client.download_blob().read() + + +def upload_blob(blob: Union[str, BlobClient], data: Union[str, bytes]): + if isinstance(blob, str): + blob_client = BlobClient.from_blob_url(blob) + else: + blob_client = blob + blob_client.upload_blob(data, overwrite=True) + + +def download_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> bytes: + """ + Download a blob file from a container and return its content as bytes. + If the file is already present in the cache, it is read from there. + """ + cache_path = Path(cache_dir) / blob_name + if cache_path.exists(): + return cache_path.read_bytes() + data = download_blob(container, blob_name) + cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_path.write_bytes(data) + return data + + +def open_blob(container: Union[str, ContainerClient], blob_name: str) -> io.BytesIO: + """ + Open a blob file for reading from a container and return its content as a BytesIO object. + """ + return io.BytesIO(download_blob(container, blob_name)) + + +def open_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> io.BytesIO: + """ + Open a blob file for reading from a container and return its content as a BytesIO object. + If the file is already present in the cache, it is read from there. + """ + return io.BytesIO(download_blob_with_cache(container, blob_name, cache_dir=cache_dir)) + + +def blob_file_exists(container: Union[str, ContainerClient], blob_name: str) -> bool: + """ + Check if a blob file exists in a container. + """ + if isinstance(container, str): + container = ContainerClient.from_container_url(container) + blob_client = container.get_blob_client(blob_name) + return blob_client.exists() + +def is_blob_url(url: str) -> bool: + return re.match(r'https://[^/]+blob.core.windows.net/+', url) is not None + + +def split_blob_url(url: str) -> Tuple[str, str, str]: + match = re.match(r'(https://[^/]+blob.core.windows.net/[^/?]+)(/([^\?]*))?(\?.+)?', url) + if match: + container, _, path, sas = match.groups() + return container, path or '', sas or '' + raise ValueError(f'Not a valid blob URL: {url}') + + +def join_blob_path(url: str, *others: str) -> str: + container, path, sas = split_blob_url(url) + return container + '/' + os.path.join(path, *others) + sas + + +class AzureBlobStringWriter(io.StringIO): + def __init__(self, blob_client: BlobClient, encoding: str = 'utf-8', **kwargs): + self._encoding = encoding + self.blob_client = blob_client + self.kwargs = kwargs + super().__init__() + + def close(self): + self.blob_client.upload_blob(self.getvalue().encode(self._encoding), blob_type='BlockBlob', overwrite=True, **self.kwargs) + + +class AzureBlobBytesWriter(io.BytesIO): + def __init__(self, blob_client: BlobClient, **kwargs): + super().__init__() + self.blob_client = blob_client + self.kwargs = kwargs + + def close(self): + self.blob_client.upload_blob(self.getvalue(), blob_type='BlockBlob', overwrite=True, **self.kwargs) + + +def open_azure_blob(blob: Union[str, BlobClient], mode: str = 'r', encoding: str = 'utf-8', newline: str = None, cache_blob: bool = False, **kwargs) -> IO: + if isinstance(blob, str): + blob_client = BlobClient.from_blob_url(blob) + elif isinstance(blob, BlobClient): + blob_client = blob + else: + raise ValueError(f'Must be a blob URL or a BlobClient object: {blob}') + + if cache_blob: + cache_path = Path(BLOB_CACHE_DIR, blob_client.account_name, blob_client.container_name, blob_client.blob_name) + + if mode == 'r' or mode == 'rb': + if cache_blob: + if cache_path.exists(): + data = cache_path.read_bytes() + else: + data = blob_client.download_blob(**kwargs).read() + cache_path.parent.mkdir(parents=True, exist_ok=True) + cache_path.write_bytes(data) + else: + data = blob_client.download_blob(**kwargs).read() + if mode == 'r': + return io.StringIO(data.decode(encoding), newline=newline) + else: + return io.BytesIO(data) + elif mode == 'w': + return AzureBlobStringWriter(blob_client, **kwargs) + elif mode == 'wb': + return AzureBlobBytesWriter(blob_client, **kwargs) + else: + raise ValueError(f'Unsupported mode: {mode}') + + +def smart_open(path_or_url: Union[Path, str], mode: str = 'r', encoding: str = 'utf-8') -> IO: + if is_blob_url(str(path_or_url)): + return open_azure_blob(str(path_or_url), mode, encoding) + return open(path_or_url, mode, encoding) + + +class AzureBlobPath(PurePosixPath): + """ + Implementation of pathlib.Path like interface for Azure Blob Storage. + """ + container_client: ContainerClient + _parse_path = PurePosixPath._parse_args if hasattr(PurePosixPath, '_parse_args') else PurePosixPath._parse_path + + def __new__(cls, *args, **kwargs): + """Override the old __new__ method. Parts are parsed in __init__""" + return object.__new__(cls) + + def __init__(self, root: Union[str, 'AzureBlobPath', ContainerClient], *others: Union[str, PurePosixPath], pool_maxsize: int = 256, retries: int = 3): + if isinstance(root, AzureBlobPath): + self.container_client = root.container_client + parts = root.parts + others + elif isinstance(root, str): + url = root + container, path, sas = split_blob_url(url) + session = self._get_session(pool_maxsize=pool_maxsize, retries=retries) + if sas: + self.container_client = ContainerClient.from_container_url(container + sas, session=session) + else: + self.container_client = ContainerClient.from_container_url(container, credential=DEFAULT_CREDENTIAL, session=session) + parts = (path, *others) + elif isinstance(root, ContainerClient): + self.container_client = root + parts = others + else: + raise ValueError(f'Invalid root: {root}') + + if hasattr(PurePosixPath, '_parse_args'): + # For compatibility with Python 3.10 + drv, root, parts = PurePosixPath._parse_args(parts) + self._drv = drv + self._root = root + self._parts = parts + else: + super().__init__(*parts) + + def _get_session(self, pool_maxsize: int = 1024, retries: int = 3) -> requests.Session: + session = requests.Session() + retry_strategy = Retry( + total=retries, + status_forcelist=[429, 500, 502, 503, 504], + allowed_methods=["HEAD", "GET", "PUT", "DELETE"], + backoff_factor=1, + raise_on_status=False, + read=retries, + connect=retries, + redirect=retries, + ) + adapter = requests.adapters.HTTPAdapter(pool_connections=pool_maxsize, pool_maxsize=pool_maxsize, max_retries=retry_strategy) + session.mount('http://', adapter) + session.mount('https://', adapter) + return session + + def _from_parsed_parts(self, drv, root, parts): + "For compatibility with Python 3.10" + return AzureBlobPath(self.container_client, drv, root, *parts) + + def with_segments(self, *pathsegments): + return AzureBlobPath(self.container_client, *pathsegments) + + @property + def path(self) -> str: + return '/'.join(self.parts) + + @property + def blob_client(self) -> BlobClient: + return self.container_client.get_blob_client(self.path) + + @property + def url(self) -> str: + if len(self.parts) == 0: + return self.container_client.url + return self.container_client.get_blob_client(self.path).url + + @property + def container_name(self) -> str: + return self.container_client.container_name + + @property + def account_name(self) -> str: + return self.container_client.account_name + + def __str__(self): + return self.url + + def __repr__(self): + return self.url + + def open(self, mode: str = 'r', encoding: str = 'utf-8', cache_blob: bool = False, **kwargs) -> IO: + return open_azure_blob(self.blob_client, mode, encoding, cache_blob=cache_blob, **kwargs) + + def __truediv__(self, other: Union[str, Path]) -> 'AzureBlobPath': + return self.joinpath(other) + + def mkdir(self, parents: bool = False, exist_ok: bool = False): + pass + + def iterdir(self) -> Generator['AzureBlobPath', None, None]: + path = self.path + if not path.endswith('/'): + path += '/' + for item in self.container_client.walk_blobs(self.path): + yield AzureBlobPath(self.container_client, item.name) + + def glob(self, pattern: str) -> Generator['AzureBlobPath', None, None]: + special_chars = ".^$+{}[]()|/" + for char in special_chars: + pattern = pattern.replace(char, "\\" + char) + pattern = pattern.replace('**', './/.') + pattern = pattern.replace('*', '[^/]*') + pattern = pattern.replace('.//.', '.*') + pattern = "^" + pattern + "$" + reg = re.compile(pattern) + + for item in self.container_client.list_blobs(self.path): + if reg.match(os.path.relpath(item.name, self.path)): + yield AzureBlobPath(self.container_client, item.name) + + def exists(self) -> bool: + return self.blob_client.exists() + + def read_bytes(self, cache_blob: bool = False) -> bytes: + with self.open('rb', cache_blob=cache_blob) as f: + return f.read() + + def read_text(self, encoding: str = 'utf-8', cache_blob: bool = False) -> str: + with self.open('r', encoding=encoding, cache_blob=cache_blob) as f: + return f.read() + + def write_bytes(self, data: bytes): + self.blob_client.upload_blob(data, overwrite=True) + + def write_text(self, data: str, encoding: str = 'utf-8'): + self.blob_client.upload_blob(data.encode(encoding), overwrite=True) + + def unlink(self): + self.blob_client.delete_blob() + + def new_client(self) -> 'AzureBlobPath': + return AzureBlobPath(self.container_client.url, self.path) + + +class SmartPath(Path, AzureBlobPath): + """ + Supports both local file paths and Azure Blob Storage URLs. + """ + def __new__(cls, first: Union[Path, str], *others: Union[str, PurePosixPath]) -> Union[Path, AzureBlobPath]: + if is_blob_url(str(first)): + return AzureBlobPath(str(first), *others) + return Path(first, *others) + + + \ No newline at end of file diff --git a/moge/utils/download.py b/moge/utils/download.py new file mode 100644 index 0000000000000000000000000000000000000000..886edbccc81cc0c3daed4d858f641097bdfceee2 --- /dev/null +++ b/moge/utils/download.py @@ -0,0 +1,55 @@ +from pathlib import Path +from typing import * +import requests + +from tqdm import tqdm + + +__all__ = ["download_file", "download_bytes"] + + +def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None: + # Ensure headers is a dict if not provided + headers = headers or {} + + # Initialize local variables + file_path = Path(filepath) + downloaded_bytes = 0 + + # Check if we should resume the download + if resume and file_path.exists(): + downloaded_bytes = file_path.stat().st_size + headers['Range'] = f"bytes={downloaded_bytes}-" + + # Make a GET request to fetch the file + with requests.get(url, stream=True, headers=headers) as response: + response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx + + # Calculate the total size to download + total_size = downloaded_bytes + int(response.headers.get('content-length', 0)) + + # Display a progress bar while downloading + with ( + tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar, + open(file_path, 'ab') as file, + ): + # Set the initial position of the progress bar + pbar.update(downloaded_bytes) + + # Write the content to the file in chunks + for chunk in response.iter_content(chunk_size=4096): + file.write(chunk) + pbar.update(len(chunk)) + + +def download_bytes(url: str, headers: dict = None) -> bytes: + # Ensure headers is a dict if not provided + headers = headers or {} + + # Make a GET request to fetch the file + with requests.get(url, stream=True, headers=headers) as response: + response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx + + # Read the content of the response + return response.content + \ No newline at end of file diff --git a/moge/utils/geometry_numpy.py b/moge/utils/geometry_numpy.py new file mode 100644 index 0000000000000000000000000000000000000000..3ed53dae60e73dcca31cc9b64946a227a967a600 --- /dev/null +++ b/moge/utils/geometry_numpy.py @@ -0,0 +1,175 @@ +from typing import * +from functools import partial +import math + +import numpy as np +import utils3d + +from .tools import timeit + +def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: + if w is None: + return np.mean(x, axis=axis) + else: + w = w.astype(x.dtype) + return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None) + + +def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray: + if w is None: + return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis) + else: + w = w.astype(x.dtype) + return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps) + + +def image_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 + + u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype) + v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + uv = np.stack([u, v], axis=-1) + return uv + + +def focal_to_fov_numpy(focal: np.ndarray): + return 2 * np.arctan(0.5 / focal) + + +def fov_to_focal_numpy(fov: np.ndarray): + return 0.5 / np.tan(fov / 2) + + +def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0]) + fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1]) + return fov_x, fov_y + + +def solve_optimal_shift_focal(uv: np.ndarray, xyz: np.ndarray, ransac_iters: int = None, ransac_hypothetical_size: float = 0.1, ransac_threshold: float = 0.1): + "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal" + from scipy.optimize import least_squares + uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1) + + def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray): + xy_proj = xy / (z + shift)[: , None] + f = (xy_proj * uv).sum() / np.square(xy_proj).sum() + err = (f * xy_proj - uv).ravel() + return err + + initial_shift = 0 #-z.min(keepdims=True) + 1.0 + + if ransac_iters is None: + solution = least_squares(partial(fn, uv, xy, z), x0=initial_shift, ftol=1e-3, method='lm') + optim_shift = solution['x'].squeeze().astype(np.float32) + else: + best_err, best_shift = np.inf, None + for _ in range(ransac_iters): + maybe_inliers = np.random.choice(len(z), size=int(ransac_hypothetical_size * len(z)), replace=False) + solution = least_squares(partial(fn, uv[maybe_inliers], xy[maybe_inliers], z[maybe_inliers]), x0=initial_shift, ftol=1e-3, method='lm') + maybe_shift = solution['x'].squeeze().astype(np.float32) + confirmed_inliers = np.linalg.norm(fn(uv, xy, z, maybe_shift).reshape(-1, 2), axis=-1) < ransac_threshold + if confirmed_inliers.sum() > 10: + solution = least_squares(partial(fn, uv[confirmed_inliers], xy[confirmed_inliers], z[confirmed_inliers]), x0=maybe_shift, ftol=1e-3, method='lm') + better_shift = solution['x'].squeeze().astype(np.float32) + else: + better_shift = maybe_shift + err = np.linalg.norm(fn(uv, xy, z, better_shift).reshape(-1, 2), axis=-1).clip(max=ransac_threshold).mean() + if err < best_err: + best_err, best_shift = err, better_shift + initial_shift = best_shift + + optim_shift = best_shift + + xy_proj = xy / (z + optim_shift)[: , None] + optim_focal = (xy_proj * uv).sum() / (xy_proj * xy_proj).sum() + + return optim_shift, optim_focal + + +def point_map_to_depth_numpy(points: np.ndarray, mask: np.ndarray = None, downsample_size: Tuple[int, int] = (64, 64)): + import cv2 + assert points.shape[-1] == 3, "Points should (H, W, 3)" + + height, width = points.shape[-3], points.shape[-2] + diagonal = (height ** 2 + width ** 2) ** 0.5 + + uv = image_plane_uv_numpy(width=width, height=height) + + if mask is None: + points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3) + uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2) + else: + index, mask_lr = mask_aware_nearest_resize_numpy(mask, *downsample_size) + points_lr, uv_lr = points[index][mask_lr], uv[index][mask_lr] + + if points_lr.size == 0: + return np.zeros((height, width)), 0, 0, 0 + + optim_shift, optim_focal = solve_optimal_shift_focal(uv_lr, points_lr, ransac_iters=None) + + fov_x = 2 * np.arctan(width / diagonal / optim_focal) + fov_y = 2 * np.arctan(height / diagonal / optim_focal) + + depth = points[:, :, 2] + optim_shift + return depth, fov_x, fov_y, optim_shift + + +def mask_aware_nearest_resize_numpy(mask: np.ndarray, target_width: int, target_height: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `mask`: Input 2D mask of shape (..., H, W) + - `target_width`: target width of the resized map + - `target_height`: target height of the resized map + + ### Returns + - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width). Indices are like j + i * W, where j is the row index and i is the column index. + - `target_mask`: Mask of the resized map of shape (..., target_height, target_width) + """ + height, width = mask.shape[-2:] + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f) + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2) + + # Window the original mask and uv + uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32) + indices = np.arange(height * width, dtype=np.int32).reshape(height, width) + padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1)) + windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1)) + + # Gather the target pixels's local window + target_uv = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32) + target_corner = target_uv - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32) + target_corner = np.round(target_corner - 0.5).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32) + + target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + + # Compute nearest neighbor in the local window for each pixel + dist = np.square(target_window_uv - target_uv[..., None]) + dist = dist[..., 0, :] + dist[..., 1, :] + dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size) + nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1) + nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width) + nearest_i, nearest_j = nearest_idx // width, nearest_idx % width + target_mask = np.any(target_window_mask, axis=-1) + batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])] + + return (*batch_indices, nearest_i, nearest_j), target_mask \ No newline at end of file diff --git a/moge/utils/geometry_torch.py b/moge/utils/geometry_torch.py new file mode 100644 index 0000000000000000000000000000000000000000..1691dd5976459ed1ea75655c7273d0dda7d680b3 --- /dev/null +++ b/moge/utils/geometry_torch.py @@ -0,0 +1,231 @@ +from typing import * +import math +from collections import namedtuple + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.types +import utils3d + +from .tools import timeit +from .geometry_numpy import solve_optimal_shift_focal + + +def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.mean(dim=dim, keepdim=keepdim) + else: + w = w.to(x.dtype) + return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps) + + +def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal() + else: + w = w.to(x.dtype) + return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal() + + +def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor: + if w is None: + return x.add(eps).log().mean(dim=dim).exp() + else: + w = w.to(x.dtype) + return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp() + + +def image_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor: + "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)" + if aspect_ratio is None: + aspect_ratio = width / height + + span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5 + span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5 + + u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device) + v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device) + u, v = torch.meshgrid(u, v, indexing='xy') + uv = torch.stack([u, v], dim=-1) + return uv + + +def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor: + kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2)) + kernel = kernel / kernel.sum() + kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size) + input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate') + input = F.conv2d(input, kernel, groups=input.shape[1]) + return input + + +def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs): + batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0] + n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0) + splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args) + splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()} + results = [] + for i in range(n_chunks): + chunk_args = tuple(arg[i] for arg in splited_args) + chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()} + results.append(fn(*chunk_args, **chunk_kwargs)) + + if isinstance(results[0], tuple): + return tuple(torch.cat(r, dim=0) for r in zip(*results)) + else: + return torch.cat(results, dim=0) + + +def focal_to_fov(focal: torch.Tensor): + return 2 * torch.atan(0.5 / focal) + + +def fov_to_focal(fov: torch.Tensor): + return 0.5 / torch.tan(fov / 2) + + +def intrinsics_to_fov(intrinsics: torch.Tensor): + """ + Returns field of view in radians from normalized intrinsics matrix. + ### Parameters: + - intrinsics: torch.Tensor of shape (..., 3, 3) + + ### Returns: + - fov_x: torch.Tensor of shape (...) + - fov_y: torch.Tensor of shape (...) + """ + focal_x = intrinsics[..., 0, 0] + focal_y = intrinsics[..., 1, 1] + return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y) + + +def point_map_to_depth_legacy(points: torch.Tensor): + height, width = points.shape[-3:-1] + diagonal = (height ** 2 + width ** 2) ** 0.5 + uv = image_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2) + + # Solve least squares problem + b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2) + A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2) + + M = A.transpose(-2, -1) @ A + solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1) + focal, shift = solution.unbind(-1) + + depth = points[..., 2] + shift[..., None, None] + fov_x = torch.atan(width / diagonal / focal) * 2 + fov_y = torch.atan(height / diagonal / focal) * 2 + return depth, fov_x, fov_y, shift + + +def point_map_to_depth(points: torch.Tensor, mask: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)): + """ + Recover the depth map and FoV from a point map with unknown z shift and focal. + + Note that it assumes: + - the optical center is at the center of the map + - the map is undistorted + - the map is isometric in the x and y directions + + ### Parameters: + - `points: torch.Tensor` of shape (..., H, W, 3) + - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps. + + ### Returns: + - `depth: torch.Tensor` of shape (..., H, W) + - `fov_x: torch.Tensor` of shape (...) + - `fov_y: torch.Tensor` of shape (...) + - `shift: torch.Tensor` of shape (...), the z shift, making `depth = points[..., 2] + shift` + """ + shape = points.shape + height, width = points.shape[-3], points.shape[-2] + diagonal = (height ** 2 + width ** 2) ** 0.5 + + points = points.reshape(-1, *shape[-3:]) + mask = None if mask is None else mask.reshape(-1, *shape[-3:-1]) + uv = image_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2) + + points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1) + uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0) + mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0 + + uv_lr_np = uv_lr.cpu().numpy() + points_lr_np = points_lr.detach().cpu().numpy() + mask_lr_np = None if mask is None else mask_lr.cpu().numpy() + optim_shift, optim_focal = [], [] + for i in range(points.shape[0]): + points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]] + uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]] + optim_shift_i, optim_focal_i = solve_optimal_shift_focal(uv_lr_i_np, points_lr_i_np, ransac_iters=None) + optim_shift.append(float(optim_shift_i)) + optim_focal.append(float(optim_focal_i)) + optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype) + optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype) + + fov_x = 2 * torch.atan(width / diagonal / optim_focal) + fov_y = 2 * torch.atan(height / diagonal / optim_focal) + + depth = (points[..., 2] + optim_shift[:, None, None]).reshape(shape[:-1]) + fov_x = fov_x.reshape(shape[:-3]) + fov_y = fov_y.reshape(shape[:-3]) + optim_shift = optim_shift.reshape(shape[:-3]) + + return depth, fov_x, fov_y, optim_shift + + +def mask_aware_nearest_resize(mask: torch.BoolTensor, target_width: int, target_height: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]: + """ + Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map. + + ### Parameters + - `mask`: Input 2D mask of shape (..., H, W) + - `target_width`: target width of the resized map + - `target_height`: target height of the resized map + + ### Returns + - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension + - `target_mask`: Mask of the resized map of shape (..., target_height, target_width) + """ + height, width = mask.shape[-2:] + device = mask.device + filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width) + filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f) + filter_size = filter_h_i * filter_w_i + padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2) + + # Window the original mask and uv + uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device) + indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width) + padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device) + padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv + padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device) + padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask + padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device) + padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices + windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1)) + windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1)) + windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1)) + + # Gather the target pixels's local window + target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device) + target_corner = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device) + target_corner = torch.round(target_corner - 0.5).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device) + + target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size) + target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size) + target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size) + target_window_indices = target_window_indices.expand_as(target_window_mask) + + # Compute nearest neighbor in the local window for each pixel + dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size) + nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1) + nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width) + target_mask = torch.any(target_window_mask, dim=-1) + nearest_i, nearest_j = nearest_idx // width, nearest_idx % width + batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])] + + return (*batch_indices, nearest_i, nearest_j), target_mask + + \ No newline at end of file diff --git a/moge/utils/io.py b/moge/utils/io.py new file mode 100644 index 0000000000000000000000000000000000000000..396c954f35f8f4282b278d6504301c0e1bd9cae5 --- /dev/null +++ b/moge/utils/io.py @@ -0,0 +1,347 @@ +import os +os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1' +from typing import IO +import zipfile +import json +import io +from typing import * +from pathlib import Path +import re + +import numpy as np +import cv2 + +from .tools import timeit + + +LEGACY_SEGFORMER_CLASSES = [ + 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ', + 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth', + 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car', + 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug', + 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe', + 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column', + 'signboard', 'chest of drawers', 'counter', 'sand', 'sink', + 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path', + 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door', + 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table', + 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove', + 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar', + 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower', + 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver', + 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister', + 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van', + 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything', + 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent', + 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank', + 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake', + 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce', + 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen', + 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass', + 'clock', 'flag' +] +LEGACY_SEGFORMER_LABELS = {k: i for i, k in enumerate(LEGACY_SEGFORMER_CLASSES)} + + +def write_rgbd_zip( + file: Union[IO, os.PathLike], + image: Union[np.ndarray, bytes], + depth: Union[np.ndarray, bytes], mask: Union[np.ndarray, bytes], + segmentation_mask: Union[np.ndarray, bytes] = None, segmentation_labels: Union[Dict[str, int], bytes] = None, + intrinsics: np.ndarray = None, + normal: np.ndarray = None, normal_mask: np.ndarray = None, + meta: Union[Dict[str, Any], bytes] = None, + *, image_quality: int = 95, depth_type: Literal['linear', 'log', 'disparity'] = 'linear', depth_format: Literal['png', 'exr'] = 'png', depth_max_dynamic_range: float = 1e4, png_compression: int = 7 +): + """ + Write RGBD data as zip archive containing the image, depth, mask, segmentation_mask, and meta data. + In the zip file there will be: + - `meta.json`: The meta data as a JSON file. + - `image.jpg`: The RGB image as a JPEG file. + - `depth.png/exr`: The depth map as a PNG or EXR file, depending on the `depth_type`. + - `mask.png` (optional): The mask as a uint8 PNG file. + - `segmentation_mask.png` (optional): The segformer mask as a uint8/uint16 PNG file. + + You can provided those data as np.ndarray or bytes. If you provide them as np.ndarray, they will be properly processed and encoded. + If you provide them as bytes, they will be written as is, assuming they are already encoded. + """ + if meta is None: + meta = {} + elif isinstance(meta, bytes): + meta = json.loads(meta.decode()) + + if isinstance(image, bytes): + image_bytes = image + elif isinstance(image, np.ndarray): + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + image_bytes = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, image_quality])[1].tobytes() + + if isinstance(depth, bytes): + depth_bytes = depth + elif isinstance(depth, np.ndarray): + meta['depth_type'] = depth_type + if depth_type == 'linear': + if depth.dtype == np.float16: + depth_format = 'exr' + depth_bytes = cv2.imencode('.exr', depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])[1].tobytes() + elif np.issubdtype(depth.dtype, np.floating): + depth_format = 'exr' + depth_bytes = cv2.imencode('.exr', depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])[1].tobytes() + elif depth.dtype in [np.uint8, np.uint16]: + depth_format = 'png' + depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes() + elif depth_type == 'log': + depth_format = 'png' + depth = depth.astype(np.float32) + near = max(depth[mask].min(), 1e-3) + far = min(depth[mask].max(), near * depth_max_dynamic_range) + depth = ((np.log(depth.clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65535).astype(np.uint16) + depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes() + meta['depth_near'] = float(near) + meta['depth_far'] = float(far) + elif depth_type == 'disparity': + depth_format = 'png' + depth = depth.astype(np.float32) + depth = 1 / (depth + 1e-12) + depth = (depth / depth[mask].max()).clip(0, 1) + if np.unique(depth) < 200: + depth = (depth * 255).astype(np.uint8) + else: + depth = (depth * 65535).astype(np.uint16) + depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes() + + if isinstance(mask, bytes): + mask_bytes = mask + elif isinstance(mask, np.ndarray): + mask_bytes = cv2.imencode('.png', mask.astype(np.uint8) * 255)[1].tobytes() + + if segmentation_mask is not None: + if isinstance(segmentation_mask, bytes): + segmentation_mask_bytes = segmentation_mask + else: + segmentation_mask_bytes = cv2.imencode('.png', segmentation_mask)[1].tobytes() + assert segmentation_labels is not None, "You provided a segmentation mask, but not the corresponding labels." + if isinstance(segmentation_labels, bytes): + segmentation_labels = json.loads(segmentation_labels) + meta['segmentation_labels'] = segmentation_labels + + if intrinsics is not None: + meta['intrinsics'] = intrinsics.tolist() + + if normal is not None: + if isinstance(normal, bytes): + normal_bytes = normal + elif isinstance(normal, np.ndarray): + normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16) + normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR) + normal_bytes = cv2.imencode('.png', normal, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes() + if normal_mask is None: + normal_mask = np.ones(image.shape[:2], dtype=bool) + normal_mask_bytes = cv2.imencode('.png', normal_mask.astype(np.uint8) * 255)[1].tobytes() + + meta_bytes = meta if isinstance(meta, bytes) else json.dumps(meta).encode() + + with zipfile.ZipFile(file, 'w') as z: + z.writestr('meta.json', meta_bytes) + z.writestr('image.jpg', image_bytes) + z.writestr(f'depth.{depth_format}', depth_bytes) + z.writestr('mask.png', mask_bytes) + if segmentation_mask is not None: + z.writestr('segmentation_mask.png', segmentation_mask_bytes) + if normal is not None: + z.writestr('normal.png', normal_bytes) + z.writestr('normal_mask.png', normal_mask_bytes) + + +def read_rgbd_zip(file: Union[str, Path, IO], return_bytes: bool = False) -> Dict[str, Union[np.ndarray, Dict[str, Any], bytes]]: + """ + Read an RGBD zip file and return the image, depth, mask, segmentation_mask, intrinsics, and meta data. + + ### Parameters: + - `file: Union[str, Path, IO]` + The file path or file object to read from. + - `return_bytes: bool = False` + If True, return the image, depth, mask, and segmentation_mask as raw bytes. + + ### Returns: + - `Tuple[Dict[str, Union[np.ndarray, Dict[str, Any]]], Dict[str, bytes]]` + A dictionary containing: (If missing, the value will be None; if return_bytes is True, the value will be bytes) + - `image`: RGB numpy.ndarray of shape (H, W, 3). + - `depth`: float32 numpy.ndarray of shape (H, W). + - `mask`: bool numpy.ndarray of shape (H, W). + - `segformer_mask`: uint8 numpy.ndarray of shape (H, W). + - `intrinsics`: float32 numpy.ndarray of shape (3, 3). + - `meta`: Dict[str, Any]. + """ + # Load & extract archive + with zipfile.ZipFile(file, 'r') as z: + meta = z.read('meta.json') + if not return_bytes: + meta = json.loads(z.read('meta.json')) + + image = z.read('image.jpg') + if not return_bytes: + image = cv2.imdecode(np.frombuffer(z.read('image.jpg'), np.uint8), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + + depth_name = next(s for s in z.namelist() if s.startswith('depth')) + depth = z.read(depth_name) + if not return_bytes: + depth = cv2.imdecode(np.frombuffer(z.read(depth_name), np.uint8), cv2.IMREAD_UNCHANGED) + + if 'mask.png' in z.namelist(): + mask = z.read('mask.png') + if not return_bytes: + mask = cv2.imdecode(np.frombuffer(z.read('mask.png'), np.uint8), cv2.IMREAD_UNCHANGED) > 0 + else: + mask = None + + if 'segformer_mask.png' in z.namelist(): + # NOTE: Legacy support for segformer_mask.png + segmentation_mask = z.read('segformer_mask.png') + segmentation_labels = None + if not return_bytes: + segmentation_mask = cv2.imdecode(np.frombuffer(segmentation_mask, np.uint8), cv2.IMREAD_UNCHANGED) + segmentation_labels = LEGACY_SEGFORMER_LABELS + elif 'segmentation_mask.png' in z.namelist(): + segmentation_mask = z.read('segmentation_mask.png') + segmentation_labels = None + if not return_bytes: + segmentation_mask = cv2.imdecode(np.frombuffer(segmentation_mask, np.uint8), cv2.IMREAD_UNCHANGED) + segmentation_labels = meta['segmentation_labels'] + else: + segmentation_mask = None + segmentation_labels = None + + if 'normal.png' in z.namelist(): + normal = z.read('normal.png') + if not return_bytes: + normal = cv2.imdecode(np.frombuffer(z.read('normal.png'), np.uint8), cv2.IMREAD_UNCHANGED) + normal = cv2.cvtColor(normal, cv2.COLOR_BGR2RGB) + normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0] + normal = normal / np.linalg.norm(normal, axis=-1, keepdims=True) + + if 'normal_mask.png' in z.namelist(): + normal_mask = z.read('normal_mask.png') + normal_mask = cv2.imdecode(np.frombuffer(normal_mask, np.uint8), cv2.IMREAD_UNCHANGED) > 0 + else: + normal_mask = np.ones(image.shape[:2], dtype=bool) + else: + normal, normal_mask = None, None + + # recover linear depth + if not return_bytes: + if mask is None: + mask = np.ones(image.shape[:2], dtype=bool) + if meta['depth_type'] == 'linear': + depth = depth.astype(np.float32) + mask = mask & (depth > 0) + elif meta['depth_type'] == 'log': + near, far = meta['depth_near'], meta['depth_far'] + if depth.dtype == np.uint16: + depth = depth.astype(np.float32) / 65535 + elif depth.dtype == np.uint8: + depth = depth.astype(np.float32) / 255 + depth = near ** (1 - depth) * far ** depth + mask = mask & ~np.isnan(depth) + elif meta['depth_type'] == 'disparity': + mask = mask & (depth > 0) + if depth.dtype == np.uint16: + depth = depth.astype(np.float32) / 65535 + elif depth.dtype == np.uint8: + depth = depth.astype(np.float32) / 255 + depth = 1 / (depth + 1e-12) + + # intrinsics + if not return_bytes and 'intrinsics' in meta: + intrinsics = np.array(meta['intrinsics'], dtype=np.float32) + else: + intrinsics = None + + # depth unit + if not return_bytes and 'depth_unit' in meta: + depth_unit_str = meta['depth_unit'] + if r := re.match(r'([\d.]*)(\w*)', depth_unit_str): + digits, unit = r.groups() + depth_unit = float(digits or 1) * {'m': 1, 'cm': 0.01, 'mm': 0.001}[unit] + else: + depth_unit = None + else: + depth_unit = None + + return_dict = { + 'image': image, + 'depth': depth, + 'mask': mask, + 'segmentation_mask': segmentation_mask, + 'segmentation_labels': segmentation_labels, + 'normal': normal, + 'normal_mask': normal_mask, + 'intrinsics': intrinsics, + 'depth_unit': depth_unit, + 'meta': meta, + } + return_dict = {k: v for k, v in return_dict.items() if v is not None} + + return return_dict + +def write_rgbxyz(file: Union[IO, Path], image: np.ndarray, points: np.ndarray, mask: np.ndarray = None, image_quality: int = 95): + if isinstance(image, bytes): + image_bytes = image + elif isinstance(image, np.ndarray): + image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR) + image_bytes = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, image_quality])[1].tobytes() + + if isinstance(points, bytes): + points_bytes = points + elif isinstance(points, np.ndarray): + points_bytes = cv2.imencode('.exr', points.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])[1].tobytes() + + if mask is None: + mask = np.ones(image.shape[:2], dtype=bool) + if isinstance(mask, bytes): + mask_bytes = mask + elif isinstance(mask, np.ndarray): + mask_bytes = cv2.imencode('.png', mask.astype(np.uint8) * 255)[1].tobytes() + + is_archive = hasattr(file, 'write') or Path(file).suffix == '.zip' + if is_archive: + with zipfile.ZipFile(file, 'w') as z: + z.writestr('image.jpg', image_bytes) + z.writestr('points.exr', points_bytes) + if mask is not None: + z.writestr('mask.png', mask_bytes) + else: + file = Path(file) + file.mkdir(parents=True, exist_ok=True) + with open(file / 'image.jpg', 'wb') as f: + f.write(image_bytes) + with open(file / 'points.exr', 'wb') as f: + f.write(points_bytes) + if mask is not None: + with open(file / 'mask.png', 'wb') as f: + f.write(mask_bytes) + + +def read_rgbxyz(file: Union[IO, str, Path]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict[str, Any]]: + is_archive = hasattr(file, 'read') or Path(file).suffix == '.zip' + if is_archive: + with zipfile.ZipFile(file, 'r') as z: + image = cv2.imdecode(np.frombuffer(z.read('image.jpg'), np.uint8), cv2.IMREAD_COLOR) + points = cv2.imdecode(np.frombuffer(z.read('points.exr'), np.uint8), cv2.IMREAD_UNCHANGED) + if 'mask.png' in z.namelist(): + mask = cv2.imdecode(np.frombuffer(z.read('mask.png'), np.uint8), cv2.IMREAD_UNCHANGED) > 0 + else: + mask = np.ones(image.shape[:2], dtype=bool) + else: + file = Path(file) + file.mkdir(parents=True, exist_ok=True) + image = cv2.imread(str(file / 'image.jpg'), cv2.IMREAD_COLOR) + image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) + points = cv2.imread(str(file / 'points.exr'), cv2.IMREAD_UNCHANGED) + if (file /'mask.png').exists(): + mask = cv2.imread(str(file / 'mask.png'), cv2.IMREAD_UNCHANGED) > 0 + else: + mask = np.ones(image.shape[:2], dtype=bool) + + return image, points, mask diff --git a/moge/utils/pipeline.py b/moge/utils/pipeline.py new file mode 100644 index 0000000000000000000000000000000000000000..f652595c8b840bc672b8a3033b0694aa0986c4bf --- /dev/null +++ b/moge/utils/pipeline.py @@ -0,0 +1,503 @@ +from typing import * +from abc import abstractmethod +from queue import Empty, Full +from threading import Thread +from queue import Queue +from multiprocessing import Process +from threading import Thread, Event +import multiprocessing +import threading +import inspect +import time +import uuid +from copy import deepcopy +import itertools +import functools + +__all__ = [ + 'Node', + 'Link', + 'ConcurrentNode', + 'Worker', + 'WorkerFunction', + 'Provider', + 'ProviderFunction', + 'Sequential', + 'Batch', + 'Unbatch', + 'Parallel', + 'Graph', + 'Buffer', +] + +TERMINATE_CHECK_INTERVAL = 0.5 + + +class _ItemWrapper: + def __init__(self, data: Any, id: Union[int, List[int]] = None): + self.data = data + self.id = id + + +class Terminate(Exception): + pass + + +def _get_queue_item(queue: Queue, terminate_flag: Event, timeout: float = None) -> _ItemWrapper: + while True: + try: + item: _ItemWrapper = queue.get(block=True, timeout=TERMINATE_CHECK_INTERVAL if timeout is None else min(timeout, TERMINATE_CHECK_INTERVAL)) + if terminate_flag.is_set(): + raise Terminate() + return item + except Empty: + if terminate_flag.is_set(): + raise Terminate() + + if timeout is not None: + timeout -= TERMINATE_CHECK_INTERVAL + if timeout <= 0: + raise Empty() + + +def _put_queue_item(queue: Queue, item: _ItemWrapper, terminate_flag: Event): + while True: + try: + queue.put(item, block=True, timeout=TERMINATE_CHECK_INTERVAL) + if terminate_flag.is_set(): + raise Terminate() + return + except Full: + if terminate_flag.is_set(): + raise Terminate() + +class Node: + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + self.input: Queue = Queue(maxsize=in_buffer_size) + self.output: Queue = Queue(maxsize=out_buffer_size) + self.in_buffer_size = in_buffer_size + self.out_buffer_size = out_buffer_size + + @abstractmethod + def start(self): + pass + + @abstractmethod + def terminate(self): + pass + + def stop(self): + self.terminate() + self.join() + + @abstractmethod + def join(self): + pass + + def put(self, data: Any, key: str = None, block: bool = True) -> None: + item = _ItemWrapper(data) + self.input.put(item, block=block) + + def get(self, key: str = None, block: bool = True) -> Any: + item: _ItemWrapper = self.output.get(block=block) + return item.data + + def __enter__(self): + self.start() + return self + + def __exit__(self, exc_type, exc_value, traceback): + self.terminate() + self.join() + + +class ConcurrentNode(Node): + job: Union[Thread, Process] + + def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + super().__init__(in_buffer_size, out_buffer_size) + self.running_as = running_as + + @abstractmethod + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + pass + + def start(self): + if self.running_as == 'thread': + terminate_flag = threading.Event() + job = Thread(target=self._loop_fn, args=(self.input, self.output, terminate_flag)) + elif self.running_as == 'process': + terminate_flag = multiprocessing.Event() + job = Process(target=self._loop_fn, args=(self.input, self.output, terminate_flag)) + job.start() + self.job = job + self.terminate_flag = terminate_flag + + def terminate(self): + self.terminate_flag.set() + + def join(self): + self.job.join() + + +class Worker(ConcurrentNode): + def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 0, out_buffer_size: int = 0) -> None: + super().__init__(running_as, in_buffer_size, out_buffer_size) + + def init(self) -> None: + """ + This method is called the the thread is started, to initialize any resources that is only held in the thread. + """ + pass + + @abstractmethod + def work(self, *args, **kwargs) -> Union[Any, Dict[str, Any]]: + """ + This method defines the job that the node should do for each input item. + A item obtained from the input queue is passed as arguments to this method, and the result is placed in the output queue. + The method is executed concurrently with other nodes. + """ + pass + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + self.init() + try: + while True: + item = _get_queue_item(input, terminate_flag) + result = self.work(item.data) + _put_queue_item(output, _ItemWrapper(result, item.id), terminate_flag) + + except Terminate: + return + + +class Provider(ConcurrentNode): + """ + A node that provides data to successive nodes. It takes no input and provides data to the output queue. + """ + def __init__(self, running_as: Literal['thread', 'process'], out_buffer_size: int = 1) -> None: + super().__init__(running_as, 0, out_buffer_size) + + def init(self) -> None: + """ + This method is called the the thread or process is started, to initialize any resources that is only held in the thread or process. + """ + pass + + @abstractmethod + def provide(self) -> Generator[Any, None, None]: + pass + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + self.init() + try: + for data in self.provide(): + _put_queue_item(output, _ItemWrapper(data), terminate_flag) + except Terminate: + return + + +class WorkerFunction(Worker): + def __init__(self, fn: Callable, running_as: 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None: + super().__init__(running_as, in_buffer_size, out_buffer_size) + self.fn = fn + + def work(self, *args, **kwargs): + return self.fn(*args, **kwargs) + + +class ProviderFunction(Provider): + def __init__(self, fn: Callable, running_as: 'thread', out_buffer_size: int = 1) -> None: + super().__init__(running_as, out_buffer_size) + self.fn = fn + + def provide(self): + for item in self.fn(): + yield item + + +class Link: + def __init__(self, src: Queue, dst: Queue): + self.src = src + self.dst = dst + + def _thread_fn(self): + try: + while True: + item = _get_queue_item(self.src, self.terminate_flag) + _put_queue_item(self.dst, item, self.terminate_flag) + except Terminate: + return + + def start(self): + self.terminate_flag = threading.Event() + self.thread = Thread(target=self._thread_fn) + self.thread.start() + + def terminate(self): + self.terminate_flag.set() + + def join(self): + self.thread.join() + + +class Graph(Node): + """ + Graph pipeline of nodes and links + """ + nodes: List[Node] + links: List[Link] + + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1): + super().__init__(in_buffer_size, out_buffer_size) + self.nodes = [] + self.links = [] + + def add(self, node: Node): + self.nodes.append(node) + + def link(self, src: Union[Node, Tuple[Node, str]], dst: Union[Node, Tuple[Node, str]]): + """ + Links the output of the source node to the input of the destination node. + If the source or destination node is None, the pipeline's input or output is used. + """ + src_queue = self.input if src is None else src.output + dst_queue = self.output if dst is None else dst.input + self.links.append(Link(src_queue, dst_queue)) + + def chain(self, nodes: Iterable[Node]): + """ + Link the output of each node to the input of the next node. + """ + nodes = list(nodes) + for i in range(len(nodes) - 1): + self.link(nodes[i], nodes[i + 1]) + + def start(self): + for node in self.nodes: + node.start() + for link in self.links: + link.start() + + def terminate(self): + for node in self.nodes: + node.terminate() + for link in self.links: + link.terminate() + + def join(self): + for node in self.nodes: + node.join() + for link in self.links: + link.join() + + def __iter__(self): + providers = [node for node in self.nodes if isinstance(node, Provider)] + if len(providers) == 0: + raise ValueError("No provider node found in the pipeline. If you want to iterate over the pipeline, the pipeline must be driven by a provider node.") + with self: + # while all(provider.job.is_alive() for provider in providers): + while True: + yield self.get() + + def __call__(self, data: Any) -> Any: + """ + Submit data to the pipeline's input queue, and return the output data asynchronously. + NOTE: The pipeline must be streamed (i.e., every output item is uniquely associated with an input item) for this to work. + """ + # TODO + + +class Sequential(Graph): + """ + Pipeline of nodes in sequential order, where each node takes the output of the previous node as input. + The order of input and output items is preserved (FIFO) + """ + def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1): + """ + Initialize the pipeline with a list of nodes to execute sequentially. + ### Parameters: + - nodes: List of nodes or functions to execute sequentially. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes. + - function_running_as: Whether to wrap the function as a thread or process worker. Default is 'thread'. + - in_buffer_size: Maximum size of the input queue of the pipeline. Default is 0 (unlimited). + - out_buffer_size: Maximum size of the output queue of the pipeline. Default is 0 (unlimited). + """ + super().__init__(in_buffer_size, out_buffer_size) + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.add(node) + self.chain([None, *self.nodes, None]) + + +class Parallel(Node): + """ + A FIFO node that runs multiple nodes in parallel to process the input items. Each input item is handed to one of the nodes whoever is available. + NOTE: It is FIFO if and only if all the nested nodes are FIFO. + """ + nodes: List[Node] + + def __init__(self, nodes: Iterable[Node], in_buffer_size: int = 1, out_buffer_size: int = 1, function_running_as: Literal['thread', 'process'] = 'thread'): + super().__init__(in_buffer_size, out_buffer_size) + self.nodes = [] + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.nodes.append(node) + self.output_order = Queue() + self.lock = threading.Lock() + + def _in_thread_fn(self, node: Node): + try: + while True: + with self.lock: + # A better idea: first make sure its node is vacant, then get it a new item. + # Currently we will not be able to know which node is busy util there is at least one item already waiting in the queue of the node. + # This could lead to suboptimal scheduling. + item = _get_queue_item(self.input, self.terminate_flag) + self.output_order.put(node.output) + _put_queue_item(node.input, item, self.terminate_flag) + except Terminate: + return + + def _out_thread_fn(self): + try: + while True: + queue = _get_queue_item(self.output_order, self.terminate_flag) + item = _get_queue_item(queue, self.terminate_flag) + _put_queue_item(self.output, item, self.terminate_flag) + except Terminate: + return + + def start(self): + self.terminate_flag = threading.Event() + self.in_threads = [] + for node in self.nodes: + thread = Thread(target=self._in_thread_fn, args=(node,)) + thread.start() + self.in_threads.append(thread) + thread = Thread(target=self._out_thread_fn) + thread.start() + self.out_thread = thread + for node in self.nodes: + node.start() + + def terminate(self): + self.terminate_flag.set() + for node in self.nodes: + node.terminate() + + def join(self): + for thread in self.in_threads: + thread.join() + self.out_thread.join() + + +class UnorderedParallel(Graph): + """ + Pipeline of nodes in parallel, where each input item is handed to one of the nodes whoever is available. + NOTE: The order of the output items is NOT guaranteed to be the same as the input items, depending on how fast the nodes handle their input. + """ + def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1): + """ + Initialize the pipeline with a list of nodes to execute in parallel. If a function is given, it is wrapped in a worker node. + ### Parameters: + - nodes: List of nodes or functions to execute in parallel. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes. + - function_running_as: Whether to wrap the function as a thread or process worker. Default is 'thread'. + - in_buffer_size: Maximum size of the input queue of the pipeline. Default is 0 (unlimited). + - out_buffer_size: Maximum size of the output queue of the pipeline. Default is 0 (unlimited). + """ + super().__init__(in_buffer_size, out_buffer_size) + for node in nodes: + if isinstance(node, Node): + pass + elif isinstance(node, Callable): + if inspect.isgeneratorfunction(node): + node = ProviderFunction(node, function_running_as) + else: + node = WorkerFunction(node, function_running_as) + else: + raise ValueError(f"Invalid node type: {type(node)}") + self.add(node) + for i in range(len(nodes)): + self.chain([None, self.nodes[i], None]) + + +class Batch(ConcurrentNode): + """ + Groups every `batch_size` items into a batch (a list of items) and passes the batch to successive nodes. + The `patience` parameter specifies the maximum time to wait for a batch to be filled before sending it to the next node, + i.e., when the earliest item in the batch is out of `patience` seconds, the batch is sent regardless of its size. + """ + def __init__(self, batch_size: int, patience: float = None, in_buffer_size: int = 1, out_buffer_size: int = 1): + assert batch_size > 0, "Batch size must be greater than 0." + super().__init__('thread', in_buffer_size, out_buffer_size) + self.batch_size = batch_size + self.patience = patience + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + try: + while True: + batch_id, batch_data = [], [] + # Try to fill the batch + for i in range(self.batch_size): + if i == 0 or self.patience is None: + timeout = None + else: + timeout = self.patience - (time.time() - earliest_time) + if timeout < 0: + break + try: + item = _get_queue_item(input, terminate_flag, timeout) + except Empty: + break + + if i == 0: + earliest_time = time.time() + batch_data.append(item.data) + batch_id.append(item.id) + + batch = _ItemWrapper(batch_data, batch_id) + _put_queue_item(output, batch, terminate_flag) + except Terminate: + return + + +class Unbatch(ConcurrentNode): + """ + Ungroups every batch (a list of items) into individual items and passes them to successive nodes. + """ + def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1): + super().__init__('thread', in_buffer_size, out_buffer_size) + + def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event): + try: + while True: + batch = _get_queue_item(input, terminate_flag) + for id, data in zip(batch.id or itertools.repeat(None), batch.data): + item = _ItemWrapper(data, id) + _put_queue_item(output, item, terminate_flag) + except Terminate: + return + + +class Buffer(Node): + "A FIFO node that buffers items in a queue. Usefull achieve better temporal balance when its successor node has a variable processing time." + def __init__(self, size: int): + super().__init__(size, size) + self.size = size + self.input = self.output = Queue(maxsize=size) \ No newline at end of file diff --git a/moge/utils/tools.py b/moge/utils/tools.py new file mode 100644 index 0000000000000000000000000000000000000000..4659462450dd946388cfd9638b3b6b17cc03ba0d --- /dev/null +++ b/moge/utils/tools.py @@ -0,0 +1,240 @@ +from typing import * +import time +from pathlib import Path +from numbers import Number + + +def catch_exception(fn): + def wrapper(*args, **kwargs): + try: + return fn(*args, **kwargs) + except Exception as e: + import traceback + print(f"Exception in {fn.__name__}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})") + traceback.print_exc(chain=False) + time.sleep(0.1) + return None + return wrapper + + +class CallbackOnException: + def __init__(self, callback: Callable, exception: type): + self.exception = exception + self.callback = callback + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_val, exc_tb): + if isinstance(exc_val, self.exception): + self.callback() + return True + return False + +def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]: + for k, v in d.items(): + if isinstance(v, dict): + for sub_key in traverse_nested_dict_keys(v): + yield (k, ) + sub_key + else: + yield (k, ) + + +def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None): + for k in keys: + d = d.get(k, default) + if d is None: + break + return d + +def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any): + for k in keys[:-1]: + d = d.setdefault(k, {}) + d[keys[-1]] = value + + +def key_average(list_of_dicts: list) -> Dict[str, Any]: + """ + Returns a dictionary with the average value of each key in the input list of dictionaries. + """ + _nested_dict_keys = set() + for d in list_of_dicts: + _nested_dict_keys.update(traverse_nested_dict_keys(d)) + _nested_dict_keys = sorted(_nested_dict_keys) + result = {} + for k in _nested_dict_keys: + values = [ + get_nested_dict(d, k) for d in list_of_dicts + if get_nested_dict(d, k) is not None + ] + avg = sum(values) / len(values) if values else float('nan') + set_nested_dict(result, k, avg) + return result + + +def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]: + """ + Flattens a nested dictionary into a single-level dictionary, with keys as tuples. + """ + items = [] + if parent_key is None: + parent_key = () + for k, v in d.items(): + new_key = parent_key + (k, ) + if isinstance(v, MutableMapping): + items.extend(flatten_nested_dict(v, new_key).items()) + else: + items.append((new_key, v)) + return dict(items) + + +def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]: + """ + Unflattens a single-level dictionary into a nested dictionary, with keys as tuples. + """ + result = {} + for k, v in d.items(): + sub_dict = result + for k_ in k[:-1]: + if k_ not in sub_dict: + sub_dict[k_] = {} + sub_dict = sub_dict[k_] + sub_dict[k[-1]] = v + return result + + +def read_jsonl(file): + import json + with open(file, 'r') as f: + data = f.readlines() + return [json.loads(line) for line in data] + + +def write_jsonl(data: List[dict], file): + import json + with open(file, 'w') as f: + for item in data: + f.write(json.dumps(item) + '\n') + + +def save_metrics(save_path: Union[str, Path], all_metrics: Dict[str, List[Dict]]): + import pandas as pd + import json + + with open(save_path, 'w') as f: + json.dump(all_metrics, f, indent=4) + + +def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]): + import pandas as pd + data = [flatten_nested_dict(d) for d in data] + df = pd.DataFrame(data) + df = df.sort_index(axis=1) + df.columns = pd.MultiIndex.from_tuples(df.columns) + return df + + +def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]): + if isinstance(d, str): + for old, new in mapping.items(): + d = d.replace(old, new) + elif isinstance(d, list): + for i, item in enumerate(d): + d[i] = recursive_replace(item, mapping) + elif isinstance(d, dict): + for k, v in d.items(): + d[k] = recursive_replace(v, mapping) + return d + + +class timeit: + _history: Dict[str, List['timeit']] = {} + + def __init__(self, name: str = None, verbose: bool = True, multiple: bool = False): + self.name = name + self.verbose = verbose + self.start = None + self.end = None + self.multiple = multiple + if multiple and name not in timeit._history: + timeit._history[name] = [] + + def __call__(self, func: Callable): + import inspect + if inspect.iscoroutinefunction(func): + async def wrapper(*args, **kwargs): + with timeit(self.name or func.__qualname__): + ret = await func(*args, **kwargs) + return ret + return wrapper + else: + def wrapper(*args, **kwargs): + with timeit(self.name or func.__qualname__): + ret = func(*args, **kwargs) + return ret + return wrapper + + def __enter__(self): + self.start = time.time() + + @property + def time(self) -> float: + assert self.start is not None, "Time not yet started." + assert self.end is not None, "Time not yet ended." + return self.end - self.start + + @property + def history(self) -> List['timeit']: + return timeit._history.get(self.name, []) + + def __exit__(self, exc_type, exc_val, exc_tb): + self.end = time.time() + if self.multiple: + timeit._history[self.name].append(self) + if self.verbose: + if self.multiple: + avg = sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name]) + print(f"{self.name or 'It'} took {avg} seconds in average.") + else: + print(f"{self.name or 'It'} took {self.time} seconds.") + + +def strip_common_prefix_suffix(strings: List[str]) -> List[str]: + first = strings[0] + + for start in range(len(first)): + if any(s[start] != strings[0][start] for s in strings): + break + + for end in range(1, min(len(s) for s in strings)): + if any(s[-end] != first[-end] for s in strings): + break + + return [s[start:len(s) - end + 1] for s in strings] + + +def multithead_execute(inputs: List[Any], num_workers: int, pbar = None): + from concurrent.futures import ThreadPoolExecutor + from contextlib import nullcontext + from tqdm import tqdm + + if pbar is not None: + pbar.total = len(inputs) if hasattr(inputs, '__len__') else None + else: + pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None) + + def decorator(fn: Callable): + with ( + ThreadPoolExecutor(max_workers=num_workers) as executor, + pbar + ): + pbar.refresh() + @catch_exception + def _fn(input): + ret = fn(input) + pbar.update() + return ret + executor.map(_fn, inputs) + executor.shutdown(wait=True) + + return decorator \ No newline at end of file diff --git a/moge/utils/vis.py b/moge/utils/vis.py new file mode 100644 index 0000000000000000000000000000000000000000..a85945ce10adfbf29bdbd95fb9ad765082b3e4df --- /dev/null +++ b/moge/utils/vis.py @@ -0,0 +1,51 @@ +import numpy as np +import matplotlib + + +def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: + if mask is None: + depth = np.where(depth > 0, depth, np.nan) + else: + depth = np.where((depth > 0) & mask, depth, np.nan) + disp = 1 / depth + if normalize: + min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.999) + disp = (disp - min_disp) / (max_disp - min_disp) + colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp), 0) + colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3] + return colored + + +def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray: + if mask is not None: + depth = np.where(mask, depth, np.nan) + + min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999) + depth = (depth - min_depth) / (max_depth - min_depth) + colored = np.nan_to_num(matplotlib.colormaps[cmap](depth), 0) + colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3] + return colored + + +def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray: + if mask is not None: + disparity = np.where(mask, disparity, np.nan) + + if normalize: + min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999) + disparity = (disparity - min_disp) / (max_disp - min_disp) + colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity), 0) + colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3] + return colored + + +def colorize_segmentation(segmentation: np.ndarray, cmap: str = 'Set1') -> np.ndarray: + colored = matplotlib.colormaps[cmap]((segmentation % 20) / 20) + colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3] + return colored + + +def colorize_normal(normal: np.ndarray) -> np.ndarray: + normal = normal * [0.5, -0.5, -0.5] + 0.5 + normal = (normal.clip(0, 1) * 255).astype(np.uint8) + return normal \ No newline at end of file diff --git a/moge/utils/webfile.py b/moge/utils/webfile.py new file mode 100644 index 0000000000000000000000000000000000000000..1e98abf8413e1c9f408849b74f4d2025d25511b6 --- /dev/null +++ b/moge/utils/webfile.py @@ -0,0 +1,73 @@ +import requests +from typing import * + +__all__ = ["WebFile"] + + +class WebFile: + def __init__(self, url: str, session: Optional[requests.Session] = None, headers: Optional[Dict[str, str]] = None, size: Optional[int] = None): + self.url = url + self.session = session or requests.Session() + self.session.headers.update(headers or {}) + self._offset = 0 + self.size = size if size is not None else self._fetch_size() + + def _fetch_size(self): + with self.session.get(self.url, stream=True) as response: + response.raise_for_status() + content_length = response.headers.get("Content-Length") + if content_length is None: + raise ValueError("Missing Content-Length in header") + return int(content_length) + + def _fetch_data(self, offset: int, n: int) -> bytes: + headers = {"Range": f"bytes={offset}-{min(offset + n - 1, self.size)}"} + response = self.session.get(self.url, headers=headers) + response.raise_for_status() + return response.content + + def seekable(self) -> bool: + return True + + def tell(self) -> int: + return self._offset + + def available(self) -> int: + return self.size - self._offset + + def seek(self, offset: int, whence: int = 0) -> None: + if whence == 0: + new_offset = offset + elif whence == 1: + new_offset = self._offset + offset + elif whence == 2: + new_offset = self.size + offset + else: + raise ValueError("Invalid value for whence") + + self._offset = max(0, min(new_offset, self.size)) + + def read(self, n: Optional[int] = None) -> bytes: + if n is None or n < 0: + n = self.available() + else: + n = min(n, self.available()) + + if n == 0: + return b'' + + data = self._fetch_data(self._offset, n) + self._offset += len(data) + + return data + + def close(self) -> None: + pass + + def __enter__(self): + return self + + def __exit__(self, exc_type, exc_value, traceback): + pass + + \ No newline at end of file diff --git a/moge/utils/webzipfile.py b/moge/utils/webzipfile.py new file mode 100644 index 0000000000000000000000000000000000000000..25ed1d3cd34720335eb001d77a278539ffef569b --- /dev/null +++ b/moge/utils/webzipfile.py @@ -0,0 +1,128 @@ +from typing import * +import io +import os +from zipfile import ( + ZipInfo, BadZipFile, ZipFile, ZipExtFile, + sizeFileHeader, structFileHeader, stringFileHeader, + _FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS, + _MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED +) +import struct +from requests import Session + +from .webfile import WebFile + + +class _SharedWebFile(WebFile): + def __init__(self, webfile: WebFile, pos: int): + super().__init__(webfile.url, webfile.session, size=webfile.size) + self.seek(pos) + + +class WebZipFile(ZipFile): + "Lock-free version of ZipFile that reads from a WebFile, allowing for concurrent reads." + def __init__(self, url: str, session: Optional[Session] = None, headers: Optional[Dict[str, str]] = None): + """Open the ZIP file with mode read 'r', write 'w', exclusive create 'x', + or append 'a'.""" + webf = WebFile(url, session=session, headers=headers) + super().__init__(webf, mode='r') + + def open(self, name, mode="r", pwd=None, *, force_zip64=False): + """Return file-like object for 'name'. + + name is a string for the file name within the ZIP file, or a ZipInfo + object. + + mode should be 'r' to read a file already in the ZIP file, or 'w' to + write to a file newly added to the archive. + + pwd is the password to decrypt files (only used for reading). + + When writing, if the file size is not known in advance but may exceed + 2 GiB, pass force_zip64 to use the ZIP64 format, which can handle large + files. If the size is known in advance, it is best to pass a ZipInfo + instance for name, with zinfo.file_size set. + """ + if mode not in {"r", "w"}: + raise ValueError('open() requires mode "r" or "w"') + if pwd and (mode == "w"): + raise ValueError("pwd is only supported for reading files") + if not self.fp: + raise ValueError( + "Attempt to use ZIP archive that was already closed") + + assert mode == "r", "Only read mode is supported for now" + + # Make sure we have an info object + if isinstance(name, ZipInfo): + # 'name' is already an info object + zinfo = name + elif mode == 'w': + zinfo = ZipInfo(name) + zinfo.compress_type = self.compression + zinfo._compresslevel = self.compresslevel + else: + # Get info object for name + zinfo = self.getinfo(name) + + if mode == 'w': + return self._open_to_write(zinfo, force_zip64=force_zip64) + + if self._writing: + raise ValueError("Can't read from the ZIP file while there " + "is an open writing handle on it. " + "Close the writing handle before trying to read.") + + # Open for reading: + self._fileRefCnt += 1 + zef_file = _SharedWebFile(self.fp, zinfo.header_offset) + + try: + # Skip the file header: + fheader = zef_file.read(sizeFileHeader) + if len(fheader) != sizeFileHeader: + raise BadZipFile("Truncated file header") + fheader = struct.unpack(structFileHeader, fheader) + if fheader[_FH_SIGNATURE] != stringFileHeader: + raise BadZipFile("Bad magic number for file header") + + fname = zef_file.read(fheader[_FH_FILENAME_LENGTH]) + if fheader[_FH_EXTRA_FIELD_LENGTH]: + zef_file.seek(fheader[_FH_EXTRA_FIELD_LENGTH], whence=1) + + if zinfo.flag_bits & _MASK_COMPRESSED_PATCH: + # Zip 2.7: compressed patched data + raise NotImplementedError("compressed patched data (flag bit 5)") + + if zinfo.flag_bits & _MASK_STRONG_ENCRYPTION: + # strong encryption + raise NotImplementedError("strong encryption (flag bit 6)") + + if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & _MASK_UTF_FILENAME: + # UTF-8 filename + fname_str = fname.decode("utf-8") + else: + fname_str = fname.decode(self.metadata_encoding or "cp437") + + if fname_str != zinfo.orig_filename: + raise BadZipFile( + 'File name in directory %r and header %r differ.' + % (zinfo.orig_filename, fname)) + + # check for encrypted flag & handle password + is_encrypted = zinfo.flag_bits & _MASK_ENCRYPTED + if is_encrypted: + if not pwd: + pwd = self.pwd + if pwd and not isinstance(pwd, bytes): + raise TypeError("pwd: expected bytes, got %s" % type(pwd).__name__) + if not pwd: + raise RuntimeError("File %r is encrypted, password " + "required for extraction" % name) + else: + pwd = None + + return ZipExtFile(zef_file, mode, zinfo, pwd, True) + except: + zef_file.close() + raise \ No newline at end of file diff --git a/packages.txt b/packages.txt new file mode 100644 index 0000000000000000000000000000000000000000..d4458ef18f1d7c83facf3d8f9653b23a58947437 --- /dev/null +++ b/packages.txt @@ -0,0 +1 @@ +python3-opencv \ No newline at end of file diff --git a/requirements.txt b/requirements.txt new file mode 100644 index 0000000000000000000000000000000000000000..0bbafa8a66beb75039a039fa30f3e13279d87847 --- /dev/null +++ b/requirements.txt @@ -0,0 +1,5 @@ +opencv-python +plyfile +pygltflib +transformers +scikit-learn \ No newline at end of file diff --git a/utils3d/__init__.py b/utils3d/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..737cda0b37c0dc4cc4e83165a6fabb40d28cc975 --- /dev/null +++ b/utils3d/__init__.py @@ -0,0 +1,14 @@ +""" +A package for common utility functions in 3D computer graphics and vision. Providing NumPy utilities in `utils3d.numpy`, PyTorch utilities in `utils3d.torch`, and IO utilities in `utils3d.io`. +""" +import importlib + +__all__ = ['numpy', 'torch', 'io'] + +def __getattr__(module_name: str): + return importlib.import_module(f'.{module_name}', __package__) + +if __name__ == '__main__': + from . import torch + from . import numpy + from . import io \ No newline at end of file diff --git a/utils3d/io/__init__.py b/utils3d/io/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..209a65543c4fc4367687a58d0e20dfa84b9ec7df --- /dev/null +++ b/utils3d/io/__init__.py @@ -0,0 +1,4 @@ +from .wavefront_obj import * +from .colmap import * +from .ply import * +from .glb import * \ No newline at end of file diff --git a/utils3d/io/colmap.py b/utils3d/io/colmap.py new file mode 100644 index 0000000000000000000000000000000000000000..9f3993fc68eb11e644a97f8341d5dd96e4afb56b --- /dev/null +++ b/utils3d/io/colmap.py @@ -0,0 +1,139 @@ +from typing import * +from pathlib import Path + +import numpy as np +from scipy.spatial.transform import Rotation + + +__all__ = ['read_extrinsics_from_colmap', 'read_intrinsics_from_colmap', 'write_extrinsics_as_colmap', 'write_intrinsics_as_colmap'] + + +def write_extrinsics_as_colmap(file: Union[str, Path], extrinsics: np.ndarray, image_names: Union[str, List[str]] = 'image_{i:04d}.png', camera_ids: List[int] = None): + """ + Write extrinsics to colmap `images.txt` file. + Args: + file: Path to `images.txt` file. + extrinsics: (N, 4, 4) array of extrinsics. + image_names: str or List of str, image names. Length is N. + If str, it should be a format string with `i` as the index. (i starts from 1, in correspondence with IMAGE_ID in colmap) + camera_ids: List of int, camera ids. Length is N. + If None, it will be set to [1, 2, ..., N]. + """ + assert extrinsics.shape[1:] == (4, 4) and extrinsics.ndim == 3 or extrinsics.shape == (4, 4) + if extrinsics.ndim == 2: + extrinsics = extrinsics[np.newaxis, ...] + quats = Rotation.from_matrix(extrinsics[:, :3, :3]).as_quat() + trans = extrinsics[:, :3, 3] + if camera_ids is None: + camera_ids = list(range(1, len(extrinsics) + 1)) + if isinstance(image_names, str): + image_names = [image_names.format(i=i) for i in range(1, len(extrinsics) + 1)] + assert len(extrinsics) == len(image_names) == len(camera_ids), \ + f'Number of extrinsics ({len(extrinsics)}), image_names ({len(image_names)}), and camera_ids ({len(camera_ids)}) must be the same' + with open(file, 'w') as fp: + print("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME", file=fp) + for i, (quat, t, name, camera_id) in enumerate(zip(quats.tolist(), trans.tolist(), image_names, camera_ids)): + # Colmap has wxyz order while scipy.spatial.transform.Rotation has xyzw order. Haha, wcnm. + qx, qy, qz, qw = quat + tx, ty, tz = t + print(f'{i + 1} {qw:f} {qx:f} {qy:f} {qz:f} {tx:f} {ty:f} {tz:f} {camera_id:d} {name}', file=fp) + print() + + +def write_intrinsics_as_colmap(file: Union[str, Path], intrinsics: np.ndarray, width: int, height: int, normalized: bool = False): + """ + Write intrinsics to colmap `cameras.txt` file. Currently only support PINHOLE model (no distortion) + Args: + file: Path to `cameras.txt` file. + intrinsics: (N, 3, 3) array of intrinsics. + width: Image width. + height: Image height. + normalized: Whether the intrinsics are normalized. If True, the intrinsics will unnormalized for writing. + """ + assert intrinsics.shape[1:] == (3, 3) and intrinsics.ndim == 3 or intrinsics.shape == (3, 3) + if intrinsics.ndim == 2: + intrinsics = intrinsics[np.newaxis, ...] + if normalized: + intrinsics = intrinsics * np.array([width, height, 1])[:, None] + with open(file, 'w') as fp: + print("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]", file=fp) + for i, intr in enumerate(intrinsics): + fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2] + print(f'{i + 1} PINHOLE {width:d} {height:d} {fx:f} {fy:f} {cx:f} {cy:f}', file=fp) + + +def read_extrinsics_from_colmap(file: Union[str, Path]) -> Union[np.ndarray, List[int], List[str]]: + """ + Read extrinsics from colmap `images.txt` file. + Args: + file: Path to `images.txt` file. + Returns: + extrinsics: (N, 4, 4) array of extrinsics. + camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1. + image_names: List of str, image names. Length is N. + """ + with open(file) as fp: + lines = fp.readlines() + image_names, quats, trans, camera_ids = [], [], [], [] + i_line = 0 + for line in lines: + line = line.strip() + if line.startswith('#'): + continue + i_line += 1 + if i_line % 2 == 0: + continue + image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name = line.split() + quats.append([float(qx), float(qy), float(qz), float(qw)]) + trans.append([float(tx), float(ty), float(tz)]) + camera_ids.append(int(camera_id)) + image_names.append(name) + + quats = np.array(quats, dtype=np.float32) + trans = np.array(trans, dtype=np.float32) + rotation = Rotation.from_quat(quats).as_matrix() + extrinsics = np.concatenate([ + np.concatenate([rotation, trans[..., None]], axis=-1), + np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :].repeat(len(quats), axis=0) + ], axis=-2) + + return extrinsics, camera_ids, image_names + + +def read_intrinsics_from_colmap(file: Union[str, Path], normalize: bool = False) -> Tuple[List[int], np.ndarray, np.ndarray]: + """ + Read intrinsics from colmap `cameras.txt` file. + Args: + file: Path to `cameras.txt` file. + normalize: Whether to normalize the intrinsics. If True, the intrinsics will be normalized. (mapping coordinates to [0, 1] range) + Returns: + camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1. + intrinsics: (N, 3, 3) array of intrinsics. + distortions: (N, 5) array of distortions. + """ + with open(file) as fp: + lines = fp.readlines() + intrinsics, distortions, camera_ids = [], [], [] + for line in lines: + line = line.strip() + if not line or line.startswith('#'): + continue + camera_id, model, width, height, *params = line.split() + camera_id, width, height = int(camera_id), int(width), int(height) + if model == 'PINHOLE': + fx, fy, cx, cy = map(float, params[:4]) + k1 = k2 = k3 = p1 = p2 = 0.0 + elif model == 'OPENCV': + fx, fy, cx, cy, k1, k2, p1, p2, k3 = *map(float, params[:8]), 0.0 + elif model == 'SIMPLE_RADIAL': + f, cx, cy, k = map(float, params[:4]) + fx = fy = f + k1, k2, p1, p2, k3 = k, 0.0, 0.0, 0.0, 0.0 + camera_ids.append(camera_id) + if normalize: + fx, fy, cx, cy = fx / width, fy / height, cx / width, cy / height + intrinsics.append([[fx, 0, cx], [0, fy, cy], [0, 0, 1]]) + distortions.append([k1, k2, p1, p2, k3]) + intrinsics = np.array(intrinsics, dtype=np.float32) + distortions = np.array(distortions, dtype=np.float32) + return camera_ids, intrinsics, distortions diff --git a/utils3d/io/glb.py b/utils3d/io/glb.py new file mode 100644 index 0000000000000000000000000000000000000000..1595c3c1d1a152809a791f09f28a3f57ee5843cf --- /dev/null +++ b/utils3d/io/glb.py @@ -0,0 +1,105 @@ +from typing import * +from pathlib import Path + +import numpy as np + + +def write_glb(path: Union[str, Path], vertices: np.ndarray, faces: np.ndarray, vertex_colors: np.ndarray = None, uv: np.ndarray = None): + import pygltflib + + has_colors = vertex_colors is not None + has_uv = uv is not None + + triangles_bytes = faces.astype(np.uint32).flatten().tobytes() + vertices_bytes = vertices.astype(np.float32).tobytes() + vertex_colors_bytes = vertex_colors.astype(np.float32).tobytes() if has_colors else None + uv_bytes = uv.astype(np.float32).tobytes() if has_uv else None + + + gltf = pygltflib.GLTF2( + scene=0, + scenes=[pygltflib.Scene(nodes=[0])], + nodes=[pygltflib.Node(mesh=0)], + meshes=[ + pygltflib.Mesh( + primitives=[ + pygltflib.Primitive( + attributes=pygltflib.Attributes( + POSITION=1, + COLOR_0=2 if has_colors else None, + TEXCOORD_0=2 + has_colors if has_uv else None + ), + indices=0 + ) + ] + ) + ], + accessors=list(filter(None, [ + pygltflib.Accessor( # triangles accessor + bufferView=0, + componentType=pygltflib.UNSIGNED_INT, + count=faces.size, + type=pygltflib.SCALAR, + max=[int(faces.max())], + min=[int(faces.min())], + ), + pygltflib.Accessor( # vertices accessor + bufferView=1, + componentType=pygltflib.FLOAT, + count=len(vertices), + type=pygltflib.VEC3, + max=vertices.max(axis=0).tolist(), + min=vertices.min(axis=0).tolist(), + ), + pygltflib.Accessor( # vertex colors accessor + bufferView=2, + componentType=pygltflib.FLOAT, + count=len(vertices), + type=pygltflib.VEC3, + max=vertex_colors.max(axis=0).tolist(), + min=vertex_colors.min(axis=0).tolist(), + ) if has_colors else None, + pygltflib.Accessor( # uv accessor + bufferView=3, + componentType=pygltflib.FLOAT, + count=len(uv), + type=pygltflib.VEC2, + max=uv.max(axis=0).tolist(), + min=uv.min(axis=0).tolist(), + ) if has_uv else None, + ])), + bufferViews=list(filter(None, [ + pygltflib.BufferView( # triangles buffer view + buffer=0, + byteLength=len(triangles_bytes), + target=pygltflib.ELEMENT_ARRAY_BUFFER, + ), + pygltflib.BufferView( # vertices buffer view + buffer=0, + byteOffset=len(triangles_bytes), + byteLength=len(vertices_bytes), + target=pygltflib.ARRAY_BUFFER, + ), + pygltflib.BufferView( # vertex colors buffer view + buffer=0, + byteOffset=len(triangles_bytes) + len(vertices_bytes), + byteLength=len(vertex_colors_bytes), + target=pygltflib.ARRAY_BUFFER, + ) if has_colors else None, + pygltflib.BufferView( # uv buffer view + buffer=0, + byteOffset=len(triangles_bytes) + len(vertices_bytes) + (len(vertex_colors_bytes) if has_colors else 0), + byteLength=len(uv_bytes), + target=pygltflib.ARRAY_BUFFER, + ) if has_uv else None, + ])), + buffers=[ + pygltflib.Buffer( + byteLength=len(triangles_bytes) + len(vertices_bytes) + (len(vertex_colors_bytes) if has_colors else 0) + (len(uv_bytes) if has_uv else 0), + ) + ] + ) + gltf.set_binary_blob(triangles_bytes + vertices_bytes + (vertex_colors_bytes or b'') + (uv_bytes or b'')) + with open(path, 'wb') as f: + for chunk in gltf.save_to_bytes(): + f.write(chunk) \ No newline at end of file diff --git a/utils3d/io/ply.py b/utils3d/io/ply.py new file mode 100644 index 0000000000000000000000000000000000000000..aea4d1fdf5e339810778cf4cd7c749b89e239f0a --- /dev/null +++ b/utils3d/io/ply.py @@ -0,0 +1,104 @@ +import numpy as np + +from typing import * +from pathlib import Path + + +def read_ply( + file: Union[str, Path], + encoding: Union[str, None] = None, + ignore_unknown: bool = False + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Read .ply file, without preprocessing. + + Args: + file (Any): filepath + encoding (str, optional): + + Returns: + Tuple[np.ndarray, np.ndarray]: vertices, faces + """ + import plyfile + plydata = plyfile.PlyData.read(file) + vertices = np.stack([plydata['vertex'][k] for k in ['x', 'y', 'z']], axis=-1) + if 'face' in plydata: + faces = np.array(plydata['face']['vertex_indices'].tolist()) + else: + faces = None + return vertices, faces + + +def write_ply( + file: Union[str, Path], + vertices: np.ndarray, + faces: np.ndarray = None, + edges: np.ndarray = None, + vertex_colors: np.ndarray = None, + edge_colors: np.ndarray = None, + text: bool = False +): + """ + Write .ply file, without preprocessing. + + Args: + file (Any): filepath + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, E] + edges (np.ndarray): [E, 2] + vertex_colors (np.ndarray, optional): [N, 3]. Defaults to None. + edge_colors (np.ndarray, optional): [E, 3]. Defaults to None. + text (bool, optional): save data in text format. Defaults to False. + """ + import plyfile + assert vertices.ndim == 2 and vertices.shape[1] == 3 + vertices = vertices.astype(np.float32) + if faces is not None: + assert faces.ndim == 2 + faces = faces.astype(np.int32) + if edges is not None: + assert edges.ndim == 2 and edges.shape[1] == 2 + edges = edges.astype(np.int32) + + if vertex_colors is not None: + assert vertex_colors.ndim == 2 and vertex_colors.shape[1] == 3 + if vertex_colors.dtype in [np.float32, np.float64]: + vertex_colors = vertex_colors * 255 + vertex_colors = np.clip(vertex_colors, 0, 255).astype(np.uint8) + vertices_data = np.zeros(len(vertices), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) + vertices_data['x'] = vertices[:, 0] + vertices_data['y'] = vertices[:, 1] + vertices_data['z'] = vertices[:, 2] + vertices_data['red'] = vertex_colors[:, 0] + vertices_data['green'] = vertex_colors[:, 1] + vertices_data['blue'] = vertex_colors[:, 2] + else: + vertices_data = np.array([tuple(v) for v in vertices], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')]) + + if faces is not None: + faces_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (faces.shape[1],))]) + faces_data['vertex_indices'] = faces + + if edges is not None: + if edge_colors is not None: + assert edge_colors.ndim == 2 and edge_colors.shape[1] == 3 + if edge_colors.dtype in [np.float32, np.float64]: + edge_colors = edge_colors * 255 + edge_colors = np.clip(edge_colors, 0, 255).astype(np.uint8) + edges_data = np.zeros(len(edges), dtype=[('vertex1', 'i4'), ('vertex2', 'i4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')]) + edges_data['vertex1'] = edges[:, 0] + edges_data['vertex2'] = edges[:, 1] + edges_data['red'] = edge_colors[:, 0] + edges_data['green'] = edge_colors[:, 1] + edges_data['blue'] = edge_colors[:, 2] + else: + edges_data = np.array([tuple(e) for e in edges], dtype=[('vertex1', 'i4'), ('vertex2', 'i4')]) + + ply_data = [plyfile.PlyElement.describe(vertices_data, 'vertex')] + if faces is not None: + ply_data.append(plyfile.PlyElement.describe(faces_data, 'face')) + if edges is not None: + ply_data.append(plyfile.PlyElement.describe(edges_data, 'edge')) + + plyfile.PlyData(ply_data, text=text).write(file) + \ No newline at end of file diff --git a/utils3d/io/wavefront_obj.py b/utils3d/io/wavefront_obj.py new file mode 100644 index 0000000000000000000000000000000000000000..3471e490bd758fbf58173cfb7297ec747f46f173 --- /dev/null +++ b/utils3d/io/wavefront_obj.py @@ -0,0 +1,146 @@ +from io import TextIOWrapper +from typing import Dict, Any, Union, Iterable +import numpy as np +from pathlib import Path + +__all__ = [ + 'read_obj', + 'write_obj', + 'simple_write_obj' +] + +def read_obj( + file : Union[str, Path, TextIOWrapper], + encoding: Union[str, None] = None, + ignore_unknown: bool = False +): + """ + Read wavefront .obj file, without preprocessing. + + Why bothering having this read_obj() while we already have other libraries like `trimesh`? + This function read the raw format from .obj file and keeps the order of vertices and faces, + while trimesh which involves modification like merge/split vertices, which could break the orders of vertices and faces, + Those libraries are commonly aiming at geometry processing and rendering supporting various formats. + If you want mesh geometry processing, you may turn to `trimesh` for more features. + + ### Parameters + `file` (str, Path, TextIOWrapper): filepath or file object + encoding (str, optional): + + ### Returns + obj (dict): A dict containing .obj components + { + 'mtllib': [], + 'v': [[0,1, 0.2, 1.0], [1.2, 0.0, 0.0], ...], + 'vt': [[0.5, 0.5], ...], + 'vn': [[0., 0.7, 0.7], [0., -0.7, 0.7], ...], + 'f': [[0, 1, 2], [2, 3, 4],...], + 'usemtl': [{'name': 'mtl1', 'f': 7}] + } + """ + if hasattr(file,'read'): + lines = file.read().splitlines() + else: + with open(file, 'r', encoding=encoding) as fp: + lines = fp.read().splitlines() + mtllib = [] + v, vt, vn, vp = [], [], [], [] # Vertex coordinates, Vertex texture coordinate, Vertex normal, Vertex parameter + f, ft, fn = [], [], [] # Face indices, Face texture indices, Face normal indices + o = [] + s = [] + usemtl = [] + + def pad(l: list, n: Any): + return l + [n] * (3 - len(l)) + + for i, line in enumerate(lines): + sq = line.strip().split() + if len(sq) == 0: + continue + if sq[0] == 'v': + assert 4 <= len(sq) <= 5, f'Invalid format of line {i}: {line}' + v.append([float(e) for e in sq[1:]][:3]) + elif sq[0] == 'vt': + assert 3 <= len(sq) <= 4, f'Invalid format of line {i}: {line}' + vt.append([float(e) for e in sq[1:]][:2]) + elif sq[0] == 'vn': + assert len(sq) == 4, f'Invalid format of line {i}: {line}' + vn.append([float(e) for e in sq[1:]]) + elif sq[0] == 'vp': + assert 2 <= len(sq) <= 4, f'Invalid format of line {i}: {line}' + vp.append(pad([float(e) for e in sq[1:]], 0)) + elif sq[0] == 'f': + spliting = [pad([int(j) - 1 for j in e.split('/')], -1) for e in sq[1:]] + f.append([e[0] for e in spliting]) + ft.append([e[1] for e in spliting]) + fn.append([e[2] for e in spliting]) + elif sq[0] == 'usemtl': + assert len(sq) == 2 + usemtl.append((sq[1], len(f))) + elif sq[0] == 'o': + assert len(sq) == 2 + o.append((sq[1], len(f))) + elif sq[0] == 's': + s.append((sq[1], len(f))) + elif sq[0] == 'mtllib': + assert len(sq) == 2 + mtllib.append(sq[1]) + elif sq[0][0] == '#': + continue + else: + if not ignore_unknown: + raise Exception(f'Unknown keyword {sq[0]}') + + min_poly_vertices = min(len(f) for f in f) + max_poly_vertices = max(len(f) for f in f) + + return { + 'mtllib': mtllib, + 'v': np.array(v, dtype=np.float32), + 'vt': np.array(vt, dtype=np.float32), + 'vn': np.array(vn, dtype=np.float32), + 'vp': np.array(vp, dtype=np.float32), + 'f': np.array(f, dtype=np.int32) if min_poly_vertices == max_poly_vertices else f, + 'ft': np.array(ft, dtype=np.int32) if min_poly_vertices == max_poly_vertices else ft, + 'fn': np.array(fn, dtype=np.int32) if min_poly_vertices == max_poly_vertices else fn, + 'o': o, + 's': s, + 'usemtl': usemtl, + } + + +def write_obj( + file: Union[str, Path], + obj: Dict[str, Any], + encoding: Union[str, None] = None + ): + with open(file, 'w', encoding=encoding) as fp: + for k in ['v', 'vt', 'vn', 'vp']: + if k not in obj: + continue + for v in obj[k]: + print(k, *map(float, v), file=fp) + for f in obj['f']: + print('f', *((str('/').join(map(int, i)) if isinstance(int(i), Iterable) else i) for i in f), file=fp) + + +def simple_write_obj( + file: Union[str, Path], + vertices: np.ndarray, + faces: np.ndarray, + encoding: Union[str, None] = None + ): + """ + Write wavefront .obj file, without preprocessing. + + Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + file (Any): filepath + encoding (str, optional): + """ + with open(file, 'w', encoding=encoding) as fp: + for v in vertices: + print('v', *map(float, v), file=fp) + for f in faces: + print('f', *map(int, f + 1), file=fp) diff --git a/utils3d/numpy/__init__.py b/utils3d/numpy/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..f6188c6c0f104b70161c736e0ac517f515d531b3 --- /dev/null +++ b/utils3d/numpy/__init__.py @@ -0,0 +1,135 @@ +""" +3D utility functions workings with NumPy. +""" +import importlib +import itertools +import numpy + + +__modules_all__ = { + 'mesh':[ + 'triangulate', + 'compute_face_normal', + 'compute_face_angle', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'remove_corrupted_faces', + 'merge_duplicate_vertices', + 'remove_unreferenced_vertices', + 'subdivide_mesh_simple', + 'mesh_relations', + 'flatten_mesh_indices' + ], + 'quadmesh': [ + 'calc_quad_candidates', + 'calc_quad_distortion', + 'calc_quad_direction', + 'calc_quad_smoothness', + 'sovle_quad', + 'sovle_quad_qp', + 'tri_to_quad' + ], + 'utils': [ + 'sliding_window_1d', + 'sliding_window_nd', + 'sliding_window_2d', + 'max_pool_1d', + 'max_pool_2d', + 'max_pool_nd', + 'depth_edge', + 'depth_aliasing', + 'interpolate', + 'image_scrcoord', + 'image_uv', + 'image_pixel_center', + 'image_pixel', + 'image_mesh', + 'image_mesh_from_depth', + 'depth_to_normal', + 'point_to_normal', + 'chessboard', + 'cube', + 'square', + 'camera_frustum', + ], + 'transforms': [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'perspective_to_near_far', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'unproject_cv', + 'unproject_gl', + 'project_cv', + 'project_gl', + 'quaternion_to_matrix', + 'axis_angle_to_matrix', + 'matrix_to_quaternion', + 'extrinsics_to_essential', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'ray_intersection', + 'se3_matrix', + 'slerp_quaternion', + 'slerp_vector', + 'lerp', + 'lerp_se3_matrix', + 'piecewise_lerp', + 'piecewise_lerp_se3_matrix', + 'apply_transform' + ], + 'spline': [ + 'linear_spline_interpolate', + ], + 'rasterization': [ + 'RastContext', + 'rasterize_triangle_faces', + 'rasterize_edges', + 'texture', + 'warp_image_by_depth', + ], +} + + +__all__ = list(itertools.chain(*__modules_all__.values())) + +def __getattr__(name): + try: + return globals()[name] + except KeyError: + pass + + try: + module_name = next(m for m in __modules_all__ if name in __modules_all__[m]) + except StopIteration: + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + module = importlib.import_module(f'.{module_name}', __name__) + for key in __modules_all__[module_name]: + globals()[key] = getattr(module, key) + + return globals()[name] + + +if __name__ == '__main__': + from .quadmesh import * + from .transforms import * + from .mesh import * + from .utils import * + from .rasterization import * + from .spline import * \ No newline at end of file diff --git a/utils3d/numpy/_helpers.py b/utils3d/numpy/_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..db489b5a0a44596e879af6dec3683038b5933da8 --- /dev/null +++ b/utils3d/numpy/_helpers.py @@ -0,0 +1,88 @@ +# decorator +import numpy as np +from numbers import Number +import inspect + + +def get_args_order(func, args, kwargs): + """ + Get the order of the arguments of a function. + """ + names = inspect.getfullargspec(func).args + names_idx = {name: i for i, name in enumerate(names)} + args_order = [] + kwargs_order = {} + for name, arg in kwargs.items(): + if name in names: + kwargs_order[name] = names_idx[name] + names.remove(name) + for i, arg in enumerate(args): + if i < len(names): + args_order.append(names_idx[names[i]]) + return args_order, kwargs_order + + +def broadcast_args(args, kwargs, args_dim, kwargs_dim): + spatial = [] + for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())): + if isinstance(arg, np.ndarray) and arg_dim is not None: + arg_spatial = arg.shape[:arg.ndim-arg_dim] + if len(arg_spatial) > len(spatial): + spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial + for j in range(len(arg_spatial)): + if spatial[-j] < arg_spatial[-j]: + if spatial[-j] == 1: + spatial[-j] = arg_spatial[-j] + else: + raise ValueError("Cannot broadcast arguments.") + for i, arg in enumerate(args): + if isinstance(arg, np.ndarray) and args_dim[i] is not None: + args[i] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]]) + for key, arg in kwargs.items(): + if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None: + kwargs[key] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + return args, kwargs, spatial + + +def batched(*dims): + """ + Decorator that allows a function to be called with batched arguments. + """ + def decorator(func): + def wrapper(*args, **kwargs): + args = list(args) + # get arguments dimensions + args_order, kwargs_order = get_args_order(func, args, kwargs) + args_dim = [dims[i] for i in args_order] + kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()} + # convert to numpy array + for i, arg in enumerate(args): + if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None: + args[i] = np.array(arg) + for key, arg in kwargs.items(): + if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None: + kwargs[key] = np.array(arg) + # broadcast arguments + args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim) + for i, (arg, arg_dim) in enumerate(zip(args, args_dim)): + if isinstance(arg, np.ndarray) and arg_dim is not None: + args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]]) + for key, arg in kwargs.items(): + if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None: + kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + # call function + results = func(*args, **kwargs) + type_results = type(results) + results = list(results) if isinstance(results, (tuple, list)) else [results] + # restore spatial dimensions + for i, result in enumerate(results): + results[i] = result.reshape([*spatial, *result.shape[1:]]) + if type_results == tuple: + results = tuple(results) + elif type_results == list: + results = list(results) + else: + results = results[0] + return results + return wrapper + return decorator \ No newline at end of file diff --git a/utils3d/numpy/mesh.py b/utils3d/numpy/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..afadb5f2510b58a1c5acbabff2ff798c041744d6 --- /dev/null +++ b/utils3d/numpy/mesh.py @@ -0,0 +1,355 @@ +import numpy as np +from typing import * +from ._helpers import batched + + +__all__ = [ + 'triangulate', + 'compute_face_normal', + 'compute_face_angle', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'remove_corrupted_faces', + 'merge_duplicate_vertices', + 'remove_unreferenced_vertices', + 'subdivide_mesh_simple', + 'mesh_relations', + 'flatten_mesh_indices' +] + + +def triangulate( + faces: np.ndarray, + vertices: np.ndarray = None, + backslash: np.ndarray = None +) -> np.ndarray: + """ + Triangulate a polygonal mesh. + + Args: + faces (np.ndarray): [L, P] polygonal faces + vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices. + If given, the triangulation is performed according to the distance + between vertices. Defaults to None. + backslash (np.ndarray, optional): [L] boolean array indicating + how to triangulate the quad faces. Defaults to None. + + Returns: + (np.ndarray): [L * (P - 2), 3] triangular faces + """ + if faces.shape[-1] == 3: + return faces + P = faces.shape[-1] + if vertices is not None: + assert faces.shape[-1] == 4, "now only support quad mesh" + if backslash is None: + backslash = np.linalg.norm(vertices[faces[:, 0]] - vertices[faces[:, 2]], axis=-1) < \ + np.linalg.norm(vertices[faces[:, 1]] - vertices[faces[:, 3]], axis=-1) + if backslash is None: + loop_indice = np.stack([ + np.zeros(P - 2, dtype=int), + np.arange(1, P - 1, 1, dtype=int), + np.arange(2, P, 1, dtype=int) + ], axis=1) + return faces[:, loop_indice].reshape((-1, 3)) + else: + assert faces.shape[-1] == 4, "now only support quad mesh" + faces = np.where( + backslash[:, None], + faces[:, [0, 1, 2, 0, 2, 3]], + faces[:, [0, 1, 3, 3, 1, 2]] + ).reshape((-1, 3)) + return faces + + +@batched(2, None) +def compute_face_normal( + vertices: np.ndarray, + faces: np.ndarray +) -> np.ndarray: + """ + Compute face normals of a triangular mesh + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + normals (np.ndarray): [..., T, 3] face normals + """ + normal = np.cross( + vertices[..., faces[:, 1], :] - vertices[..., faces[:, 0], :], + vertices[..., faces[:, 2], :] - vertices[..., faces[:, 0], :] + ) + normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True) + normal_norm[normal_norm == 0] = 1 + normal /= normal_norm + return normal + + +@batched(2, None) +def compute_face_angle( + vertices: np.ndarray, + faces: np.ndarray, + eps: float = 1e-12 + ) -> np.ndarray: + """ + Compute face angles of a triangular mesh + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + angles (np.ndarray): [..., T, 3] face angles + """ + face_angle = np.zeros_like(faces, dtype=vertices.dtype) + for i in range(3): + edge1 = vertices[..., faces[:, (i + 1) % 3], :] - vertices[..., faces[:, i], :] + edge2 = vertices[..., faces[:, (i + 2) % 3], :] - vertices[..., faces[:, i], :] + face_angle[..., i] = np.arccos(np.sum( + edge1 / np.clip(np.linalg.norm(edge1, axis=-1, keepdims=True), eps, None) * + edge2 / np.clip(np.linalg.norm(edge2, axis=-1, keepdims=True), eps, None), + axis=-1 + )) + return face_angle + + +@batched(2, None, 2) +def compute_vertex_normal( + vertices: np.ndarray, + faces: np.ndarray, + face_normal: np.ndarray = None +) -> np.ndarray: + """ + Compute vertex normals of a triangular mesh by averaging neightboring face normals + TODO: can be improved. + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + face_normal (np.ndarray, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (np.ndarray): [..., N, 3] vertex normals + """ + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + vertex_normal = np.zeros_like(vertices, dtype=vertices.dtype) + for n in range(vertices.shape[0]): + for i in range(3): + vertex_normal[n, :, 0] += np.bincount(faces[:, i], weights=face_normal[n, :, 0], minlength=vertices.shape[1]) + vertex_normal[n, :, 1] += np.bincount(faces[:, i], weights=face_normal[n, :, 1], minlength=vertices.shape[1]) + vertex_normal[n, :, 2] += np.bincount(faces[:, i], weights=face_normal[n, :, 2], minlength=vertices.shape[1]) + vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True) + vertex_normal_norm[vertex_normal_norm == 0] = 1 + vertex_normal /= vertex_normal_norm + return vertex_normal + + +@batched(2, None, 2) +def compute_vertex_normal_weighted( + vertices: np.ndarray, + faces: np.ndarray, + face_normal: np.ndarray = None +) -> np.ndarray: + """ + Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals + according to the angles + + Args: + vertices (np.ndarray): [..., N, 3] 3-dimensional vertices + faces (np.ndarray): [..., T, 3] triangular face indices + face_normal (np.ndarray, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (np.ndarray): [..., N, 3] vertex normals + """ + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + face_angle = compute_face_angle(vertices, faces) + vertex_normal = np.zeros_like(vertices) + for n in range(vertices.shape[0]): + for i in range(3): + vertex_normal[n, :, 0] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 0] * face_angle[n, :, i], minlength=vertices.shape[1]) + vertex_normal[n, :, 1] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 1] * face_angle[n, :, i], minlength=vertices.shape[1]) + vertex_normal[n, :, 2] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 2] * face_angle[n, :, i], minlength=vertices.shape[1]) + vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True) + vertex_normal_norm[vertex_normal_norm == 0] = 1 + vertex_normal /= vertex_normal_norm + return vertex_normal + + +def remove_corrupted_faces( + faces: np.ndarray + ) -> np.ndarray: + """ + Remove corrupted faces (faces with duplicated vertices) + + Args: + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + np.ndarray: [T_, 3] triangular face indices + """ + corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0]) + return faces[~corrupted] + + +def merge_duplicate_vertices( + vertices: np.ndarray, + faces: np.ndarray, + tol: float = 1e-6 + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Merge duplicate vertices of a triangular mesh. + Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + tol (float, optional): tolerance for merging. Defaults to 1e-6. + + Returns: + vertices (np.ndarray): [N_, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + """ + vertices_round = np.round(vertices / tol) + _, uni_i, uni_inv = np.unique(vertices_round, return_index=True, return_inverse=True, axis=0) + vertices = vertices[uni_i] + faces = uni_inv[faces] + return vertices, faces + + +def remove_unreferenced_vertices( + faces: np.ndarray, + *vertice_attrs, + return_indices: bool = False +) -> Tuple[np.ndarray, ...]: + """ + Remove unreferenced vertices of a mesh. + Unreferenced vertices are removed, and the face indices are updated accordingly. + + Args: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + + Returns: + faces (np.ndarray): [T, P] face indices + *vertice_attrs: vertex attributes + indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None. + """ + P = faces.shape[-1] + fewer_indices, inv_map = np.unique(faces, return_inverse=True) + faces = inv_map.astype(np.int32).reshape(-1, P) + ret = [faces] + for attr in vertice_attrs: + ret.append(attr[fewer_indices]) + if return_indices: + ret.append(fewer_indices) + return tuple(ret) + + +def subdivide_mesh_simple( + vertices: np.ndarray, + faces: np.ndarray, + n: int = 1 +) -> Tuple[np.ndarray, np.ndarray]: + """ + Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles. + NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + n (int, optional): number of subdivisions. Defaults to 1. + + Returns: + vertices (np.ndarray): [N_, 3] subdivided 3-dimensional vertices + faces (np.ndarray): [4 * T, 3] subdivided triangular face indices + """ + for _ in range(n): + edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0) + edges = np.sort(edges, axis=2) + uni_edges, uni_inv = np.unique(edges.reshape(-1, 2), return_inverse=True, axis=0) + uni_inv = uni_inv.reshape(3, -1) + midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2 + + n_vertices = vertices.shape[0] + vertices = np.concatenate([vertices, midpoints], axis=0) + faces = np.concatenate([ + np.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1), + np.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1), + np.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1), + np.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1), + ], axis=0) + return vertices, faces + + +def mesh_relations( + faces: np.ndarray, +) -> Tuple[np.ndarray, np.ndarray]: + """ + Calculate the relation between vertices and faces. + NOTE: The input mesh must be a manifold triangle mesh. + + Args: + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + edges (np.ndarray): [E, 2] edge indices + edge2face (np.ndarray): [E, 2] edge to face relation. The second column is -1 if the edge is boundary. + face2edge (np.ndarray): [T, 3] face to edge relation + face2face (np.ndarray): [T, 3] face to face relation + """ + T = faces.shape[0] + edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=1).reshape(-1, 2) # [3T, 2] + edges = np.sort(edges, axis=1) # [3T, 2] + edges, face2edge, occurence = np.unique(edges, axis=0, return_inverse=True, return_counts=True) # [E, 2], [3T], [E] + E = edges.shape[0] + assert np.all(occurence <= 2), "The input mesh is not a manifold mesh." + + # Edge to face relation + padding = np.arange(E, dtype=np.int32)[occurence == 1] + padded_face2edge = np.concatenate([face2edge, padding], axis=0) # [2E] + edge2face = np.argsort(padded_face2edge, kind='stable').reshape(-1, 2) // 3 # [E, 2] + edge2face_valid = edge2face[:, 1] < T # [E] + edge2face[~edge2face_valid, 1] = -1 + + # Face to edge relation + face2edge = face2edge.reshape(-1, 3) # [T, 3] + + # Face to face relation + face2face = edge2face[face2edge] # [T, 3, 2] + face2face = face2face[face2face != np.arange(T)[:, None, None]].reshape(T, 3) # [T, 3] + + return edges, edge2face, face2edge, face2face + + +@overload +def flatten_mesh_indices(faces1: np.ndarray, attr1: np.ndarray, *other_faces_attrs_pairs: np.ndarray) -> Tuple[np.ndarray, ...]: + """ + Rearrange the indices of a mesh to a flattened version. Vertices will be no longer shared. + + ### Parameters: + - `faces1`: [T, P] face indices of the first attribute + - `attr1`: [N1, ...] attributes of the first mesh + - ... + + ### Returns: + - `faces`: [T, P] flattened face indices, contigous from 0 to T * P - 1 + - `attr1`: [T * P, ...] attributes of the first mesh, where every P values correspond to a face + _ ... + """ +def flatten_mesh_indices(*args: np.ndarray) -> Tuple[np.ndarray, ...]: + assert len(args) % 2 == 0, "The number of arguments must be even." + T, P = args[0].shape + assert all(arg.shape[0] == T and arg.shape[1] == P for arg in args[::2]), "The faces must have the same shape." + attr_flat = [] + for faces_, attr_ in zip(args[::2], args[1::2]): + attr_flat_ = attr_[faces_].reshape(-1, *attr_.shape[1:]) + attr_flat.append(attr_flat_) + faces_flat = np.arange(T * P, dtype=np.int32).reshape(T, P) + return faces_flat, *attr_flat \ No newline at end of file diff --git a/utils3d/numpy/quadmesh.py b/utils3d/numpy/quadmesh.py new file mode 100644 index 0000000000000000000000000000000000000000..de20dbba4d6fb78e67ad655c21c2b7d136216ca4 --- /dev/null +++ b/utils3d/numpy/quadmesh.py @@ -0,0 +1,472 @@ +import numpy as np +import scipy as sp +import scipy.optimize as spopt +import piqp +from typing import * + + +__all__ = [ + 'calc_quad_candidates', + 'calc_quad_distortion', + 'calc_quad_direction', + 'calc_quad_smoothness', + 'sovle_quad', + 'sovle_quad_qp', + 'tri_to_quad' +] + + +def calc_quad_candidates( + edges: np.ndarray, + face2edge: np.ndarray, + edge2face: np.ndarray, +): + """ + Calculate the candidate quad faces. + + Args: + edges (np.ndarray): [E, 2] edge indices + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + + Returns: + quads (np.ndarray): [Q, 4] quad candidate indices + quad2edge (np.ndarray): [Q, 4] edge to quad candidate relation + quad2adj (np.ndarray): [Q, 8] adjacent quad candidates of each quad candidate + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + """ + E = edges.shape[0] + T = face2edge.shape[0] + + quads_valid = edge2face[:, 1] != -1 + Q = quads_valid.sum() + quad2face = edge2face[quads_valid] # [Q, 2] + quad2edge = face2edge[quad2face] # [Q, 2, 3] + flag = quad2edge == np.arange(E)[quads_valid][:, None, None] # [Q, 2, 3] + flag = flag.argmax(axis=-1) # [Q, 2] + quad2edge = np.stack([ + quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 1) % 3], + quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 2) % 3], + ], axis=-1).reshape(Q, 4) # [Q, 4] + + quads = np.concatenate([ + np.where( + (edges[quad2edge[:, 0:1], 1:] == edges[quad2edge[:, 1:2], :]).any(axis=-1), + edges[quad2edge[:, 0:1], [[0, 1]]], + edges[quad2edge[:, 0:1], [[1, 0]]], + ), + np.where( + (edges[quad2edge[:, 2:3], 1:] == edges[quad2edge[:, 3:4], :]).any(axis=-1), + edges[quad2edge[:, 2:3], [[0, 1]]], + edges[quad2edge[:, 2:3], [[1, 0]]], + ), + ], axis=1) # [Q, 4] + + quad2adj = edge2face[quad2edge] # [Q, 4, 2] + quad2adj = quad2adj[quad2adj != quad2face[:, [0,0,1,1], None]].reshape(Q, 4) # [Q, 4] + quad2adj_valid = quad2adj != -1 + quad2adj = face2edge[quad2adj] # [Q, 4, 3] + quad2adj[~quad2adj_valid, 0] = quad2edge[~quad2adj_valid] + quad2adj[~quad2adj_valid, 1:] = -1 + quad2adj = quad2adj[quad2adj != quad2edge[..., None]].reshape(Q, 8) # [Q, 8] + edge_valid = -np.ones(E, dtype=np.int32) + edge_valid[quads_valid] = np.arange(Q) + quad2adj_valid = quad2adj != -1 + quad2adj[quad2adj_valid] = edge_valid[quad2adj[quad2adj_valid]] # [Q, 8] + + return quads, quad2edge, quad2adj, quads_valid + + +def calc_quad_distortion( + vertices: np.ndarray, + quads: np.ndarray, +): + """ + Calculate the distortion of each candidate quad face. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + quads (np.ndarray): [Q, 4] quad face indices + + Returns: + distortion (np.ndarray): [Q] distortion of each quad face + """ + edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3] + edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3] + edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3] + edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3] + cross = vertices[quads[:, 0]] - vertices[quads[:, 2]] # [Q, 3] + + len0 = np.maximum(np.linalg.norm(edge0, axis=-1), 1e-10) # [Q] + len1 = np.maximum(np.linalg.norm(edge1, axis=-1), 1e-10) # [Q] + len2 = np.maximum(np.linalg.norm(edge2, axis=-1), 1e-10) # [Q] + len3 = np.maximum(np.linalg.norm(edge3, axis=-1), 1e-10) # [Q] + len_cross = np.maximum(np.linalg.norm(cross, axis=-1), 1e-10) # [Q] + + angle0 = np.arccos(np.clip(np.sum(-edge0 * edge1, axis=-1) / (len0 * len1), -1, 1)) # [Q] + angle1 = np.arccos(np.clip(np.sum(-edge1 * cross, axis=-1) / (len1 * len_cross), -1, 1)) \ + + np.arccos(np.clip(np.sum(cross * edge2, axis=-1) / (len_cross * len2), -1, 1)) # [Q] + angle2 = np.arccos(np.clip(np.sum(-edge2 * edge3, axis=-1) / (len2 * len3), -1, 1)) # [Q] + angle3 = np.arccos(np.clip(np.sum(-edge3 * -cross, axis=-1) / (len3 * len_cross), -1, 1)) \ + + np.arccos(np.clip(np.sum(-cross * edge0, axis=-1) / (len_cross * len0), -1, 1)) # [Q] + + normal0 = np.cross(edge0, edge1) # [Q, 3] + normal1 = np.cross(edge2, edge3) # [Q, 3] + normal0 = normal0 / np.maximum(np.linalg.norm(normal0, axis=-1, keepdims=True), 1e-10) # [Q, 3] + normal1 = normal1 / np.maximum(np.linalg.norm(normal1, axis=-1, keepdims=True), 1e-10) # [Q, 3] + angle_normal = np.arccos(np.clip(np.sum(normal0 * normal1, axis=-1), -1, 1)) # [Q] + + D90 = np.pi / 2 + D180 = np.pi + D360 = np.pi * 2 + ang_eng = (np.abs(angle0 - D90)**2 + np.abs(angle1 - D90)**2 + np.abs(angle2 - D90)**2 + np.abs(angle3 - D90)**2) / 4 # [Q] + dist_eng = np.abs(angle0 - angle2)**2 / np.minimum(np.maximum(np.minimum(angle0, angle2), 1e-10), np.maximum(D180 - np.maximum(angle0, angle2), 1e-10)) \ + + np.abs(angle1 - angle3)**2 / np.minimum(np.maximum(np.minimum(angle1, angle3), 1e-10), np.maximum(D180 - np.maximum(angle1, angle3), 1e-10)) # [Q] + plane_eng = np.where(angle_normal < D90/2, np.abs(angle_normal)**2, 1e10) # [Q] + eng = ang_eng + 2 * dist_eng + 2 * plane_eng # [Q] + + return eng + + +def calc_quad_direction( + vertices: np.ndarray, + quads: np.ndarray, + ): + """ + Calculate the direction of each candidate quad face. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + quads (np.ndarray): [Q, 4] quad face indices + + Returns: + direction (np.ndarray): [Q, 4] direction of each quad face. + Represented by the angle between the crossing and each edge. + """ + mid0 = (vertices[quads[:, 0]] + vertices[quads[:, 1]]) / 2 # [Q, 3] + mid1 = (vertices[quads[:, 1]] + vertices[quads[:, 2]]) / 2 # [Q, 3] + mid2 = (vertices[quads[:, 2]] + vertices[quads[:, 3]]) / 2 # [Q, 3] + mid3 = (vertices[quads[:, 3]] + vertices[quads[:, 0]]) / 2 # [Q, 3] + + cross0 = mid2 - mid0 # [Q, 3] + cross1 = mid3 - mid1 # [Q, 3] + cross0 = cross0 / np.maximum(np.linalg.norm(cross0, axis=-1, keepdims=True), 1e-10) # [Q, 3] + cross1 = cross1 / np.maximum(np.linalg.norm(cross1, axis=-1, keepdims=True), 1e-10) # [Q, 3] + + edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3] + edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3] + edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3] + edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3] + edge0 = edge0 / np.maximum(np.linalg.norm(edge0, axis=-1, keepdims=True), 1e-10) # [Q, 3] + edge1 = edge1 / np.maximum(np.linalg.norm(edge1, axis=-1, keepdims=True), 1e-10) # [Q, 3] + edge2 = edge2 / np.maximum(np.linalg.norm(edge2, axis=-1, keepdims=True), 1e-10) # [Q, 3] + edge3 = edge3 / np.maximum(np.linalg.norm(edge3, axis=-1, keepdims=True), 1e-10) # [Q, 3] + + direction = np.stack([ + np.arccos(np.clip(np.sum(cross0 * edge0, axis=-1), -1, 1)), + np.arccos(np.clip(np.sum(cross1 * edge1, axis=-1), -1, 1)), + np.arccos(np.clip(np.sum(-cross0 * edge2, axis=-1), -1, 1)), + np.arccos(np.clip(np.sum(-cross1 * edge3, axis=-1), -1, 1)), + ], axis=-1) # [Q, 4] + + return direction + + +def calc_quad_smoothness( + quad2edge: np.ndarray, + quad2adj: np.ndarray, + quads_direction: np.ndarray, + ): + """ + Calculate the smoothness of each candidate quad face connection. + + Args: + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_direction (np.ndarray): [Q, 4] direction of each quad face + + Returns: + smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + """ + Q = quad2adj.shape[0] + quad2adj_valid = quad2adj != -1 + connections = np.stack([ + np.arange(Q)[:, None].repeat(8, axis=1), + quad2adj, + ], axis=-1)[quad2adj_valid] # [C, 2] + shared_edge_idx_0 = np.array([[0, 0, 1, 1, 2, 2, 3, 3]]).repeat(Q, axis=0)[quad2adj_valid] # [C] + shared_edge_idx_1 = np.argmax(quad2edge[quad2adj][quad2adj_valid] == quad2edge[connections[:, 0], shared_edge_idx_0][:, None], axis=-1) # [C] + valid_smoothness = np.abs(quads_direction[connections[:, 0], shared_edge_idx_0] - quads_direction[connections[:, 1], shared_edge_idx_1])**2 # [C] + smoothness = np.zeros([Q, 8], dtype=np.float32) + smoothness[quad2adj_valid] = valid_smoothness + return smoothness + + +def sovle_quad( + face2edge: np.ndarray, + edge2face: np.ndarray, + quad2adj: np.ndarray, + quads_distortion: np.ndarray, + quads_smoothness: np.ndarray, + quads_valid: np.ndarray, + ): + """ + Solve the quad mesh from the candidate quad faces. + + Args: + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_distortion (np.ndarray): [Q] distortion of each quad face + quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + + Returns: + weights (np.ndarray): [Q] weight of each valid quad face + """ + T = face2edge.shape[0] + E = edge2face.shape[0] + Q = quads_distortion.shape[0] + edge_valid = -np.ones(E, dtype=np.int32) + edge_valid[quads_valid] = np.arange(Q) + + quads_connection = np.stack([ + np.arange(Q)[:, None].repeat(8, axis=1), + quad2adj, + ], axis=-1)[quad2adj != -1] # [C, 2] + quads_connection = np.sort(quads_connection, axis=-1) # [C, 2] + quads_connection, quads_connection_idx = np.unique(quads_connection, axis=0, return_index=True) # [C, 2], [C] + quads_smoothness = quads_smoothness[quad2adj != -1] # [C] + quads_smoothness = quads_smoothness[quads_connection_idx] # [C] + C = quads_connection.shape[0] + + # Construct the linear programming problem + + # Variables: + # quads_weight: [Q] weight of each quad face + # tri_min_weight: [T] minimum weight of each triangle face + # conn_min_weight: [C] minimum weight of each quad face connection + # conn_max_weight: [C] maximum weight of each quad face connection + # Objective: + # mimi + + c = np.concatenate([ + quads_distortion - 3, + quads_smoothness*4 - 2, + quads_smoothness*4, + ], axis=0) # [Q+C] + + A_ub_triplet = np.concatenate([ + np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T, T+C), np.arange(Q, Q+C), np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T, T+C), quads_connection[:, 0], -np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T, T+C), quads_connection[:, 1], -np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T+C, T+2*C), np.arange(Q+C, Q+2*C), -np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T+C, T+2*C), quads_connection[:, 0], np.ones(C)], axis=1), # [C, 3] + np.stack([np.arange(T+C, T+2*C), quads_connection[:, 1], np.ones(C)], axis=1), # [C, 3] + ], axis=0) # [3T+6C, 3] + A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3] + A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T+2*C, Q+2*C]) # [T, + b_ub = np.concatenate([np.ones(T), -np.ones(C), np.ones(C)], axis=0) # [T+2C] + bound = np.stack([ + np.concatenate([np.zeros(Q), -np.ones(C), np.zeros(C)], axis=0), + np.concatenate([np.ones(Q), np.ones(C), np.ones(C)], axis=0), + ], axis=1) # [Q+2C, 2] + A_eq = None + b_eq = None + + print('Solver statistics:') + print(f' #T = {T}') + print(f' #Q = {Q}') + print(f' #C = {C}') + + # Solve the linear programming problem + last_num_valid = 0 + for i in range(100): + res_ = spopt.linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bound) + if not res_.success: + print(f' Iter {i} | Failed with {res_.message}') + break + res = res_ + weights = res.x[:Q] + valid = (weights > 0.5) + num_valid = valid.sum() + print(f' Iter {i} | #Q_valid = {num_valid}') + if num_valid == last_num_valid: + break + last_num_valid = num_valid + A_eq_triplet = np.stack([ + np.arange(num_valid), + np.arange(Q)[valid], + np.ones(num_valid), + ], axis=1) # [num_valid, 3] + A_eq = sp.sparse.coo_matrix((A_eq_triplet[:, 2], (A_eq_triplet[:, 0], A_eq_triplet[:, 1])), shape=[num_valid, Q+2*C]) # [num_valid, Q+C] + b_eq = np.where(weights[valid] > 0.5, 1, 0) # [num_valid] + + # Return the result + quads_weight = res.x[:Q] + conn_min_weight = res.x[Q:Q+C] + conn_max_weight = res.x[Q+C:Q+2*C] + return quads_weight, conn_min_weight, conn_max_weight + + +def sovle_quad_qp( + face2edge: np.ndarray, + edge2face: np.ndarray, + quad2adj: np.ndarray, + quads_distortion: np.ndarray, + quads_smoothness: np.ndarray, + quads_valid: np.ndarray, + ): + """ + Solve the quad mesh from the candidate quad faces. + + Args: + face2edge (np.ndarray): [T, 3] face to edge relation + edge2face (np.ndarray): [E, 2] edge to face relation + quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face + quads_distortion (np.ndarray): [Q] distortion of each quad face + quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection + quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid + + Returns: + weights (np.ndarray): [Q] weight of each valid quad face + """ + T = face2edge.shape[0] + E = edge2face.shape[0] + Q = quads_distortion.shape[0] + edge_valid = -np.ones(E, dtype=np.int32) + edge_valid[quads_valid] = np.arange(Q) + + # Construct the quadratic programming problem + C_smoothness_triplet = np.stack([ + np.arange(Q)[:, None].repeat(8, axis=1)[quad2adj != -1], + quad2adj[quad2adj != -1], + 5 * quads_smoothness[quad2adj != -1], + ], axis=-1) # [C, 3] + # C_smoothness_triplet = np.concatenate([ + # C_smoothness_triplet, + # np.stack([np.arange(Q), np.arange(Q), 20*np.ones(Q)], axis=1), + # ], axis=0) # [C+Q, 3] + C_smoothness = sp.sparse.coo_matrix((C_smoothness_triplet[:, 2], (C_smoothness_triplet[:, 0], C_smoothness_triplet[:, 1])), shape=[Q, Q]) # [Q, Q] + C_smoothness = C_smoothness.tocsc() + C_dist = quads_distortion - 20 # [Q] + + A_eq = sp.sparse.coo_matrix((np.zeros(Q), (np.zeros(Q), np.arange(Q))), shape=[1, Q]) # [1, Q]\ + A_eq = A_eq.tocsc() + b_eq = np.array([0]) + + A_ub_triplet = np.concatenate([ + np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3] + np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3] + ], axis=0) # [3T, 3] + A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3] + A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T, Q]) # [T, Q] + A_ub = A_ub.tocsc() + b_ub = np.ones(T) + + lb = np.zeros(Q) + ub = np.ones(Q) + + solver = piqp.SparseSolver() + solver.settings.verbose = True + solver.settings.compute_timings = True + solver.setup(C_smoothness, C_dist, A_eq, b_eq, A_ub, b_ub, lb, ub) + + status = solver.solve() + + # x = cp.Variable(Q) + # prob = cp.Problem( + # cp.Minimize(cp.quad_form(x, C_smoothness) + C_dist.T @ x), + # [ + # A_ub @ x <= b_ub, + # x >= 0, x <= 1, + # ] + # ) + + # # Solve the quadratic programming problem + # prob.solve(solver=cp.PIQP, verbose=True) + + # Return the result + weights = solver.result.x + return weights + + +def tri_to_quad( + vertices: np.ndarray, + faces: np.ndarray, + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Convert a triangle mesh to a quad mesh. + NOTE: The input mesh must be a manifold mesh. + + Args: + vertices (np.ndarray): [N, 3] 3-dimensional vertices + faces (np.ndarray): [T, 3] triangular face indices + + Returns: + vertices (np.ndarray): [N_, 3] 3-dimensional vertices + faces (np.ndarray): [Q, 4] quad face indices + """ + raise NotImplementedError + + +if __name__ == '__main__': + import os + import sys + sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..'))) + import utils3d + import numpy as np + import cv2 + from vis import vis_edge_color + + file = 'miku' + + vertices, faces = utils3d.io.read_ply(f'test/assets/{file}.ply') + edges, edge2face, face2edge, face2face = calc_relations(faces) + quad_cands, quad2edge, quad2adj, quad_valid = calc_quad_candidates(edges, face2edge, edge2face) + distortion = calc_quad_distortion(vertices, quad_cands) + direction = calc_quad_direction(vertices, quad_cands) + smoothness = calc_quad_smoothness(quad2edge, quad2adj, direction) + boundary_edges = edges[edge2face[:, 1] == -1] + quads_weight, conn_min_weight, conn_max_weight = sovle_quad(face2edge, edge2face, quad2adj, distortion, smoothness, quad_valid) + quads = quad_cands[quads_weight > 0.5] + print('Mesh statistics') + print(f' #V = {vertices.shape[0]}') + print(f' #F = {faces.shape[0]}') + print(f' #E = {edges.shape[0]}') + print(f' #B = {boundary_edges.shape[0]}') + print(f' #Q_cand = {quad_cands.shape[0]}') + print(f' #Q = {quads.shape[0]}') + + utils3d.io.write_ply(f'test/assets/{file}_boundary_edges.ply', vertices=vertices, edges=boundary_edges) + utils3d.io.write_ply(f'test/assets/{file}_quad_candidates.ply', vertices=vertices, faces=quads) + + edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8) + distortion = (distortion - distortion.min()) / (distortion.max() - distortion.min()) + distortion = (distortion * 255).astype(np.uint8) + edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap(distortion, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_distortion.ply', **vis_edge_color(vertices, edges, edge_colors)) + + edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8) + edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap((quads_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_weights.ply', **vis_edge_color(vertices, edges, edge_colors)) + utils3d.io.write_ply(f'test/assets/{file}_quad.ply', vertices=vertices, faces=quads) + + quad_centers = vertices[quad_cands].mean(axis=1) + conns = np.stack([ + np.arange(quad_cands.shape[0])[:, None].repeat(8, axis=1), + quad2adj, + ], axis=-1)[quad2adj != -1] # [C, 2] + conns, conns_idx = np.unique(np.sort(conns, axis=-1), axis=0, return_index=True) # [C, 2], [C] + smoothness = smoothness[quad2adj != -1][conns_idx] # [C] + conns_color = cv2.cvtColor(cv2.applyColorMap((smoothness * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_conn_smoothness.ply', **vis_edge_color(quad_centers, conns, conns_color)) + conns_color = cv2.cvtColor(cv2.applyColorMap((conn_min_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_conn_min.ply', **vis_edge_color(quad_centers, conns, conns_color)) + conns_color = cv2.cvtColor(cv2.applyColorMap((conn_max_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3) + utils3d.io.write_ply(f'test/assets/{file}_quad_conn_max.ply', **vis_edge_color(quad_centers, conns, conns_color)) + + \ No newline at end of file diff --git a/utils3d/numpy/rasterization.py b/utils3d/numpy/rasterization.py new file mode 100644 index 0000000000000000000000000000000000000000..7b8962f7a7403fda60997595c33140290be734a8 --- /dev/null +++ b/utils3d/numpy/rasterization.py @@ -0,0 +1,471 @@ +import os +from typing import * + +import numpy as np +import moderngl + +from . import transforms, utils, mesh + + +__all__ = [ + 'RastContext', + 'rasterize_triangle_faces', + 'rasterize_edges', + 'texture', + 'warp_image_by_depth', +] + + +def map_np_dtype(dtype) -> str: + if dtype == int: + return 'i4' + elif dtype == np.uint8: + return 'u1' + elif dtype == np.uint32: + return 'u2' + elif dtype == np.float16: + return 'f2' + elif dtype == np.float32: + return 'f4' + + +def one_value(dtype): + if dtype == 'u1': + return 255 + elif dtype == 'u2': + return 65535 + else: + return 1 + + +class RastContext: + def __init__(self, standalone: bool = True, backend: str = None, **kwargs): + """ + Create a moderngl context. + + Args: + standalone (bool, optional): whether to create a standalone context. Defaults to True. + backend (str, optional): backend to use. Defaults to None. + + Keyword Args: + See moderngl.create_context + """ + if backend is None: + self.mgl_ctx = moderngl.create_context(standalone=standalone, **kwargs) + else: + self.mgl_ctx = moderngl.create_context(standalone=standalone, backend=backend, **kwargs) + + self.__prog_src = {} + self.__prog = {} + + def __del__(self): + self.mgl_ctx.release() + + def screen_quad(self) -> moderngl.VertexArray: + self.screen_quad_vbo = self.mgl_ctx.buffer(np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]], dtype='f4')) + self.screen_quad_ibo = self.mgl_ctx.buffer(np.array([0, 1, 2, 0, 2, 3], dtype=np.int32)) + + def program_vertex_attribute(self, n: int) -> moderngl.Program: + assert n in [1, 2, 3, 4], 'vertex attribute only supports channels 1, 2, 3, 4' + + if 'vertex_attribute_vsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.vsh'), 'r') as f: + self.__prog_src['vertex_attribute_vsh'] = f.read() + if 'vertex_attribute_fsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.fsh'), 'r') as f: + self.__prog_src['vertex_attribute_fsh'] = f.read() + + if f'vertex_attribute_{n}' not in self.__prog: + vsh = self.__prog_src['vertex_attribute_vsh'].replace('vecN', f'vec{n}') + fsh = self.__prog_src['vertex_attribute_fsh'].replace('vecN', f'vec{n}') + self.__prog[f'vertex_attribute_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh) + + return self.__prog[f'vertex_attribute_{n}'] + + def program_texture(self, n: int) -> moderngl.Program: + assert n in [1, 2, 3, 4], 'texture only supports channels 1, 2, 3, 4' + + if 'texture_vsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.vsh'), 'r') as f: + self.__prog_src['texture_vsh'] = f.read() + if 'texture_fsh' not in self.__prog_src: + with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.fsh'), 'r') as f: + self.__prog_src['texture_fsh'] = f.read() + + if f'texture_{n}' not in self.__prog: + vsh = self.__prog_src['texture_vsh'].replace('vecN', f'vec{n}') + fsh = self.__prog_src['texture_fsh'].replace('vecN', f'vec{n}') + self.__prog[f'texture_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh) + self.__prog[f'texture_{n}']['tex'] = 0 + self.__prog[f'texture_{n}']['uv'] = 1 + + return self.__prog[f'texture_{n}'] + + +def rasterize_triangle_faces( + ctx: RastContext, + vertices: np.ndarray, + faces: np.ndarray, + attr: np.ndarray, + width: int, + height: int, + transform: np.ndarray = None, + cull_backface: bool = True, + return_depth: bool = False, + image: np.ndarray = None, + depth: np.ndarray = None +) -> Tuple[np.ndarray, np.ndarray]: + """ + Rasterize vertex attribute. + + Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + attr (np.ndarray): [N, C] + width (int): width of rendered image + height (int): height of rendered image + transform (np.ndarray): [4, 4] model-view-projection transformation matrix. + cull_backface (bool): whether to cull backface + image: (np.ndarray): [H, W, C] background image + depth: (np.ndarray): [H, W] background depth + + Returns: + image (np.ndarray): [H, W, C] rendered image + depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None. + """ + assert vertices.ndim == 2 and vertices.shape[1] == 3 + assert faces.ndim == 2 and faces.shape[1] == 3, f"Faces should be a 2D array with shape (T, 3), but got {faces.shape}" + assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}' + assert vertices.shape[0] == attr.shape[0] + assert vertices.dtype == np.float32 + assert faces.dtype == np.uint32 or faces.dtype == np.int32 + assert attr.dtype == np.float32, "Attribute should be float32" + + C = attr.shape[1] + prog = ctx.program_vertex_attribute(C) + + transform = np.eye(4, np.float32) if transform is None else transform + + # Create buffers + ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(faces, dtype='i4')) + vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4')) + vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4')) + vao = ctx.mgl_ctx.vertex_array( + prog, + [ + (vbo_vertices, '3f', 'i_position'), + (vbo_attr, f'{C}f', 'i_attr'), + ], + ibo, + mode=moderngl.TRIANGLES, + ) + + # Create framebuffer + image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None) + depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None) + fbo = ctx.mgl_ctx.framebuffer( + color_attachments=[image_tex], + depth_attachment=depth_tex, + ) + + # Render + prog['u_mvp'].write(transform.transpose().copy().astype('f4')) + fbo.use() + fbo.viewport = (0, 0, width, height) + ctx.mgl_ctx.depth_func = '<' + ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST) + if cull_backface: + ctx.mgl_ctx.enable(ctx.mgl_ctx.CULL_FACE) + else: + ctx.mgl_ctx.disable(ctx.mgl_ctx.CULL_FACE) + vao.render() + ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST) + + # Read + image = np.zeros((height, width, C), dtype='f4') + image_tex.read_into(image) + image = image[::-1, :, :] + if return_depth: + depth = np.zeros((height, width), dtype='f4') + depth_tex.read_into(depth) + depth = depth[::-1, :] + else: + depth = None + + # Release + vao.release() + ibo.release() + vbo_vertices.release() + vbo_attr.release() + fbo.release() + image_tex.release() + depth_tex.release() + + return image, depth + + +def rasterize_edges( + ctx: RastContext, + vertices: np.ndarray, + edges: np.ndarray, + attr: np.ndarray, + width: int, + height: int, + transform: np.ndarray = None, + line_width: float = 1.0, + return_depth: bool = False, + image: np.ndarray = None, + depth: np.ndarray = None +) -> Tuple[np.ndarray, ...]: + """ + Rasterize vertex attribute. + + Args: + vertices (np.ndarray): [N, 3] + faces (np.ndarray): [T, 3] + attr (np.ndarray): [N, C] + width (int): width of rendered image + height (int): height of rendered image + transform (np.ndarray): [4, 4] model-view-projection matrix + line_width (float): width of line. Defaults to 1.0. NOTE: Values other than 1.0 may not work across all platforms. + cull_backface (bool): whether to cull backface + + Returns: + image (np.ndarray): [H, W, C] rendered image + depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None. + """ + assert vertices.ndim == 2 and vertices.shape[1] == 3 + assert edges.ndim == 2 and edges.shape[1] == 2, f"Edges should be a 2D array with shape (T, 2), but got {edges.shape}" + assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}' + assert vertices.shape[0] == attr.shape[0] + assert vertices.dtype == np.float32 + assert edges.dtype == np.uint32 or edges.dtype == np.int32 + assert attr.dtype == np.float32, "Attribute should be float32" + + C = attr.shape[1] + prog = ctx.program_vertex_attribute(C) + + transform = transform if transform is not None else np.eye(4, np.float32) + + # Create buffers + ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(edges, dtype='i4')) + vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4')) + vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4')) + vao = ctx.mgl_ctx.vertex_array( + prog, + [ + (vbo_vertices, '3f', 'i_position'), + (vbo_attr, f'{C}f', 'i_attr'), + ], + ibo, + mode=moderngl.LINES, + ) + + # Create framebuffer + image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None) + depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None) + fbo = ctx.mgl_ctx.framebuffer( + color_attachments=[image_tex], + depth_attachment=depth_tex, + ) + + # Render + prog['u_mvp'].write(transform.transpose().copy().astype('f4')) + fbo.use() + fbo.viewport = (0, 0, width, height) + ctx.mgl_ctx.depth_func = '<' + ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST) + ctx.mgl_ctx.line_width = line_width + vao.render() + ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST) + + # Read + image = np.zeros((height, width, C), dtype='f4') + image_tex.read_into(image) + image = image[::-1, :, :] + if return_depth: + depth = np.zeros((height, width), dtype='f4') + depth_tex.read_into(depth) + depth = depth[::-1, :] + else: + depth = None + + # Release + vao.release() + ibo.release() + vbo_vertices.release() + vbo_attr.release() + fbo.release() + image_tex.release() + depth_tex.release() + + return image, depth + + +def texture( + ctx: RastContext, + uv: np.ndarray, + texture: np.ndarray, + interpolation: str= 'linear', + wrap: str = 'clamp' +) -> np.ndarray: + """ + Given an UV image, texturing from the texture map + """ + assert len(texture.shape) == 3 and 1 <= texture.shape[2] <= 4 + assert uv.shape[2] == 2 + height, width = uv.shape[:2] + texture_dtype = map_np_dtype(texture.dtype) + + # Create VAO + screen_quad_vbo = ctx.mgl_ctx.buffer(np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]], dtype='f4')) + screen_quad_ibo = ctx.mgl_ctx.buffer(np.array([0, 1, 2, 0, 2, 3], dtype=np.int32)) + screen_quad_vao = ctx.mgl_ctx.vertex_array(ctx.program_texture(texture.shape[2]), [(screen_quad_vbo, '2f4', 'in_vert')], index_buffer=screen_quad_ibo, index_element_size=4) + + # Create texture, set filter and bind. TODO: min mag filter, mipmap + texture_tex = ctx.mgl_ctx.texture((texture.shape[1], texture.shape[0]), texture.shape[2], dtype=texture_dtype, data=np.ascontiguousarray(texture)) + if interpolation == 'linear': + texture_tex.filter = (moderngl.LINEAR, moderngl.LINEAR) + elif interpolation == 'nearest': + texture_tex.filter = (moderngl.NEAREST, moderngl.NEAREST) + texture_tex.use(location=0) + texture_uv = ctx.mgl_ctx.texture((width, height), 2, dtype='f4', data=np.ascontiguousarray(uv.astype('f4', copy=False))) + texture_uv.filter = (moderngl.NEAREST, moderngl.NEAREST) + texture_uv.use(location=1) + + # Create render buffer and frame buffer + rb = ctx.mgl_ctx.renderbuffer((uv.shape[1], uv.shape[0]), texture.shape[2], dtype=texture_dtype) + fbo = ctx.mgl_ctx.framebuffer(color_attachments=[rb]) + + # Render + fbo.use() + fbo.viewport = (0, 0, width, height) + ctx.mgl_ctx.disable(ctx.mgl_ctx.BLEND) + screen_quad_vao.render() + + # Read buffer + image_buffer = np.frombuffer(fbo.read(components=texture.shape[2], attachment=0, dtype=texture_dtype), dtype=texture_dtype).reshape((height, width, texture.shape[2])) + + # Release + texture_tex.release() + rb.release() + fbo.release() + + return image_buffer + + +def warp_image_by_depth( + ctx: RastContext, + src_depth: np.ndarray, + src_image: np.ndarray = None, + width: int = None, + height: int = None, + *, + extrinsics_src: np.ndarray = None, + extrinsics_tgt: np.ndarray = None, + intrinsics_src: np.ndarray = None, + intrinsics_tgt: np.ndarray = None, + near: float = 0.1, + far: float = 100.0, + cull_backface: bool = True, + ssaa: int = 1, + return_depth: bool = False, +) -> Tuple[np.ndarray, ...]: + """ + Warp image by depth map. + + Args: + ctx (RastContext): rasterizer context + src_depth (np.ndarray): [H, W] + src_image (np.ndarray, optional): [H, W, C]. The image to warp. Defaults to None (use uv coordinates). + width (int, optional): width of the output image. None to use depth map width. Defaults to None. + height (int, optional): height of the output image. None to use depth map height. Defaults to None. + extrinsics_src (np.ndarray, optional): extrinsics matrix of the source camera. Defaults to None (identity). + extrinsics_tgt (np.ndarray, optional): extrinsics matrix of the target camera. Defaults to None (identity). + intrinsics_src (np.ndarray, optional): intrinsics matrix of the source camera. Defaults to None (use the same as intrinsics_tgt). + intrinsics_tgt (np.ndarray, optional): intrinsics matrix of the target camera. Defaults to None (use the same as intrinsics_src). + cull_backface (bool, optional): whether to cull backface. Defaults to True. + ssaa (int, optional): super sampling anti-aliasing. Defaults to 1. + + Returns: + tgt_image (np.ndarray): [H, W, C] warped image (or uv coordinates if image is None). + tgt_depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None. + """ + assert src_depth.ndim == 2 + + if width is None: + width = src_depth.shape[1] + if height is None: + height = src_depth.shape[0] + if src_image is not None: + assert src_image.shape[-2:] == src_depth.shape[-2:], f'Shape of source image {src_image.shape} does not match shape of source depth {src_depth.shape}' + + # set up default camera parameters + extrinsics_src = np.eye(4) if extrinsics_src is None else extrinsics_src + extrinsics_tgt = np.eye(4) if extrinsics_tgt is None else extrinsics_tgt + intrinsics_src = intrinsics_tgt if intrinsics_src is None else intrinsics_src + intrinsics_tgt = intrinsics_src if intrinsics_tgt is None else intrinsics_tgt + + assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters." + + # check shapes + assert extrinsics_src.shape == (4, 4) and extrinsics_tgt.shape == (4, 4) + assert intrinsics_src.shape == (3, 3) and intrinsics_tgt.shape == (3, 3) + + # convert to view and perspective matrices + view_tgt = transforms.extrinsics_to_view(extrinsics_tgt) + perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far) + + # unproject depth map + uv, faces = utils.image_mesh(*src_depth.shape[-2:]) + pts = transforms.unproject_cv(uv, src_depth.reshape(-1), extrinsics_src, intrinsics_src) + faces = mesh.triangulate(faces, vertices=pts) + + # rasterize attributes + if src_image is not None: + attr = src_image.reshape(-1, src_image.shape[-1]) + else: + attr = uv + + tgt_image, tgt_depth = rasterize_triangle_faces( + ctx, + pts, + faces, + attr, + width * ssaa, + height * ssaa, + transform=perspective_tgt @ view_tgt, + cull_backface=cull_backface, + return_depth=return_depth, + ) + + if ssaa > 1: + tgt_image = tgt_image.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3)) + tgt_depth = tgt_depth.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3)) if return_depth else None + + return tgt_image, tgt_depth + +def test(): + """ + Test if rasterization works. It will render a cube with random colors and save it as a CHECKME.png file. + """ + ctx = RastContext(backend='egl') + vertices, faces = utils.cube(tri=True) + attr = np.random.rand(len(vertices), 3).astype(np.float32) + perspective = transforms.perspective(np.deg2rad(60), 1, 0.01, 100) + view = transforms.view_look_at(np.array([2, 2, 2]), np.array([0, 0, 0]), np.array([0, 1, 0])) + image, _ = rasterize_triangle_faces( + ctx, + vertices, + faces, + attr, + 512, 512, + view=view, + projection=perspective, + cull_backface=True, + ssaa=1, + return_depth=True, + ) + import cv2 + cv2.imwrite('CHECKME.png', cv2.cvtColor((image.clip(0, 1) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR)) + \ No newline at end of file diff --git a/utils3d/numpy/shaders/texture.fsh b/utils3d/numpy/shaders/texture.fsh new file mode 100644 index 0000000000000000000000000000000000000000..c8be72f94cbf38fb0b2a9609e8db4d50ac7753d6 --- /dev/null +++ b/utils3d/numpy/shaders/texture.fsh @@ -0,0 +1,11 @@ +#version 330 + +uniform sampler2D tex; +uniform sampler2D uv; + +in vec2 scr_coord; +out vecN tex_color; + +void main() { + tex_color = vecN(texture(tex, texture(uv, scr_coord).xy)); +} \ No newline at end of file diff --git a/utils3d/numpy/shaders/texture.vsh b/utils3d/numpy/shaders/texture.vsh new file mode 100644 index 0000000000000000000000000000000000000000..f96c6b14a8931fbcd5f4ca22ea917b9c8f80f195 --- /dev/null +++ b/utils3d/numpy/shaders/texture.vsh @@ -0,0 +1,9 @@ + #version 330 core + +in vec2 in_vert; +out vec2 scr_coord; + +void main() { + scr_coord = in_vert * 0.5 + 0.5; + gl_Position = vec4(in_vert, 0., 1.); +} \ No newline at end of file diff --git a/utils3d/numpy/shaders/vertex_attribute.fsh b/utils3d/numpy/shaders/vertex_attribute.fsh new file mode 100644 index 0000000000000000000000000000000000000000..54409764c5600ee190db89313b07dd91b940d6eb --- /dev/null +++ b/utils3d/numpy/shaders/vertex_attribute.fsh @@ -0,0 +1,9 @@ +#version 330 + +in vecN v_attr; + +out vecN f_attr; + +void main() { + f_attr = v_attr; +} diff --git a/utils3d/numpy/shaders/vertex_attribute.vsh b/utils3d/numpy/shaders/vertex_attribute.vsh new file mode 100644 index 0000000000000000000000000000000000000000..7c94f91aaabfd714a47a194b93f8e53bf63577f5 --- /dev/null +++ b/utils3d/numpy/shaders/vertex_attribute.vsh @@ -0,0 +1,13 @@ +#version 330 + +uniform mat4 u_mvp; + +in vec3 i_position; +in vecN i_attr; + +out vecN v_attr; + +void main() { + gl_Position = u_mvp * vec4(i_position, 1.0); + v_attr = i_attr; +} diff --git a/utils3d/numpy/spline.py b/utils3d/numpy/spline.py new file mode 100644 index 0000000000000000000000000000000000000000..03c664136bc3734215d37669a3446c248dffe097 --- /dev/null +++ b/utils3d/numpy/spline.py @@ -0,0 +1,82 @@ +from typing import * + +import numpy as np + + +__all__ = ['linear_spline_interpolate'] + + +def linear_spline_interpolate(x: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray: + """ + Linear spline interpolation. + + ### Parameters: + - `x`: np.ndarray, shape (n, d): the values of data points. + - `t`: np.ndarray, shape (n,): the times of the data points. + - `s`: np.ndarray, shape (m,): the times to be interpolated. + - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + + ### Returns: + - `y`: np.ndarray, shape (..., m, d): the interpolated values. + """ + i = np.searchsorted(t, s, side='left') + if extrapolation_mode == 'constant': + prev = np.clip(i - 1, 0, len(t) - 1) + suc = np.clip(i, 0, len(t) - 1) + elif extrapolation_mode == 'linear': + prev = np.clip(i - 1, 0, len(t) - 2) + suc = np.clip(i, 1, len(t) - 1) + else: + raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}') + + u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12) + y = u * x[suc] + (1 - u) * x[prev] + + return y + + + +def _solve_tridiagonal(a: np.ndarray, b: np.ndarray, c: np.ndarray, d: np.ndarray) -> np.ndarray: + n = b.shape[-1] + cc = np.zeros_like(b) + dd = np.zeros_like(b) + cc[..., 0] = c[..., 0] / b[..., 0] + dd[..., 0] = d[..., 0] / b[..., 0] + for i in range(1, n): + cc[..., i] = c[..., i] / (b[..., i] - a[..., i - 1] * cc[..., i - 1]) + dd[..., i] = (d[..., i] - a[..., i - 1] * dd[..., i - 1]) / (b[..., i] - a[..., i - 1] * cc[..., i - 1]) + x = np.zeros_like(b) + x[..., -1] = dd[..., -1] + for i in range(n - 2, -1, -1): + x[..., i] = dd[..., i] - cc[..., i] * x[..., i + 1] + return x + + +def cubic_spline_interpolate(x: np.ndarray, t: np.ndarray, s: np.ndarray, v0: np.ndarray = None, vn: np.ndarray = None) -> np.ndarray: + """ + Cubic spline interpolation. + + ### Parameters: + - `x`: np.ndarray, shape (..., n,): the x-coordinates of the data points. + - `t`: np.ndarray, shape (n,): the knot vector. NOTE: t must be sorted in ascending order. + - `s`: np.ndarray, shape (..., m,): the y-coordinates of the data points. + - `v0`: np.ndarray, shape (...,): the value of the derivative at the first knot, as the boundary condition. If None, it is set to zero. + - `vn`: np.ndarray, shape (...,): the value of the derivative at the last knot, as the boundary condition. If None, it is set to zero. + + ### Returns: + - `y`: np.ndarray, shape (..., m): the interpolated values. + """ + h = t[..., 1:] - t[..., :-1] + mu = h[..., :-1] / (h[..., :-1] + h[..., 1:]) + la = 1 - mu + d = (x[..., 1:] - x[..., :-1]) / h + d = 6 * (d[..., 1:] - d[..., :-1]) / (t[..., 2:] - t[..., :-2]) + + mu = np.concatenate([mu, np.ones_like(mu[..., :1])], axis=-1) + la = np.concatenate([np.ones_like(la[..., :1]), la], axis=-1) + d = np.concatenate([(((x[..., 1] - x[..., 0]) / h[0] - v0) / h[0])[..., None], d, ((vn - (x[..., -1] - x[..., -2]) / h[-1]) / h[-1])[..., None]], axis=-1) + + M = _solve_tridiagonal(mu, np.full_like(d, fill_value=2), la, d) + + i = np.searchsorted(t, s, side='left') + diff --git a/utils3d/numpy/transforms.py b/utils3d/numpy/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..867ba49b3287e70572275f9ae8d689e6f854d100 --- /dev/null +++ b/utils3d/numpy/transforms.py @@ -0,0 +1,1084 @@ +import numpy as np +from typing import * +from numbers import Number +from ._helpers import batched + + +__all__ = [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'perspective_to_near_far', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'unproject_cv', + 'unproject_gl', + 'project_cv', + 'project_gl', + 'quaternion_to_matrix', + 'axis_angle_to_matrix', + 'matrix_to_quaternion', + 'extrinsics_to_essential', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'ray_intersection', + 'se3_matrix', + 'slerp_quaternion', + 'slerp_vector', + 'lerp', + 'lerp_se3_matrix', + 'piecewise_lerp', + 'piecewise_lerp_se3_matrix', + 'apply_transform' +] + + +@batched(0,0,0,0) +def perspective( + fov_y: Union[float, np.ndarray], + aspect: Union[float, np.ndarray], + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + Get OpenGL perspective matrix + + Args: + fov_y (float | np.ndarray): field of view in y axis + aspect (float | np.ndarray): aspect ratio + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + + Returns: + (np.ndarray): [..., 4, 4] perspective matrix + """ + N = fov_y.shape[0] + ret = np.zeros((N, 4, 4), dtype=fov_y.dtype) + ret[:, 0, 0] = 1. / (np.tan(fov_y / 2) * aspect) + ret[:, 1, 1] = 1. / (np.tan(fov_y / 2)) + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +def perspective_from_fov( + fov: Union[float, np.ndarray], + width: Union[int, np.ndarray], + height: Union[int, np.ndarray], + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] + ) -> np.ndarray: + """ + Get OpenGL perspective matrix from field of view in largest dimension + + Args: + fov (float | np.ndarray): field of view in largest dimension + width (int | np.ndarray): image width + height (int | np.ndarray): image height + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + + Returns: + (np.ndarray): [..., 4, 4] perspective matrix + """ + fov_y = 2 * np.arctan(np.tan(fov / 2) * height / np.maximum(width, height)) + aspect = width / height + return perspective(fov_y, aspect, near, far) + + +def perspective_from_fov_xy( + fov_x: Union[float, np.ndarray], + fov_y: Union[float, np.ndarray], + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + Get OpenGL perspective matrix from field of view in x and y axis + + Args: + fov_x (float | np.ndarray): field of view in x axis + fov_y (float | np.ndarray): field of view in y axis + near (float | np.ndarray): near plane to clip + far (float | np.ndarray): far plane to clip + + Returns: + (np.ndarray): [..., 4, 4] perspective matrix + """ + aspect = np.tan(fov_x / 2) / np.tan(fov_y / 2) + return perspective(fov_y, aspect, near, far) + + +def intrinsics_from_focal_center( + fx: Union[float, np.ndarray], + fy: Union[float, np.ndarray], + cx: Union[float, np.ndarray], + cy: Union[float, np.ndarray], + dtype: Optional[np.dtype] = np.float32 +) -> np.ndarray: + """ + Get OpenCV intrinsics matrix + + Returns: + (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix + """ + if any(isinstance(x, np.ndarray) for x in (fx, fy, cx, cy)): + dtype = np.result_type(fx, fy, cx, cy) + fx, fy, cx, cy = np.broadcast_arrays(fx, fy, cx, cy) + ret = np.zeros((*fx.shape, 3, 3), dtype=dtype) + ret[..., 0, 0] = fx + ret[..., 1, 1] = fy + ret[..., 0, 2] = cx + ret[..., 1, 2] = cy + ret[..., 2, 2] = 1. + return ret + + +def intrinsics_from_fov( + fov_max: Union[float, np.ndarray] = None, + fov_min: Union[float, np.ndarray] = None, + fov_x: Union[float, np.ndarray] = None, + fov_y: Union[float, np.ndarray] = None, + width: Union[int, np.ndarray] = None, + height: Union[int, np.ndarray] = None, +) -> np.ndarray: + """ + Get normalized OpenCV intrinsics matrix from given field of view. + You can provide either fov_max, fov_min, fov_x or fov_y + + Args: + width (int | np.ndarray): image width + height (int | np.ndarray): image height + fov_max (float | np.ndarray): field of view in largest dimension + fov_min (float | np.ndarray): field of view in smallest dimension + fov_x (float | np.ndarray): field of view in x axis + fov_y (float | np.ndarray): field of view in y axis + + Returns: + (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix + """ + if fov_max is not None: + fx = np.maximum(width, height) / width / (2 * np.tan(fov_max / 2)) + fy = np.maximum(width, height) / height / (2 * np.tan(fov_max / 2)) + elif fov_min is not None: + fx = np.minimum(width, height) / width / (2 * np.tan(fov_min / 2)) + fy = np.minimum(width, height) / height / (2 * np.tan(fov_min / 2)) + elif fov_x is not None and fov_y is not None: + fx = 1 / (2 * np.tan(fov_x / 2)) + fy = 1 / (2 * np.tan(fov_y / 2)) + elif fov_x is not None: + fx = 1 / (2 * np.tan(fov_x / 2)) + fy = fx * width / height + elif fov_y is not None: + fy = 1 / (2 * np.tan(fov_y / 2)) + fx = fy * height / width + cx = 0.5 + cy = 0.5 + ret = intrinsics_from_focal_center(fx, fy, cx, cy) + return ret + + +@batched(1,1,1) +def view_look_at( + eye: np.ndarray, + look_at: np.ndarray, + up: np.ndarray + ) -> np.ndarray: + """ + Get OpenGL view matrix looking at something + + Args: + eye (np.ndarray): [..., 3] the eye position + look_at (np.ndarray): [..., 3] the position to look at + up (np.ndarray): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (np.ndarray): [..., 4, 4], view matrix + """ + z = eye - look_at + x = np.cross(up, z) + y = np.cross(z, x) + # x = np.cross(y, z) + x = x / np.linalg.norm(x, axis=-1, keepdims=True) + y = y / np.linalg.norm(y, axis=-1, keepdims=True) + z = z / np.linalg.norm(z, axis=-1, keepdims=True) + R = np.stack([x, y, z], axis=-2) + t = -np.matmul(R, eye[..., None]) + return np.concatenate([ + np.concatenate([R, t], axis=-1), + np.array([[[0., 0., 0., 1.]]]).repeat(eye.shape[0], axis=0) + ], axis=-2) + + +@batched(1,1,1) +def extrinsics_look_at( + eye: np.ndarray, + look_at: np.ndarray, + up: np.ndarray +) -> np.ndarray: + """ + Get OpenCV extrinsics matrix looking at something + + Args: + eye (np.ndarray): [..., 3] the eye position + look_at (np.ndarray): [..., 3] the position to look at + up (np.ndarray): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (np.ndarray): [..., 4, 4], extrinsics matrix + """ + z = look_at - eye + x = np.cross(-up, z) + y = np.cross(z, x) + # x = np.cross(y, z) + x = x / np.linalg.norm(x, axis=-1, keepdims=True) + y = y / np.linalg.norm(y, axis=-1, keepdims=True) + z = z / np.linalg.norm(z, axis=-1, keepdims=True) + R = np.stack([x, y, z], axis=-2) + t = -np.matmul(R, eye[..., None]) + return np.concatenate([ + np.concatenate([R, t], axis=-1), + np.array([[[0., 0., 0., 1.]]], dtype=eye.dtype).repeat(eye.shape[0], axis=0) + ], axis=-2) + + +def perspective_to_intrinsics( + perspective: np.ndarray +) -> np.ndarray: + """ + OpenGL perspective matrix to OpenCV intrinsics + + Args: + perspective (np.ndarray): [..., 4, 4] OpenGL perspective matrix + + Returns: + (np.ndarray): shape [..., 3, 3] OpenCV intrinsics + """ + ret = np.array([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype) \ + @ perspective[..., [0, 1, 3], :3] \ + @ np.diag(np.array([1, -1, -1], dtype=perspective.dtype)) + return ret + + +def perspective_to_near_far(perspective: np.ndarray) -> Tuple[np.ndarray, np.ndarray]: + """ + Get near and far planes from OpenGL perspective matrix + + Args: + """ + a, b = perspective[..., 2, 2], perspective[..., 2, 3] + near, far = b / (a - 1), b / (a + 1) + return near, far + + +@batched(2,0,0) +def intrinsics_to_perspective( + intrinsics: np.ndarray, + near: Union[float, np.ndarray], + far: Union[float, np.ndarray], +) -> np.ndarray: + """ + OpenCV intrinsics to OpenGL perspective matrix + NOTE: not work for tile-shifting intrinsics currently + + Args: + intrinsics (np.ndarray): [..., 3, 3] OpenCV intrinsics matrix + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + Returns: + (np.ndarray): [..., 4, 4] OpenGL perspective matrix + """ + N = intrinsics.shape[0] + fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1] + cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2] + ret = np.zeros((N, 4, 4), dtype=intrinsics.dtype) + ret[:, 0, 0] = 2 * fx + ret[:, 1, 1] = 2 * fy + ret[:, 0, 2] = -2 * cx + 1 + ret[:, 1, 2] = 2 * cy - 1 + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +@batched(2) +def extrinsics_to_view( + extrinsics: np.ndarray + ) -> np.ndarray: + """ + OpenCV camera extrinsics to OpenGL view matrix + + Args: + extrinsics (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix + + Returns: + (np.ndarray): [..., 4, 4] OpenGL view matrix + """ + return extrinsics * np.array([1, -1, -1, 1], dtype=extrinsics.dtype)[:, None] + + +@batched(2) +def view_to_extrinsics( + view: np.ndarray + ) -> np.ndarray: + """ + OpenGL view matrix to OpenCV camera extrinsics + + Args: + view (np.ndarray): [..., 4, 4] OpenGL view matrix + + Returns: + (np.ndarray): [..., 4, 4] OpenCV camera extrinsics matrix + """ + return view * np.array([1, -1, -1, 1], dtype=view.dtype)[:, None] + + +@batched(2, 0, 0, None) +def normalize_intrinsics( + intrinsics: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray], + integer_pixel_centers: bool = True +) -> np.ndarray: + """ + Normalize intrinsics from pixel cooridnates to uv coordinates + + Args: + intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to normalize + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + integer_pixel_centers (bool): whether the integer pixel coordinates are at the center of the pixel. If False, the integer coordinates are at the left-top corner of the pixel. + + Returns: + (np.ndarray): [..., 3, 3] normalized camera intrinsics(s) + """ + zeros = np.zeros_like(width) + ones = np.ones_like(width) + if integer_pixel_centers: + transform = np.stack([ + 1 / width, zeros, 0.5 / width, + zeros, 1 / height, 0.5 / height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3) + else: + transform = np.stack([ + 1 / width, zeros, zeros, + zeros, 1 / height, zeros, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3) + return transform @ intrinsics + + +@batched(2,0,0,0,0,0,0) +def crop_intrinsics( + intrinsics: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray], + left: Union[int, np.ndarray], + top: Union[int, np.ndarray], + crop_width: Union[int, np.ndarray], + crop_height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width] + + Args: + intrinsics (np.ndarray): [..., 3, 3] camera intrinsics(s) to crop + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + left (int | np.ndarray): [...] left crop boundary + top (int | np.ndarray): [...] top crop boundary + crop_width (int | np.ndarray): [...] crop width + crop_height (int | np.ndarray): [...] crop height + + Returns: + (np.ndarray): [..., 3, 3] cropped camera intrinsics(s) + """ + zeros = np.zeros_like(width) + ones = np.ones_like(width) + transform = np.stack([ + width / crop_width, zeros, -left / crop_width, + zeros, height / crop_height, -top / crop_height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3) + return transform @ intrinsics + + +@batched(1,0,0) +def pixel_to_uv( + pixel: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + + Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + if not np.issubdtype(pixel.dtype, np.floating): + pixel = pixel.astype(np.float32) + dtype = pixel.dtype + uv = (pixel + np.array(0.5, dtype=dtype)) / np.stack([width, height], axis=-1) + return uv + + +@batched(1,0,0) +def uv_to_pixel( + uv: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + + Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + pixel = uv * np.stack([width, height], axis=-1) - 0.5 + return pixel + + +@batched(1,0,0) +def pixel_to_ndc( + pixel: np.ndarray, + width: Union[int, np.ndarray], + height: Union[int, np.ndarray] +) -> np.ndarray: + """ + Args: + pixel (np.ndarray): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | np.ndarray): [...] image width(s) + height (int | np.ndarray): [...] image height(s) + + Returns: + (np.ndarray): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1) + """ + if not np.issubdtype(pixel.dtype, np.floating): + pixel = pixel.astype(np.float32) + dtype = pixel.dtype + ndc = (pixel + np.array(0.5, dtype=dtype)) / (np.stack([width, height], dim=-1) * np.array([2, -2], dtype=dtype)) \ + + np.array([-1, 1], dtype=dtype) + return ndc + + +@batched(0,0,0) +def project_depth( + depth: np.ndarray, + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + Project linear depth to depth value in screen space + + Args: + depth (np.ndarray): [...] depth value + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + + Returns: + (np.ndarray): [..., 1] depth value in screen space, value ranging in [0, 1] + """ + return (far - near * far / depth) / (far - near) + + +@batched(0,0,0) +def depth_buffer_to_linear( + depth_buffer: np.ndarray, + near: Union[float, np.ndarray], + far: Union[float, np.ndarray] +) -> np.ndarray: + """ + OpenGL depth buffer to linear depth + + Args: + depth_buffer (np.ndarray): [...] depth value + near (float | np.ndarray): [...] near plane to clip + far (float | np.ndarray): [...] far plane to clip + + Returns: + (np.ndarray): [..., 1] linear depth + """ + return near * far / (far - (far - near) * depth_buffer) + + +@batched(2,2,2,2) +def project_gl( + points: np.ndarray, + model: np.ndarray = None, + view: np.ndarray = None, + perspective: np.ndarray = None + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Project 3D points to 2D following the OpenGL convention (except for row major matrice) + + Args: + points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + model (np.ndarray): [..., 4, 4] model matrix + view (np.ndarray): [..., 4, 4] view matrix + perspective (np.ndarray): [..., 4, 4] perspective matrix + + Returns: + scr_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + linear_depth (np.ndarray): [..., N] linear depth + """ + assert perspective is not None, "perspective matrix is required" + if points.shape[-1] == 3: + points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1) + if model is not None: + points = points @ model.swapaxes(-1, -2) + if view is not None: + points = points @ view.swapaxes(-1, -2) + clip_coord = points @ perspective.swapaxes(-1, -2) + ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:] + scr_coord = ndc_coord * 0.5 + 0.5 + linear_depth = clip_coord[..., 3] + return scr_coord, linear_depth + + +@batched(2,2,2) +def project_cv( + points: np.ndarray, + extrinsics: np.ndarray = None, + intrinsics: np.ndarray = None + ) -> Tuple[np.ndarray, np.ndarray]: + """ + Project 3D points to 2D following the OpenCV convention + + Args: + points (np.ndarray): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix + intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix + + Returns: + uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + linear_depth (np.ndarray): [..., N] linear depth + """ + assert intrinsics is not None, "intrinsics matrix is required" + if points.shape[-1] == 3: + points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1) + if extrinsics is not None: + points = points @ extrinsics.swapaxes(-1, -2) + points = points[..., :3] @ intrinsics.swapaxes(-1, -2) + uv_coord = points[..., :2] / points[..., 2:] + linear_depth = points[..., 2] + return uv_coord, linear_depth + + +@batched(2,2,2,2) +def unproject_gl( + screen_coord: np.ndarray, + model: np.ndarray = None, + view: np.ndarray = None, + perspective: np.ndarray = None + ) -> np.ndarray: + """ + Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice) + + Args: + screen_coord (np.ndarray): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + model (np.ndarray): [..., 4, 4] model matrix + view (np.ndarray): [..., 4, 4] view matrix + perspective (np.ndarray): [..., 4, 4] perspective matrix + + Returns: + points (np.ndarray): [..., N, 3] 3d points + """ + assert perspective is not None, "perspective matrix is required" + ndc_xy = screen_coord * 2 - 1 + clip_coord = np.concatenate([ndc_xy, np.ones_like(ndc_xy[..., :1])], axis=-1) + transform = perspective + if view is not None: + transform = transform @ view + if model is not None: + transform = transform @ model + transform = np.linalg.inv(transform) + points = clip_coord @ transform.swapaxes(-1, -2) + points = points[..., :3] / points[..., 3:] + return points + + +@batched(2,1,2,2) +def unproject_cv( + uv_coord: np.ndarray, + depth: np.ndarray, + extrinsics: np.ndarray = None, + intrinsics: np.ndarray = None +) -> np.ndarray: + """ + Unproject uv coordinates to 3D view space following the OpenCV convention + + Args: + uv_coord (np.ndarray): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (np.ndarray): [..., N] depth value + extrinsics (np.ndarray): [..., 4, 4] extrinsics matrix + intrinsics (np.ndarray): [..., 3, 3] intrinsics matrix + + Returns: + points (np.ndarray): [..., N, 3] 3d points + """ + assert intrinsics is not None, "intrinsics matrix is required" + points = np.concatenate([uv_coord, np.ones_like(uv_coord[..., :1])], axis=-1) + points = points @ np.linalg.inv(intrinsics).swapaxes(-1, -2) + points = points * depth[..., None] + if extrinsics is not None: + points = np.concatenate([points, np.ones_like(points[..., :1])], axis=-1) + points = (points @ np.linalg.inv(extrinsics).swapaxes(-1, -2))[..., :3] + return points + + +def quaternion_to_matrix(quaternion: np.ndarray, eps: float = 1e-12) -> np.ndarray: + """Converts a batch of quaternions (w, x, y, z) to rotation matrices + + Args: + quaternion (np.ndarray): shape (..., 4), the quaternions to convert + + Returns: + np.ndarray: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions + """ + assert quaternion.shape[-1] == 4 + quaternion = quaternion / np.linalg.norm(quaternion, axis=-1, keepdims=True).clip(min=eps) + w, x, y, z = quaternion[..., 0], quaternion[..., 1], quaternion[..., 2], quaternion[..., 3] + zeros = np.zeros_like(w) + I = np.eye(3, dtype=quaternion.dtype) + xyz = quaternion[..., 1:] + A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(axis=-1)[..., None, None] + B = np.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros + ], axis=-1).reshape(*quaternion.shape[:-1], 3, 3) + rot_mat = I + 2 * (A + w[..., None, None] * B) + return rot_mat + + +def matrix_to_quaternion(rot_mat: np.ndarray, eps: float = 1e-12) -> np.ndarray: + """Convert 3x3 rotation matrix to quaternion (w, x, y, z) + + Args: + rot_mat (np.ndarray): shape (..., 3, 3), the rotation matrices to convert + + Returns: + np.ndarray: shape (..., 4), the quaternions corresponding to the given rotation matrices + """ + # Extract the diagonal and off-diagonal elements of the rotation matrix + m00, m01, m02, m10, m11, m12, m20, m21, m22 = [rot_mat[..., i, j] for i in range(3) for j in range(3)] + + diag = np.diagonal(rot_mat, axis1=-2, axis2=-1) + M = np.array([ + [1, 1, 1], + [1, -1, -1], + [-1, 1, -1], + [-1, -1, 1] + ], dtype=rot_mat.dtype) + wxyz = 0.5 * np.clip(1 + diag @ M.T, 0.0, None) ** 0.5 + max_idx = np.argmax(wxyz, axis=-1) + xw = np.sign(m21 - m12) + yw = np.sign(m02 - m20) + zw = np.sign(m10 - m01) + yz = np.sign(m21 + m12) + xz = np.sign(m02 + m20) + xy = np.sign(m01 + m10) + ones = np.ones_like(xw) + sign = np.where( + max_idx[..., None] == 0, + np.stack([ones, xw, yw, zw], axis=-1), + np.where( + max_idx[..., None] == 1, + np.stack([xw, ones, xy, xz], axis=-1), + np.where( + max_idx[..., None] == 2, + np.stack([yw, xy, ones, yz], axis=-1), + np.stack([zw, xz, yz, ones], axis=-1) + ) + ) + ) + quat = sign * wxyz + quat = quat / np.linalg.norm(quat, axis=-1, keepdims=True).clip(min=eps) + return quat + + +def extrinsics_to_essential(extrinsics: np.ndarray): + """ + extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0` + + Args: + extrinsics (np.ndaray): [..., 4, 4] extrinsics matrix + + Returns: + (np.ndaray): [..., 3, 3] essential matrix + """ + assert extrinsics.shape[-2:] == (4, 4) + R = extrinsics[..., :3, :3] + t = extrinsics[..., :3, 3] + zeros = np.zeros_like(t[..., 0]) + t_x = np.stack([ + zeros, -t[..., 2], t[..., 1], + t[..., 2], zeros, -t[..., 0], + -t[..., 1], t[..., 0], zeros + ]).reshape(*t.shape[:-1], 3, 3) + return t_x @ R + + +def euler_axis_angle_rotation(axis: str, angle: np.ndarray) -> np.ndarray: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = np.cos(angle) + sin = np.sin(angle) + one = np.ones_like(angle) + zero = np.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return np.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: np.ndarray, convention: str = 'XYZ') -> np.ndarray: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as ndarray of shape (..., 3), XYZ + convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply. + + Returns: + Rotation matrices as ndarray of shape (..., 3, 3). + """ + if euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)]) + for c in convention + ] + return matrices[2] @ matrices[1] @ matrices[0] + + +def skew_symmetric(v: np.ndarray): + "Skew symmetric matrix from a 3D vector" + assert v.shape[-1] == 3, "v must be 3D" + x, y, z = v[..., 0], v[..., 1], v[..., 2] + zeros = np.zeros_like(x) + return np.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros, + ], axis=-1).reshape(*v.shape[:-1], 3, 3) + + +def rotation_matrix_from_vectors(v1: np.ndarray, v2: np.ndarray): + "Rotation matrix that rotates v1 to v2" + I = np.eye(3, dtype=v1.dtype) + v1 = v1 / np.linalg.norm(v1, axis=-1) + v2 = v2 / np.linalg.norm(v2, axis=-1) + v = np.cross(v1, v2, axis=-1) + c = np.sum(v1 * v2, axis=-1) + K = skew_symmetric(v) + R = I + K + (1 / (1 + c)).astype(v1.dtype)[None, None] * (K @ K) # Avoid numpy's default type casting for scalars + return R + + +def axis_angle_to_matrix(axis_angle: np.ndarray, eps: float = 1e-12) -> np.ndarray: + """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation + + Args: + axis_angle (np.ndarray): shape (..., 3), axis-angle vcetors + + Returns: + np.ndarray: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters + """ + batch_shape = axis_angle.shape[:-1] + dtype = axis_angle.dtype + + angle = np.linalg.norm(axis_angle, axis=-1, keepdims=True) + axis = axis_angle / (angle + eps) + + cos = np.cos(angle)[..., None, :] + sin = np.sin(angle)[..., None, :] + + rx, ry, rz = np.split(axis, 3, axis=-1) + zeros = np.zeros((*batch_shape, 1), dtype=dtype) + K = np.concatenate([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], axis=-1).reshape((*batch_shape, 3, 3)) + + ident = np.eye(3, dtype=dtype) + rot_mat = ident + sin * K + (1 - cos) * (K @ K) + return rot_mat + + +def ray_intersection(p1: np.ndarray, d1: np.ndarray, p2: np.ndarray, d2: np.ndarray): + """ + Compute the intersection/closest point of two D-dimensional rays + If the rays are intersecting, the closest point is the intersection point. + + Args: + p1 (np.ndarray): (..., D) origin of ray 1 + d1 (np.ndarray): (..., D) direction of ray 1 + p2 (np.ndarray): (..., D) origin of ray 2 + d2 (np.ndarray): (..., D) direction of ray 2 + + Returns: + (np.ndarray): (..., N) intersection point + """ + p1, d1, p2, d2 = np.broadcast_arrays(p1, d1, p2, d2) + dtype = p1.dtype + dim = p1.shape[-1] + d = np.stack([d1, d2], axis=-2) # (..., 2, D) + p = np.stack([p1, p2], axis=-2) # (..., 2, D) + A = np.concatenate([ + (np.eye(dim, dtype=dtype) * np.ones((*p.shape[:-2], 2, 1, 1))).reshape(*d.shape[:-2], 2 * dim, dim), # (..., 2 * D, D) + -(np.eye(2, dtype=dtype)[..., None] * d[..., None, :]).swapaxes(-2, -1).reshape(*d.shape[:-2], 2 * dim, 2) # (..., 2 * D, 2) + ], axis=-1) # (..., 2 * D, D + 2) + b = p.reshape(*p.shape[:-2], 2 * dim) # (..., 2 * D) + x = np.linalg.solve(A.swapaxes(-1, -2) @ A + 1e-12 * np.eye(dim + 2, dtype=dtype), (A.swapaxes(-1, -2) @ b[..., :, None])[..., 0]) + return x[..., :dim], (x[..., dim], x[..., dim + 1]) + + +def se3_matrix(R: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Convert rotation matrix and translation vector to 4x4 transformation matrix. + + Args: + R (np.ndarray): [..., 3, 3] rotation matrix + t (np.ndarray): [..., 3] translation vector + + Returns: + np.ndarray: [..., 4, 4] transformation matrix + """ + assert R.shape[:-2] == t.shape[:-1] + assert R.shape[-1] == 3 and R.shape[-2] == 3 + return np.concatenate([ + np.concatenate([R, t[..., None]], axis=-1), + np.concatenate([np.zeros_like(t), np.ones_like(t[..., :1])], axis=-1)[..., None, :] + ], axis=-2) + + +def slerp_quaternion(q1: np.ndarray, q2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Spherical linear interpolation between two unit quaternions. + + Args: + q1 (np.ndarray): [..., d] unit vector 1 + q2 (np.ndarray): [..., d] unit vector 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., 3] interpolated unit vector + """ + q1 = q1 / np.linalg.norm(q1, axis=-1, keepdims=True) + q2 = q2 / np.linalg.norm(q2, axis=-1, keepdims=True) + dot = np.sum(q1 * q2, axis=-1, keepdims=True) + + dot = np.where(dot < 0, -dot, dot) # handle negative dot product + + dot = np.minimum(dot, 1.) + theta = np.arccos(dot) * t + + q_ortho = q2 - q1 * dot + q_ortho = q_ortho / np.maximum(np.linalg.norm(q_ortho, axis=-1, keepdims=True), 1e-12) + q = q1 * np.cos(theta) + q_ortho * np.sin(theta) + return q + + +def slerp_rotation_matrix(R1: np.ndarray, R2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Spherical linear interpolation between two rotation matrices. + + Args: + R1 (np.ndarray): [..., 3, 3] rotation matrix 1 + R2 (np.ndarray): [..., 3, 3] rotation matrix 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., 3, 3] interpolated rotation matrix + """ + quat1 = matrix_to_quaternion(R1) + quat2 = matrix_to_quaternion(R2) + quat = slerp_quaternion(quat1, quat2, t) + return quaternion_to_matrix(quat) + + +def slerp_vector(v1: np.ndarray, v2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Spherical linear interpolation between two unit vectors. The vectors are assumed to be normalized. + + Args: + v1 (np.ndarray): [..., d] unit vector 1 + v2 (np.ndarray): [..., d] unit vector 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., d] interpolated unit vector + """ + dot = np.sum(v1 * v2, axis=-1, keepdims=True) + + dot = np.minimum(dot, 1.) + theta = np.arccos(dot) * t + + v_ortho = v2 - v1 * dot + v_ortho = v_ortho / np.maximum(np.linalg.norm(v_ortho, axis=-1, keepdims=True), 1e-12) + v = v1 * np.cos(theta) + v_ortho * np.sin(theta) + return v + + +def lerp(x1: np.ndarray, x2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Linear interpolation between two vectors. + + Args: + x1 (np.ndarray): [..., d] vector 1 + x2 (np.ndarray): [..., d] vector 2 + t (np.ndarray): [...] interpolation parameter. [0, 1] for interpolation between x1 and x2, otherwise for extrapolation. + + Returns: + np.ndarray: [..., d] interpolated vector + """ + return x1 + np.asarray(t)[..., None] * (x2 - x1) + + +def lerp_se3_matrix(T1: np.ndarray, T2: np.ndarray, t: np.ndarray) -> np.ndarray: + """ + Linear interpolation between two SE(3) matrices. + + Args: + T1 (np.ndarray): [..., 4, 4] SE(3) matrix 1 + T2 (np.ndarray): [..., 4, 4] SE(3) matrix 2 + t (np.ndarray): [...] interpolation parameter in [0, 1] + + Returns: + np.ndarray: [..., 4, 4] interpolated SE(3) matrix + """ + R1 = T1[..., :3, :3] + R2 = T2[..., :3, :3] + trans1 = T1[..., :3, 3] + trans2 = T2[..., :3, 3] + R = slerp_rotation_matrix(R1, R2, t) + trans = lerp(trans1, trans2, t) + return se3_matrix(R, trans) + + +def piecewise_lerp(x: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray: + """ + Linear spline interpolation. + + ### Parameters: + - `x`: np.ndarray, shape (n, d): the values of data points. + - `t`: np.ndarray, shape (n,): the times of the data points. + - `s`: np.ndarray, shape (m,): the times to be interpolated. + - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + + ### Returns: + - `y`: np.ndarray, shape (..., m, d): the interpolated values. + """ + i = np.searchsorted(t, s, side='left') + if extrapolation_mode == 'constant': + prev = np.clip(i - 1, 0, len(t) - 1) + suc = np.clip(i, 0, len(t) - 1) + elif extrapolation_mode == 'linear': + prev = np.clip(i - 1, 0, len(t) - 2) + suc = np.clip(i, 1, len(t) - 1) + else: + raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}') + + u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12) + y = lerp(x[prev], x[suc], u) + + return y + + +def piecewise_lerp_se3_matrix(T: np.ndarray, t: np.ndarray, s: np.ndarray, extrapolation_mode: Literal['constant', 'linear'] = 'constant') -> np.ndarray: + """ + Linear spline interpolation for SE(3) matrices. + + ### Parameters: + - `T`: np.ndarray, shape (n, 4, 4): the SE(3) matrices. + - `t`: np.ndarray, shape (n,): the times of the data points. + - `s`: np.ndarray, shape (m,): the times to be interpolated. + - `extrapolation_mode`: str, the mode of extrapolation. 'constant' means extrapolate the boundary values, 'linear' means extrapolate linearly. + + ### Returns: + - `T_interp`: np.ndarray, shape (..., m, 4, 4): the interpolated SE(3) matrices. + """ + i = np.searchsorted(t, s, side='left') + if extrapolation_mode == 'constant': + prev = np.clip(i - 1, 0, len(t) - 1) + suc = np.clip(i, 0, len(t) - 1) + elif extrapolation_mode == 'linear': + prev = np.clip(i - 1, 0, len(t) - 2) + suc = np.clip(i, 1, len(t) - 1) + else: + raise ValueError(f'Invalid extrapolation_mode: {extrapolation_mode}') + + u = (s - t[prev]) / np.maximum(t[suc] - t[prev], 1e-12) + T = lerp_se3_matrix(T[prev], T[suc], u) + + return T + + +def apply_transform(T: np.ndarray, x: np.ndarray) -> np.ndarray: + """ + Apply SE(3) transformation to a point or a set of points. + + ### Parameters: + - `T`: np.ndarray, shape (..., 4, 4): the SE(3) matrix. + - `x`: np.ndarray, shape (..., 3): the point or a set of points to be transformed. + + ### Returns: + - `x_transformed`: np.ndarray, shape (..., 3): the transformed point or a set of points. + """ + x = np.asarray(x) + assert x.shape[-1] == 3 + T = np.asarray(T) + assert T.shape[-2:] == (4, 4) + x_transformed = (T[..., :3, :3] @ x[..., :, None]) + T[..., :3, 3][..., None] + return x_transformed[..., 0] \ No newline at end of file diff --git a/utils3d/numpy/utils.py b/utils3d/numpy/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..1a33c3c7288b0b7a63566e2255cc45b929e67b49 --- /dev/null +++ b/utils3d/numpy/utils.py @@ -0,0 +1,557 @@ +import numpy as np +from typing import * +from numbers import Number + +from ._helpers import batched +from . import transforms +from . import mesh + +__all__ = [ + 'sliding_window_1d', + 'sliding_window_nd', + 'sliding_window_2d', + 'max_pool_1d', + 'max_pool_2d', + 'max_pool_nd', + 'depth_edge', + 'depth_aliasing', + 'interpolate', + 'image_scrcoord', + 'image_uv', + 'image_pixel_center', + 'image_pixel', + 'image_mesh', + 'image_mesh_from_depth', + 'depth_to_normal', + 'point_to_normal', + 'chessboard', + 'cube', + 'square', + 'camera_frustum', + 'to4x4' +] + + +def sliding_window_1d(x: np.ndarray, window_size: int, stride: int, axis: int = -1): + """ + Return x view of the input array with x sliding window of the given kernel size and stride. + The sliding window is performed over the given axis, and the window dimension is append to the end of the output array's shape. + + Args: + x (np.ndarray): input array with shape (..., axis_size, ...) + kernel_size (int): size of the sliding window + stride (int): stride of the sliding window + axis (int): axis to perform sliding window over + + Returns: + a_sliding (np.ndarray): view of the input array with shape (..., n_windows, ..., kernel_size), where n_windows = (axis_size - kernel_size + 1) // stride + """ + assert x.shape[axis] >= window_size, f"kernel_size ({window_size}) is larger than axis_size ({x.shape[axis]})" + axis = axis % x.ndim + shape = (*x.shape[:axis], (x.shape[axis] - window_size + 1) // stride, *x.shape[axis + 1:], window_size) + strides = (*x.strides[:axis], stride * x.strides[axis], *x.strides[axis + 1:], x.strides[axis]) + x_sliding = np.lib.stride_tricks.as_strided(x, shape=shape, strides=strides) + return x_sliding + + +def sliding_window_nd(x: np.ndarray, window_size: Tuple[int,...], stride: Tuple[int,...], axis: Tuple[int,...]) -> np.ndarray: + axis = [axis[i] % x.ndim for i in range(len(axis))] + for i in range(len(axis)): + x = sliding_window_1d(x, window_size[i], stride[i], axis[i]) + return x + + +def sliding_window_2d(x: np.ndarray, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)) -> np.ndarray: + if isinstance(window_size, int): + window_size = (window_size, window_size) + if isinstance(stride, int): + stride = (stride, stride) + return sliding_window_nd(x, window_size, stride, axis) + + +def max_pool_1d(x: np.ndarray, kernel_size: int, stride: int, padding: int = 0, axis: int = -1): + axis = axis % x.ndim + if padding > 0: + fill_value = np.nan if x.dtype.kind == 'f' else np.iinfo(x.dtype).min + padding_arr = np.full((*x.shape[:axis], padding, *x.shape[axis + 1:]), fill_value=fill_value, dtype=x.dtype) + x = np.concatenate([padding_arr, x, padding_arr], axis=axis) + a_sliding = sliding_window_1d(x, kernel_size, stride, axis) + max_pool = np.nanmax(a_sliding, axis=-1) + return max_pool + + +def max_pool_nd(x: np.ndarray, kernel_size: Tuple[int,...], stride: Tuple[int,...], padding: Tuple[int,...], axis: Tuple[int,...]) -> np.ndarray: + for i in range(len(axis)): + x = max_pool_1d(x, kernel_size[i], stride[i], padding[i], axis[i]) + return x + + +def max_pool_2d(x: np.ndarray, kernel_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], padding: Union[int, Tuple[int, int]], axis: Tuple[int, int] = (-2, -1)): + if isinstance(kernel_size, Number): + kernel_size = (kernel_size, kernel_size) + if isinstance(stride, Number): + stride = (stride, stride) + if isinstance(padding, Number): + padding = (padding, padding) + axis = tuple(axis) + return max_pool_nd(x, kernel_size, stride, padding, axis) + + +def depth_edge(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray: + """ + Compute the edge mask of x depth map. The edge is defined as the pixels whose neighbors have x large difference in depth. + + Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + if mask is None: + diff = (max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)) + else: + diff = (max_pool_2d(np.where(mask, depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) + max_pool_2d(np.where(mask, -depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2)) + + edge = np.zeros_like(depth, dtype=bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= diff / depth > rtol + return edge + + +def depth_aliasing(depth: np.ndarray, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: np.ndarray = None) -> np.ndarray: + """ + Compute the map that indicates the aliasing of x depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. + Args: + depth (np.ndarray): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (np.ndarray): shape (..., height, width) of dtype torch.bool + """ + if mask is None: + diff_max = max_pool_2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = max_pool_2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth + else: + diff_max = max_pool_2d(np.where(mask, depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = max_pool_2d(np.where(mask, -depth, -np.inf), kernel_size, stride=1, padding=kernel_size // 2) + depth + diff = np.minimum(diff_max, diff_min) + + edge = np.zeros_like(depth, dtype=bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= diff / depth > rtol + return edge + +def point_to_normal(point: np.ndarray, mask: np.ndarray = None) -> np.ndarray: + """ + Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + point (np.ndarray): shape (height, width, 3), point map + Returns: + normal (np.ndarray): shape (height, width, 3), normal map. + """ + height, width = point.shape[-3:-1] + has_mask = mask is not None + + if mask is None: + mask = np.ones_like(point[..., 0], dtype=bool) + mask_pad = np.zeros((height + 2, width + 2), dtype=bool) + mask_pad[1:-1, 1:-1] = mask + mask = mask_pad + + pts = np.zeros((height + 2, width + 2, 3), dtype=point.dtype) + pts[1:-1, 1:-1, :] = point + up = pts[:-2, 1:-1, :] - pts[1:-1, 1:-1, :] + left = pts[1:-1, :-2, :] - pts[1:-1, 1:-1, :] + down = pts[2:, 1:-1, :] - pts[1:-1, 1:-1, :] + right = pts[1:-1, 2:, :] - pts[1:-1, 1:-1, :] + normal = np.stack([ + np.cross(up, left, axis=-1), + np.cross(left, down, axis=-1), + np.cross(down, right, axis=-1), + np.cross(right, up, axis=-1), + ]) + normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) + valid = np.stack([ + mask[:-2, 1:-1] & mask[1:-1, :-2], + mask[1:-1, :-2] & mask[2:, 1:-1], + mask[2:, 1:-1] & mask[1:-1, 2:], + mask[1:-1, 2:] & mask[:-2, 1:-1], + ]) & mask[None, 1:-1, 1:-1] + normal = (normal * valid[..., None]).sum(axis=0) + normal = normal / (np.linalg.norm(normal, axis=-1, keepdims=True) + 1e-12) + + if has_mask: + return normal, valid.any(axis=0) + else: + return normal + + +def depth_to_normal(depth: np.ndarray, intrinsics: np.ndarray, mask: np.ndarray = None) -> np.ndarray: + """ + Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + depth (np.ndarray): shape (height, width), linear depth map + intrinsics (np.ndarray): shape (3, 3), intrinsics matrix + Returns: + normal (np.ndarray): shape (height, width, 3), normal map. + """ + has_mask = mask is not None + + height, width = depth.shape[-2:] + if mask is None: + mask = np.ones_like(depth, dtype=bool) + + uv = image_uv(width=width, height=height, dtype=np.float32) + pts = transforms.unproject_cv(uv, depth, intrinsics=intrinsics, extrinsics=None) + + return point_to_normal(pts, mask) + +def interpolate(bary: np.ndarray, tri_id: np.ndarray, attr: np.ndarray, faces: np.ndarray) -> np.ndarray: + """Interpolate with given barycentric coordinates and triangle indices + + Args: + bary (np.ndarray): shape (..., 3), barycentric coordinates + tri_id (np.ndarray): int array of shape (...), triangle indices + attr (np.ndarray): shape (N, M), vertices attributes + faces (np.ndarray): int array of shape (T, 3), face vertex indices + + Returns: + np.ndarray: shape (..., M) interpolated result + """ + faces_ = np.concatenate([np.zeros((1, 3), dtype=faces.dtype), faces + 1], axis=0) + attr_ = np.concatenate([np.zeros((1, attr.shape[1]), dtype=attr.dtype), attr], axis=0) + return np.sum(bary[..., None] * attr_[faces_[tri_id + 1]], axis=-2) + + +def image_scrcoord( + width: int, + height: int, +) -> np.ndarray: + """ + Get OpenGL's screen space coordinates, ranging in [0, 1]. + [0, 0] is the bottom-left corner of the image. + + Args: + width (int): image width + height (int): image height + + Returns: + (np.ndarray): shape (height, width, 2) + """ + x, y = np.meshgrid( + np.linspace(0.5 / width, 1 - 0.5 / width, width, dtype=np.float32), + np.linspace(1 - 0.5 / height, 0.5 / height, height, dtype=np.float32), + indexing='xy' + ) + return np.stack([x, y], axis=2) + + +def image_uv( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: np.dtype = np.float32 +) -> np.ndarray: + """ + Get image space UV grid, ranging in [0, 1]. + + >>> image_uv(10, 10): + [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = np.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, dtype=dtype) + v = np.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + return np.stack([u, v], axis=2) + + +def image_pixel_center( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: np.dtype = np.float32 +) -> np.ndarray: + """ + Get image pixel center coordinates, ranging in [0, width] and [0, height]. + `image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`. + + >>> image_pixel_center(10, 10): + [[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]], + [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]], + ... ... ... + [[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = np.linspace(left + 0.5, right - 0.5, right - left, dtype=dtype) + v = np.linspace(top + 0.5, bottom - 0.5, bottom - top, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + return np.stack([u, v], axis=2) + +def image_pixel( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: np.dtype = np.int32 +) -> np.ndarray: + """ + Get image pixel coordinates grid, ranging in [0, width - 1] and [0, height - 1]. + `image[i, j]` has pixel center coordinates `(j, i)`. + + >>> image_pixel_center(10, 10): + [[[0, 0], [1, 0], ..., [9, 0]], + [[0, 1.5], [1, 1], ..., [9, 1]], + ... ... ... + [[0, 9.5], [1, 9], ..., [9, 9 ]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = np.arange(left, right, dtype=dtype) + v = np.arange(top, bottom, dtype=dtype) + u, v = np.meshgrid(u, v, indexing='xy') + return np.stack([u, v], axis=2) + + +def image_mesh( + height: int, + width: int, + mask: np.ndarray = None, + tri: bool = False +) -> Tuple[np.ndarray, np.ndarray]: + """ + Get x quad mesh regarding image pixel uv coordinates as vertices and image grid as faces. + + Args: + width (int): image width + height (int): image height + mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. + + Returns: + uv (np.ndarray): uv corresponding to pixels as described in image_uv() + faces (np.ndarray): quad faces connecting neighboring pixels + indices (np.ndarray, optional): indices of vertices in the original mesh + """ + if mask is not None: + assert mask.shape[0] == height and mask.shape[1] == width + assert mask.dtype == np.bool_ + uv = image_uv(height, width).reshape((-1, 2)) + row_faces = np.stack([np.arange(0, width - 1, dtype=np.int32), np.arange(width, 2 * width - 1, dtype=np.int32), np.arange(1 + width, 2 * width, dtype=np.int32), np.arange(1, width, dtype=np.int32)], axis=1) + faces = (np.arange(0, (height - 1) * width, width, dtype=np.int32)[:, None, None] + row_faces[None, :, :]).reshape((-1, 4)) + if mask is not None: + quad_mask = (mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:]).ravel() + faces = faces[quad_mask] + faces, uv, indices = mesh.remove_unreferenced_vertices(faces, uv, return_indices=True) + if tri: + faces = mesh.triangulate(faces) + return uv, faces, indices + if tri: + faces = mesh.triangulate(faces) + return uv, faces + + +def image_mesh_from_depth( + depth: np.ndarray, + extrinsics: np.ndarray = None, + intrinsics: np.ndarray = None, + *vertice_attrs: np.ndarray, + atol: float = None, + rtol: float = None, + remove_by_depth: bool = False, + return_uv: bool = False, + return_indices: bool = False +) -> Tuple[np.ndarray, ...]: + """ + Get x triangle mesh by lifting depth map to 3D. + + Args: + depth (np.ndarray): [H, W] depth map + extrinsics (np.ndarray, optional): [4, 4] extrinsics matrix. Defaults to None. + intrinsics (np.ndarray, optional): [3, 3] intrinsics matrix. Defaults to None. + *vertice_attrs (np.ndarray): [H, W, C] vertex attributes. Defaults to None. + atol (float, optional): absolute tolerance. Defaults to None. + rtol (float, optional): relative tolerance. Defaults to None. + triangles with vertices having depth difference larger than atol + rtol * depth will be marked. + remove_by_depth (bool, optional): whether to remove triangles with large depth difference. Defaults to True. + return_uv (bool, optional): whether to return uv coordinates. Defaults to False. + return_indices (bool, optional): whether to return indices of vertices in the original mesh. Defaults to False. + + Returns: + vertices (np.ndarray): [N, 3] vertices + faces (np.ndarray): [T, 3] faces + *vertice_attrs (np.ndarray): [N, C] vertex attributes + image_uv (np.ndarray, optional): [N, 2] uv coordinates + ref_indices (np.ndarray, optional): [N] indices of vertices in the original mesh + """ + height, width = depth.shape + image_uv, image_face = image_mesh(height, width) + depth = depth.reshape(-1) + pts = transforms.unproject_cv(image_uv, depth, extrinsics, intrinsics) + image_face = mesh.triangulate(image_face, vertices=pts) + ref_indices = None + ret = [] + if atol is not None or rtol is not None: + atol = 0 if atol is None else atol + rtol = 0 if rtol is None else rtol + mean = depth[image_face].mean(axis=1) + diff = np.max(np.abs(depth[image_face] - depth[image_face[:, [1, 2, 0]]]), axis=1) + mask = (diff <= atol + rtol * mean) + image_face_ = image_face[mask] + image_face_, ref_indices = mesh.remove_unreferenced_vertices(image_face_, return_indices=True) + + remove = remove_by_depth and ref_indices is not None + if remove: + pts = pts[ref_indices] + image_face = image_face_ + ret += [pts, image_face] + for attr in vertice_attrs: + ret.append(attr.reshape(-1, attr.shape[-1]) if not remove else attr.reshape(-1, attr.shape[-1])[ref_indices]) + if return_uv: + ret.append(image_uv if not remove else image_uv[ref_indices]) + if return_indices and ref_indices is not None: + ret.append(ref_indices) + return tuple(ret) + + +def chessboard(width: int, height: int, grid_size: int, color_a: np.ndarray, color_b: np.ndarray) -> np.ndarray: + """get x chessboard image + + Args: + width (int): image width + height (int): image height + grid_size (int): size of chessboard grid + color_a (np.ndarray): color of the grid at the top-left corner + color_b (np.ndarray): color in complementary grid cells + + Returns: + image (np.ndarray): shape (height, width, channels), chessboard image + """ + x = np.arange(width) // grid_size + y = np.arange(height) // grid_size + mask = (x[None, :] + y[:, None]) % 2 + image = (1 - mask[..., None]) * color_a + mask[..., None] * color_b + return image + + +def square(tri: bool = False) -> Tuple[np.ndarray, np.ndarray]: + """ + Get a square mesh of area 1 centered at origin in the xy-plane. + + ### Returns + vertices (np.ndarray): shape (4, 3) + faces (np.ndarray): shape (1, 4) + """ + vertices = np.array([ + [-0.5, 0.5, 0], [0.5, 0.5, 0], [0.5, -0.5, 0], [-0.5, -0.5, 0] # v0-v1-v2-v3 + ], dtype=np.float32) + if tri: + faces = np.array([[0, 1, 2], [0, 2, 3]], dtype=np.int32) + else: + faces = np.array([[0, 1, 2, 3]], dtype=np.int32) + return vertices, faces + + +def cube(tri: bool = False) -> Tuple[np.ndarray, np.ndarray]: + """ + Get x cube mesh of size 1 centered at origin. + + ### Parameters + tri (bool, optional): return triangulated mesh. Defaults to False, which returns quad mesh. + + ### Returns + vertices (np.ndarray): shape (8, 3) + faces (np.ndarray): shape (12, 3) + """ + vertices = np.array([ + [-0.5, 0.5, 0.5], [0.5, 0.5, 0.5], [0.5, -0.5, 0.5], [-0.5, -0.5, 0.5], # v0-v1-v2-v3 + [-0.5, 0.5, -0.5], [0.5, 0.5, -0.5], [0.5, -0.5, -0.5], [-0.5, -0.5, -0.5] # v4-v5-v6-v7 + ], dtype=np.float32).reshape((-1, 3)) + + faces = np.array([ + [0, 1, 2, 3], # v0-v1-v2-v3 (front) + [4, 5, 1, 0], # v4-v5-v1-v0 (top) + [3, 2, 6, 7], # v3-v2-v6-v7 (bottom) + [5, 4, 7, 6], # v5-v4-v7-v6 (back) + [1, 5, 6, 2], # v1-v5-v6-v2 (right) + [4, 0, 3, 7] # v4-v0-v3-v7 (left) + ], dtype=np.int32) + + if tri: + faces = mesh.triangulate(faces, vertices=vertices) + + return vertices, faces + + +def camera_frustum(extrinsics: np.ndarray, intrinsics: np.ndarray, depth: float = 1.0) -> Tuple[np.ndarray, np.ndarray, np.ndarray]: + """ + Get x triangle mesh of camera frustum. + """ + assert extrinsics.shape == (4, 4) and intrinsics.shape == (3, 3) + vertices = transforms.unproject_cv( + np.array([[0, 0], [0, 0], [0, 1], [1, 1], [1, 0]], dtype=np.float32), + np.array([0] + [depth] * 4, dtype=np.float32), + extrinsics, + intrinsics + ).astype(np.float32) + edges = np.array([ + [0, 1], [0, 2], [0, 3], [0, 4], + [1, 2], [2, 3], [3, 4], [4, 1] + ], dtype=np.int32) + faces = np.array([ + [0, 1, 2], + [0, 2, 3], + [0, 3, 4], + [0, 4, 1], + [1, 2, 3], + [1, 3, 4] + ], dtype=np.int32) + return vertices, edges, faces + diff --git a/utils3d/torch/__init__.py b/utils3d/torch/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c0d10527e924a2fc62ab0862e90894b1aef6d143 --- /dev/null +++ b/utils3d/torch/__init__.py @@ -0,0 +1,133 @@ +import importlib +import itertools +import torch + + +__modules_all__ = { + 'mesh': [ + 'triangulate', + 'compute_face_normal', + 'compute_face_angles', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'remove_unreferenced_vertices', + 'remove_corrupted_faces', + 'merge_duplicate_vertices', + 'subdivide_mesh_simple', + 'compute_face_tbn', + 'compute_vertex_tbn', + 'laplacian', + 'laplacian_smooth_mesh', + 'taubin_smooth_mesh', + 'laplacian_hc_smooth_mesh', + ], + 'nerf': [ + 'get_rays', + 'get_image_rays', + 'get_mipnerf_cones', + 'volume_rendering', + 'bin_sample', + 'importance_sample', + 'nerf_render_rays', + 'mipnerf_render_rays', + 'nerf_render_view', + 'mipnerf_render_view', + 'InstantNGP', + ], + 'utils': [ + 'sliding_window_1d', + 'sliding_window_2d', + 'sliding_window_nd', + 'image_uv', + 'image_pixel_center', + 'image_mesh', + 'chessboard', + 'depth_edge', + 'depth_aliasing', + 'image_mesh_from_depth', + 'point_to_normal', + 'depth_to_normal', + 'masked_min', + 'masked_max', + 'bounding_rect' + ], + 'transforms': [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'intrinsics_from_fov_xy', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'project_gl', + 'project_cv', + 'unproject_gl', + 'unproject_cv', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'matrix_to_euler_angles', + 'matrix_to_quaternion', + 'quaternion_to_matrix', + 'matrix_to_axis_angle', + 'axis_angle_to_matrix', + 'axis_angle_to_quaternion', + 'quaternion_to_axis_angle', + 'slerp', + 'interpolate_extrinsics', + 'interpolate_view', + 'extrinsics_to_essential', + 'to4x4', + 'rotation_matrix_2d', + 'rotate_2d', + 'translate_2d', + 'scale_2d', + 'apply_2d', + ], + 'rasterization': [ + 'RastContext', + 'rasterize_triangle_faces', + 'warp_image_by_depth', + 'warp_image_by_forward_flow', + ], +} + + +__all__ = list(itertools.chain(*__modules_all__.values())) + +def __getattr__(name): + try: + return globals()[name] + except KeyError: + pass + + try: + module_name = next(m for m in __modules_all__ if name in __modules_all__[m]) + except StopIteration: + raise AttributeError(f"module '{__name__}' has no attribute '{name}'") + module = importlib.import_module(f'.{module_name}', __name__) + for key in __modules_all__[module_name]: + globals()[key] = getattr(module, key) + + return globals()[name] + + +if __name__ == '__main__': + from .transforms import * + from .mesh import * + from .utils import * + from .nerf import * + from .rasterization import * \ No newline at end of file diff --git a/utils3d/torch/_helpers.py b/utils3d/torch/_helpers.py new file mode 100644 index 0000000000000000000000000000000000000000..ba742b26618d643afb4b00f35e7d337789345174 --- /dev/null +++ b/utils3d/torch/_helpers.py @@ -0,0 +1,102 @@ +# decorator +import torch +from numbers import Number +import inspect +from functools import wraps + + +def get_device(args, kwargs): + device = None + for arg in (list(args) + list(kwargs.values())): + if isinstance(arg, torch.Tensor): + if device is None: + device = arg.device + elif device != arg.device: + raise ValueError("All tensors must be on the same device.") + return device + + +def get_args_order(func, args, kwargs): + """ + Get the order of the arguments of a function. + """ + names = inspect.getfullargspec(func).args + names_idx = {name: i for i, name in enumerate(names)} + args_order = [] + kwargs_order = {} + for name, arg in kwargs.items(): + if name in names: + kwargs_order[name] = names_idx[name] + names.remove(name) + for i, arg in enumerate(args): + if i < len(names): + args_order.append(names_idx[names[i]]) + return args_order, kwargs_order + + +def broadcast_args(args, kwargs, args_dim, kwargs_dim): + spatial = [] + for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())): + if isinstance(arg, torch.Tensor) and arg_dim is not None: + arg_spatial = arg.shape[:arg.ndim-arg_dim] + if len(arg_spatial) > len(spatial): + spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial + for j in range(len(arg_spatial)): + if spatial[-j] < arg_spatial[-j]: + if spatial[-j] == 1: + spatial[-j] = arg_spatial[-j] + else: + raise ValueError("Cannot broadcast arguments.") + for i, arg in enumerate(args): + if isinstance(arg, torch.Tensor) and args_dim[i] is not None: + args[i] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]]) + for key, arg in kwargs.items(): + if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: + kwargs[key] = torch.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + return args, kwargs, spatial + + +def batched(*dims): + """ + Decorator that allows a function to be called with batched arguments. + """ + def decorator(func): + @wraps(func) + def wrapper(*args, device=torch.device('cpu'), **kwargs): + args = list(args) + # get arguments dimensions + args_order, kwargs_order = get_args_order(func, args, kwargs) + args_dim = [dims[i] for i in args_order] + kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()} + # convert to torch tensor + device = get_device(args, kwargs) or device + for i, arg in enumerate(args): + if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None: + args[i] = torch.tensor(arg, device=device) + for key, arg in kwargs.items(): + if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None: + kwargs[key] = torch.tensor(arg, device=device) + # broadcast arguments + args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim) + for i, (arg, arg_dim) in enumerate(zip(args, args_dim)): + if isinstance(arg, torch.Tensor) and arg_dim is not None: + args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]]) + for key, arg in kwargs.items(): + if isinstance(arg, torch.Tensor) and kwargs_dim[key] is not None: + kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]]) + # call function + results = func(*args, **kwargs) + type_results = type(results) + results = list(results) if isinstance(results, (tuple, list)) else [results] + # restore spatial dimensions + for i, result in enumerate(results): + results[i] = result.reshape([*spatial, *result.shape[1:]]) + if type_results == tuple: + results = tuple(results) + elif type_results == list: + results = list(results) + else: + results = results[0] + return results + return wrapper + return decorator \ No newline at end of file diff --git a/utils3d/torch/mesh.py b/utils3d/torch/mesh.py new file mode 100644 index 0000000000000000000000000000000000000000..ab4eef469fe72fd627c932cc3cb75303e4e316ff --- /dev/null +++ b/utils3d/torch/mesh.py @@ -0,0 +1,401 @@ +import torch +import torch.nn.functional as F +from typing import * +from ._helpers import batched + + +__all__ = [ + 'triangulate', + 'compute_face_normal', + 'compute_face_angles', + 'compute_vertex_normal', + 'compute_vertex_normal_weighted', + 'remove_unreferenced_vertices', + 'remove_corrupted_faces', + 'merge_duplicate_vertices', + 'subdivide_mesh_simple', + 'compute_face_tbn', + 'compute_vertex_tbn', + 'laplacian', + 'laplacian_smooth_mesh', + 'taubin_smooth_mesh', + 'laplacian_hc_smooth_mesh', +] + + +def triangulate( + faces: torch.Tensor, + vertices: torch.Tensor = None, + backslash: bool = None +) -> torch.Tensor: + """ + Triangulate a polygonal mesh. + + Args: + faces (torch.Tensor): [..., L, P] polygonal faces + vertices (torch.Tensor, optional): [..., N, 3] 3-dimensional vertices. + If given, the triangulation is performed according to the distance + between vertices. Defaults to None. + backslash (torch.Tensor, optional): [..., L] boolean array indicating + how to triangulate the quad faces. Defaults to None. + + + Returns: + (torch.Tensor): [L * (P - 2), 3] triangular faces + """ + if faces.shape[-1] == 3: + return faces + P = faces.shape[-1] + if vertices is not None: + assert faces.shape[-1] == 4, "now only support quad mesh" + if backslash is None: + faces_idx = faces.long() + backslash = torch.norm(vertices[faces_idx[..., 0]] - vertices[faces_idx[..., 2]], p=2, dim=-1) < \ + torch.norm(vertices[faces_idx[..., 1]] - vertices[faces_idx[..., 3]], p=2, dim=-1) + if backslash is None: + loop_indice = torch.stack([ + torch.zeros(P - 2, dtype=int), + torch.arange(1, P - 1, 1, dtype=int), + torch.arange(2, P, 1, dtype=int) + ], axis=1) + return faces[:, loop_indice].reshape(-1, 3) + else: + assert faces.shape[-1] == 4, "now only support quad mesh" + if isinstance(backslash, bool): + if backslash: + faces = faces[:, [0, 1, 2, 0, 2, 3]].reshape(-1, 3) + else: + faces = faces[:, [0, 1, 3, 3, 1, 2]].reshape(-1, 3) + else: + faces = torch.where( + backslash[:, None], + faces[:, [0, 1, 2, 0, 2, 3]], + faces[:, [0, 1, 3, 3, 1, 2]] + ).reshape(-1, 3) + return faces + + +@batched(2, None) +def compute_face_normal( + vertices: torch.Tensor, + faces: torch.Tensor +) -> torch.Tensor: + """ + Compute face normals of a triangular mesh + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [..., T, 3] triangular face indices + + Returns: + normals (torch.Tensor): [..., T, 3] face normals + """ + N = vertices.shape[0] + index = torch.arange(N)[:, None] + normal = torch.cross( + vertices[index, faces[..., 1].long()] - vertices[index, faces[..., 0].long()], + vertices[index, faces[..., 2].long()] - vertices[index, faces[..., 0].long()], + dim=-1 + ) + return F.normalize(normal, p=2, dim=-1) + + +@batched(2, None) +def compute_face_angles( + vertices: torch.Tensor, + faces: torch.Tensor +) -> torch.Tensor: + """ + Compute face angles of a triangular mesh + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + + Returns: + angles (torch.Tensor): [..., T, 3] face angles + """ + face_angles = [] + for i in range(3): + edge1 = torch.index_select(vertices, dim=-2, index=faces[:, (i + 1) % 3]) - torch.index_select(vertices, dim=-2, index=faces[:, i]) + edge2 = torch.index_select(vertices, dim=-2, index=faces[:, (i + 2) % 3]) - torch.index_select(vertices, dim=-2, index=faces[:, i]) + face_angle = torch.arccos(torch.sum(F.normalize(edge1, p=2, dim=-1) * F.normalize(edge2, p=2, dim=-1), dim=-1)) + face_angles.append(face_angle) + face_angles = torch.stack(face_angles, dim=-1) + return face_angles + + +@batched(2, None, 2) +def compute_vertex_normal( + vertices: torch.Tensor, + faces: torch.Tensor, + face_normal: torch.Tensor = None +) -> torch.Tensor: + """ + Compute vertex normals of a triangular mesh by averaging neightboring face normals + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + face_normal (torch.Tensor, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (torch.Tensor): [..., N, 3] vertex normals + """ + N = vertices.shape[0] + assert faces.shape[-1] == 3, "Only support triangular mesh" + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + face_normal = face_normal[:, :, None, :].expand(-1, -1, 3, -1).flatten(-3, -2) + faces = faces.flatten() + vertex_normal = torch.index_put(torch.zeros_like(vertices), (torch.arange(N)[:, None], faces[None, :]), face_normal, accumulate=True) + vertex_normal = F.normalize(vertex_normal, p=2, dim=-1) + return vertex_normal + + +@batched(2, None, 2) +def compute_vertex_normal_weighted( + vertices: torch.Tensor, + faces: torch.Tensor, + face_normal: torch.Tensor = None +) -> torch.Tensor: + """ + Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals + according to the angles + + Args: + vertices (torch.Tensor): [..., N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + face_normal (torch.Tensor, optional): [..., T, 3] face normals. + None to compute face normals from vertices and faces. Defaults to None. + + Returns: + normals (torch.Tensor): [..., N, 3] vertex normals + """ + N = vertices.shape[0] + if face_normal is None: + face_normal = compute_face_normal(vertices, faces) + face_angle = compute_face_angles(vertices, faces) + face_normal = face_normal[:, :, None, :].expand(-1, -1, 3, -1) * face_angle[..., None] + vertex_normal = torch.index_put(torch.zeros_like(vertices), (torch.arange(N)[:, None], faces.view(N, -1)), face_normal.view(N, -1, 3), accumulate=True) + vertex_normal = F.normalize(vertex_normal, p=2, dim=-1) + return vertex_normal + + +def remove_unreferenced_vertices( + faces: torch.Tensor, + *vertice_attrs, + return_indices: bool = False +) -> Tuple[torch.Tensor, ...]: + """ + Remove unreferenced vertices of a mesh. + Unreferenced vertices are removed, and the face indices are updated accordingly. + + Args: + faces (torch.Tensor): [T, P] face indices + *vertice_attrs: vertex attributes + + Returns: + faces (torch.Tensor): [T, P] face indices + *vertice_attrs: vertex attributes + indices (torch.Tensor, optional): [N] indices of vertices that are kept. Defaults to None. + """ + P = faces.shape[-1] + fewer_indices, inv_map = torch.unique(faces, return_inverse=True) + faces = inv_map.to(torch.int32).reshape(-1, P) + ret = [faces] + for attr in vertice_attrs: + ret.append(attr[fewer_indices]) + if return_indices: + ret.append(fewer_indices) + return tuple(ret) + + +def remove_corrupted_faces( + faces: torch.Tensor +) -> torch.Tensor: + """ + Remove corrupted faces (faces with duplicated vertices) + + Args: + faces (torch.Tensor): [T, 3] triangular face indices + + Returns: + torch.Tensor: [T_, 3] triangular face indices + """ + corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0]) + return faces[~corrupted] + + +def merge_duplicate_vertices( + vertices: torch.Tensor, + faces: torch.Tensor, + tol: float = 1e-6 +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Merge duplicate vertices of a triangular mesh. + Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly. + + Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + tol (float, optional): tolerance for merging. Defaults to 1e-6. + + Returns: + vertices (torch.Tensor): [N_, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + """ + vertices_round = torch.round(vertices / tol) + uni, uni_inv = torch.unique(vertices_round, dim=0, return_inverse=True) + uni[uni_inv] = vertices + faces = uni_inv[faces] + return uni, faces + + +def subdivide_mesh_simple(vertices: torch.Tensor, faces: torch.Tensor, n: int = 1) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles. + NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list. + + Args: + vertices (torch.Tensor): [N, 3] 3-dimensional vertices + faces (torch.Tensor): [T, 3] triangular face indices + n (int, optional): number of subdivisions. Defaults to 1. + + Returns: + vertices (torch.Tensor): [N_, 3] subdivided 3-dimensional vertices + faces (torch.Tensor): [4 * T, 3] subdivided triangular face indices + """ + for _ in range(n): + edges = torch.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], dim=0) + edges = torch.sort(edges, dim=2) + uni_edges, uni_inv = torch.unique(edges, return_inverse=True, dim=0) + midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2 + + n_vertices = vertices.shape[0] + vertices = torch.cat([vertices, midpoints], dim=0) + faces = torch.cat([ + torch.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1), + torch.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1), + torch.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1), + torch.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1), + ], dim=0) + return vertices, faces + + +def compute_face_tbn(pos: torch.Tensor, faces_pos: torch.Tensor, uv: torch.Tensor, faces_uv: torch.Tensor, eps: float = 1e-7) -> torch.Tensor: + """compute TBN matrix for each face + + Args: + pos (torch.Tensor): shape (..., N_pos, 3), positions + faces_pos (torch.Tensor): shape(T, 3) + uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates, + faces_uv (torch.Tensor): shape(T, 3) + + Returns: + torch.Tensor: (..., T, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal + """ + e01 = torch.index_select(pos, dim=-2, index=faces_pos[:, 1]) - torch.index_select(pos, dim=-2, index=faces_pos[:, 0]) + e02 = torch.index_select(pos, dim=-2, index=faces_pos[:, 2]) - torch.index_select(pos, dim=-2, index=faces_pos[:, 0]) + uv01 = torch.index_select(uv, dim=-2, index=faces_uv[:, 1]) - torch.index_select(uv, dim=-2, index=faces_uv[:, 0]) + uv02 = torch.index_select(uv, dim=-2, index=faces_uv[:, 2]) - torch.index_select(uv, dim=-2, index=faces_uv[:, 0]) + normal = torch.cross(e01, e02) + tangent_bitangent = torch.stack([e01, e02], dim=-1) @ torch.inverse(torch.stack([uv01, uv02], dim=-1)) + tbn = torch.cat([tangent_bitangent, normal.unsqueeze(-1)], dim=-1) + tbn = tbn / (torch.norm(tbn, p=2, dim=-2, keepdim=True) + eps) + return tbn + + +def compute_vertex_tbn(faces_topo: torch.Tensor, pos: torch.Tensor, faces_pos: torch.Tensor, uv: torch.Tensor, faces_uv: torch.Tensor) -> torch.Tensor: + """compute TBN matrix for each face + + Args: + faces_topo (torch.Tensor): (T, 3), face indice of topology + pos (torch.Tensor): shape (..., N_pos, 3), positions + faces_pos (torch.Tensor): shape(T, 3) + uv (torch.Tensor): shape (..., N_uv, 3) uv coordinates, + faces_uv (torch.Tensor): shape(T, 3) + + Returns: + torch.Tensor: (..., V, 3, 3) TBN matrix for each face. Note TBN vectors are normalized but not necessarily orthognal + """ + n_vertices = faces_topo.max().item() + 1 + n_tri = faces_topo.shape[-2] + batch_shape = pos.shape[:-2] + face_tbn = compute_face_tbn(pos, faces_pos, uv, faces_uv) # (..., T, 3, 3) + face_tbn = face_tbn[..., :, None, :, :].repeat(*[1] * len(batch_shape), 1, 3, 1, 1).view(*batch_shape, n_tri * 3, 3, 3) # (..., T * 3, 3, 3) + vertex_tbn = torch.index_add(torch.zeros(*batch_shape, n_vertices, 3, 3).to(face_tbn), dim=-3, index=faces_topo.view(-1), source=face_tbn) + vertex_tbn = vertex_tbn / (torch.norm(vertex_tbn, p=2, dim=-2, keepdim=True) + 1e-7) + return vertex_tbn + + +def laplacian(vertices: torch.Tensor, faces: torch.Tensor, weight: str = 'uniform') -> torch.Tensor: + """Laplacian smooth with cotangent weights + + Args: + vertices (torch.Tensor): shape (..., N, 3) + faces (torch.Tensor): shape (T, 3) + weight (str): 'uniform' or 'cotangent' + """ + sum_verts = torch.zeros_like(vertices) # (..., N, 3) + sum_weights = torch.zeros(*vertices.shape[:-1]).to(vertices) # (..., N) + face_verts = torch.index_select(vertices, -2, faces.view(-1)).view(*vertices.shape[:-2], *faces.shape, vertices.shape[-1]) # (..., T, 3) + if weight == 'cotangent': + for i in range(3): + e1 = face_verts[..., (i + 1) % 3, :] - face_verts[..., i, :] + e2 = face_verts[..., (i + 2) % 3, :] - face_verts[..., i, :] + cot_angle = (e1 * e2).sum(dim=-1) / torch.cross(e1, e2, dim=-1).norm(p=2, dim=-1) # (..., T, 3) + sum_verts = torch.index_add(sum_verts, -2, faces[:, (i + 1) % 3], face_verts[..., (i + 2) % 3, :] * cot_angle[..., None]) + sum_weights = torch.index_add(sum_weights, -1, faces[:, (i + 1) % 3], cot_angle) + sum_verts = torch.index_add(sum_verts, -2, faces[:, (i + 2) % 3], face_verts[..., (i + 1) % 3, :] * cot_angle[..., None]) + sum_weights = torch.index_add(sum_weights, -1, faces[:, (i + 2) % 3], cot_angle) + elif weight == 'uniform': + for i in range(3): + sum_verts = torch.index_add(sum_verts, -2, faces[:, i], face_verts[..., (i + 1) % 3, :]) + sum_weights = torch.index_add(sum_weights, -1, faces[:, i], torch.ones_like(face_verts[..., i, 0])) + else: + raise NotImplementedError + return sum_verts / (sum_weights[..., None] + 1e-7) + + +def laplacian_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, weight: str = 'uniform', times: int = 5) -> torch.Tensor: + """Laplacian smooth with cotangent weights + + Args: + vertices (torch.Tensor): shape (..., N, 3) + faces (torch.Tensor): shape (T, 3) + weight (str): 'uniform' or 'cotangent' + """ + for _ in range(times): + vertices = laplacian(vertices, faces, weight) + return vertices + + +def taubin_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, lambda_: float = 0.5, mu_: float = -0.51) -> torch.Tensor: + """Taubin smooth mesh + + Args: + vertices (torch.Tensor): _description_ + faces (torch.Tensor): _description_ + lambda_ (float, optional): _description_. Defaults to 0.5. + mu_ (float, optional): _description_. Defaults to -0.51. + + Returns: + torch.Tensor: _description_ + """ + pt = vertices + lambda_ * laplacian_smooth_mesh(vertices, faces) + p = pt + mu_ * laplacian_smooth_mesh(pt, faces) + return p + + +def laplacian_hc_smooth_mesh(vertices: torch.Tensor, faces: torch.Tensor, times: int = 5, alpha: float = 0.5, beta: float = 0.5, weight: str = 'uniform'): + """HC algorithm from Improved Laplacian Smoothing of Noisy Surface Meshes by J.Vollmer et al. + """ + p = vertices + for i in range(times): + q = p + p = laplacian_smooth_mesh(vertices, faces, weight) + b = p - (alpha * vertices + (1 - alpha) * q) + p = p - (beta * b + (1 - beta) * laplacian_smooth_mesh(b, faces, weight)) * 0.8 + return p diff --git a/utils3d/torch/nerf.py b/utils3d/torch/nerf.py new file mode 100644 index 0000000000000000000000000000000000000000..7d20bc747255dbb1a68191f93a395a824d76e108 --- /dev/null +++ b/utils3d/torch/nerf.py @@ -0,0 +1,749 @@ +from typing import * +from numbers import Number +import math + +import torch +import torch.nn as nn +import torch.nn.functional as F +from torch import Tensor +from .utils import image_uv + + +__all__ = [ + 'get_rays', + 'get_image_rays', + 'get_mipnerf_cones', + 'volume_rendering', + 'bin_sample', + 'importance_sample', + 'nerf_render_rays', + 'mipnerf_render_rays', + 'nerf_render_view', + 'mipnerf_render_view', + 'InstantNGP', +] + + +def get_rays(extrinsics: Tensor, intrinsics: Tensor, uv: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + extrinsics: (..., 4, 4) extrinsics matrices. + intrinsics: (..., 3, 3) intrinsics matrices. + uv: (..., n_rays, 2) uv coordinates of the rays. + + Returns: + rays_o: (..., 1, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth. + """ + uvz = torch.cat([uv, torch.ones_like(uv[..., :1])], dim=-1).to(extrinsics) # (n_batch, n_views, n_rays, 3) + + with torch.cuda.amp.autocast(enabled=False): + inv_transformation = (intrinsics @ extrinsics[..., :3, :3]).inverse() + inv_extrinsics = extrinsics.inverse() + rays_d = uvz @ inv_transformation.transpose(-1, -2) + rays_o = inv_extrinsics[..., None, :3, 3] # (n_batch, n_views, 1, 3) + return rays_o, rays_d + + +def get_image_rays(extrinsics: Tensor, intrinsics: Tensor, width: int, height: int) -> Tuple[Tensor, Tensor]: + """ + Args: + extrinsics: (..., 4, 4) extrinsics matrices. + intrinsics: (..., 3, 3) intrinsics matrices. + width: width of the image. + height: height of the image. + + Returns: + rays_o: (..., 1, 1, 3) ray origins + rays_d: (..., height, width, 3) ray directions. + NOTE: ray directions are NOT normalized. They actuallys makes rays_o + rays_d * z = world coordinates, where z is the depth. + """ + uv = image_uv(height, width).to(extrinsics).flatten(0, 1) + rays_o, rays_d = get_rays(extrinsics, intrinsics, uv) + rays_o = rays_o.unflatten(-2, (1, 1)) + rays_d = rays_d.unflatten(-2, (height, width)) + return rays_o, rays_d + + +def get_mipnerf_cones(rays_o: Tensor, rays_d: Tensor, z_vals: Tensor, pixel_width: Tensor) -> Tuple[Tensor, Tensor]: + """ + Args: + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + z_vals: (..., n_rays, n_samples) z values. + pixel_width: (...) pixel width. = 1 / (normalized focal length * width) + + Returns: + mu: (..., n_rays, n_samples, 3) cone mu. + sigma: (..., n_rays, n_samples, 3, 3) cone sigma. + """ + t_mu = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) + t_delta = (z_vals[..., 1:] - z_vals[..., :-1]).mul_(0.5) + t_mu_square = t_mu.square() + t_delta_square = t_delta.square() + t_delta_quad = t_delta_square.square() + mu_t = t_mu + 2.0 * t_mu * t_delta_square / (3.0 * t_mu_square + t_delta_square) + sigma_t = t_delta_square / 3.0 - (4.0 / 15.0) * t_delta_quad / (3.0 * t_mu_square + t_delta_square).square() * (12.0 * t_mu_square - t_delta_square) + sigma_r = (pixel_width[..., None, None].square() / 3.0) * (t_mu_square / 4.0 + (5.0 / 12.0) * t_delta_square - (4.0 / 15.0) * t_delta_quad / (3.0 * t_mu_square + t_delta_square)) + points_mu = rays_o[:, :, :, None, :] + rays_d[:, :, :, None, :] * mu_t[..., None] + d_dt = rays_d[..., :, None] * rays_d[..., None, :] # (..., n_rays, 3, 3) + points_sigma = sigma_t[..., None, None] * d_dt[..., None, :, :] + sigma_r[..., None, None] * (torch.eye(3).to(rays_o) - d_dt[..., None, :, :]) + return points_mu, points_sigma + + +def get_pixel_width(intrinsics: Tensor, width: int, height: int) -> Tensor: + """ + Args: + intrinsics: (..., 3, 3) intrinsics matrices. + width: width of the image. + height: height of the image. + + Returns: + pixel_width: (...) pixel width. = 1 / (normalized focal length * width) + """ + assert width == height, "Currently, only square images are supported." + pixel_width = torch.reciprocal((intrinsics[..., 0, 0] * intrinsics[..., 1, 1]).sqrt() * width) + return pixel_width + + +def volume_rendering(color: Tensor, sigma: Tensor, z_vals: Tensor, ray_length: Tensor, rgb: bool = True, depth: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + """ + Given color, sigma and z_vals (linear depth of the sampling points), render the volume. + + NOTE: By default, color and sigma should have one less sample than z_vals, in correspondence with the average value in intervals. + If queried color are aligned with z_vals, we use trapezoidal rule to calculate the average values in intervals. + + Args: + color: (..., n_samples or n_samples - 1, 3) color values. + sigma: (..., n_samples or n_samples - 1) density values. + z_vals: (..., n_samples) z values. + ray_length: (...) length of the ray + + Returns: + rgb: (..., 3) rendered color values. + depth: (...) rendered depth values. + weights (..., n_samples) weights. + """ + dists = (z_vals[..., 1:] - z_vals[..., :-1]) * ray_length[..., None] + if color.shape[-2] == z_vals.shape[-1]: + color = (color[..., 1:, :] + color[..., :-1, :]).mul_(0.5) + sigma = (sigma[..., 1:] + sigma[..., :-1]).mul_(0.5) + sigma_delta = sigma * dists + transparancy = (-torch.cat([torch.zeros_like(sigma_delta[..., :1]), sigma_delta[..., :-1]], dim=-1).cumsum(dim=-1)).exp_() # First cumsum then exp for numerical stability + alpha = 1.0 - (-sigma_delta).exp_() + weights = alpha * transparancy + if rgb: + rgb = torch.sum(weights[..., None] * color, dim=-2) if rgb else None + if depth: + z_vals = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) + depth = torch.sum(weights * z_vals, dim=-1) / weights.sum(dim=-1).clamp_min_(1e-8) if depth else None + return rgb, depth, weights + + +def neus_volume_rendering(color: Tensor, sdf: Tensor, s: torch.Tensor, z_vals: Tensor = None, rgb: bool = True, depth: bool = True) -> Tuple[Tensor, Tensor, Tensor]: + """ + Given color, sdf values and z_vals (linear depth of the sampling points), do volume rendering. (NeuS) + + Args: + color: (..., n_samples or n_samples - 1, 3) color values. + sdf: (..., n_samples) sdf values. + s: (..., n_samples) S values of S-density function in NeuS. The standard deviation of such S-density distribution is 1 / s. + z_vals: (..., n_samples) z values. + ray_length: (...) length of the ray + + Returns: + rgb: (..., 3) rendered color values. + depth: (...) rendered depth values. + weights (..., n_samples) weights. + """ + + if color.shape[-2] == z_vals.shape[-1]: + color = (color[..., 1:, :] + color[..., :-1, :]).mul_(0.5) + + sigmoid_sdf = torch.sigmoid(s * sdf) + alpha = F.relu(1 - sigmoid_sdf[..., :-1] / sigmoid_sdf[..., :-1]) + transparancy = torch.cumprod(torch.cat([torch.ones_like(alpha[..., :1]), alpha], dim=-1), dim=-1) + weights = alpha * transparancy + + if rgb: + rgb = torch.sum(weights[..., None] * color, dim=-2) if rgb else None + if depth: + z_vals = (z_vals[..., 1:] + z_vals[..., :-1]).mul_(0.5) + depth = torch.sum(weights * z_vals, dim=-1) / weights.sum(dim=-1).clamp_min_(1e-8) if depth else None + return rgb, depth, weights + + +def bin_sample(size: Union[torch.Size, Tuple[int, ...]], n_samples: int, min_value: Number, max_value: Number, spacing: Literal['linear', 'inverse_linear'], dtype: torch.dtype = None, device: torch.device = None) -> Tensor: + """ + Uniformly (or uniformly in inverse space) sample z values in `n_samples` bins in range [min_value, max_value]. + Args: + size: size of the rays + n_samples: number of samples to be sampled, also the number of bins + min_value: minimum value of the range + max_value: maximum value of the range + space: 'linear' or 'inverse_linear'. If 'inverse_linear', the sampling is uniform in inverse space. + + Returns: + z_rand: (*size, n_samples) sampled z values, sorted in ascending order. + """ + if spacing == 'linear': + pass + elif spacing == 'inverse_linear': + min_value = 1.0 / min_value + max_value = 1.0 / max_value + bin_length = (max_value - min_value) / n_samples + z_rand = (torch.rand(*size, n_samples, device=device, dtype=dtype) - 0.5) * bin_length + torch.linspace(min_value + bin_length * 0.5, max_value - bin_length * 0.5, n_samples, device=device, dtype=dtype) + if spacing == 'inverse_linear': + z_rand = 1.0 / z_rand + return z_rand + + +def importance_sample(z_vals: Tensor, weights: Tensor, n_samples: int) -> Tuple[Tensor, Tensor]: + """ + Importance sample z values. + + NOTE: By default, weights should have one less sample than z_vals, in correspondence with the intervals. + If weights has the same number of samples as z_vals, we use trapezoidal rule to calculate the average weights in intervals. + + Args: + z_vals: (..., n_rays, n_input_samples) z values, sorted in ascending order. + weights: (..., n_rays, n_input_samples or n_input_samples - 1) weights. + n_samples: number of output samples for importance sampling. + + Returns: + z_importance: (..., n_rays, n_samples) importance sampled z values, unsorted. + """ + if weights.shape[-1] == z_vals.shape[-1]: + weights = (weights[..., 1:] + weights[..., :-1]).mul_(0.5) + weights = weights / torch.sum(weights, dim=-1, keepdim=True) # (..., n_rays, n_input_samples - 1) + bins_a, bins_b = z_vals[..., :-1], z_vals[..., 1:] + + pdf = weights / torch.sum(weights, dim=-1, keepdim=True) # (..., n_rays, n_input_samples - 1) + cdf = torch.cumsum(pdf, dim=-1) + u = torch.rand(*z_vals.shape[:-1], n_samples, device=z_vals.device, dtype=z_vals.dtype) + + inds = torch.searchsorted(cdf, u, right=True).clamp(0, cdf.shape[-1] - 1) # (..., n_rays, n_samples) + + bins_a = torch.gather(bins_a, dim=-1, index=inds) + bins_b = torch.gather(bins_b, dim=-1, index=inds) + z_importance = bins_a + (bins_b - bins_a) * torch.rand_like(u) + return z_importance + + +def nerf_render_rays( + nerf: Union[Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]], Tuple[Callable[[Tensor], Tuple[Tensor, Tensor]], Callable[[Tensor], Tuple[Tensor, Tensor]]]], + rays_o: Tensor, rays_d: Tensor, + *, + return_dict: bool = False, + n_coarse: int = 64, n_fine: int = 64, + near: float = 0.1, far: float = 100.0, + z_spacing: Literal['linear', 'inverse_linear'] = 'linear', +): + """ + NeRF rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + nerf: nerf model, which takes (points, directions) as input and returns (color, density) as output. + If nerf is a tuple, it should be (nerf_coarse, nerf_fine), where nerf_coarse and nerf_fine are two nerf models for coarse and fine stages respectively. + + nerf args: + points: (..., n_rays, n_samples, 3) + directions: (..., n_rays, n_samples, 3) + nerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + + Returns + if return_dict is False, return rendered rgb and depth for short cut. (If there are separate coarse and fine results, return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0` or `nerf` is a single model, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + ``` + If there are two models for coarse and fine stages, the dict contains both coarse and fine results: + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ``` + """ + if isinstance(nerf, tuple): + nerf_coarse, nerf_fine = nerf + else: + nerf_coarse = nerf_fine = nerf + # 1. Coarse: bin sampling + z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, device=rays_o.device, dtype=rays_o.dtype, spacing=z_spacing) # (n_batch, n_views, n_rays, n_samples) + points_coarse = rays_o[..., None, :] + rays_d[..., None, :] * z_coarse[..., None] # (n_batch, n_views, n_rays, n_samples, 3) + ray_length = rays_d.norm(dim=-1) + + # Query color and density + color_coarse, density_coarse = nerf_coarse(points_coarse, rays_d[..., None, :].expand_as(points_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples) + + # Volume rendering + with torch.no_grad(): + rgb_coarse, depth_coarse, weights = volume_rendering(color_coarse, density_coarse, z_coarse, ray_length) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples) + + if n_fine == 0: + if return_dict: + return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse} + else: + return rgb_coarse, depth_coarse + + # 2. Fine: Importance sampling + if nerf_coarse is nerf_fine: + # If coarse and fine stages share the same model, the points of coarse stage can be reused, + # and we only need to query the importance samples of fine stage. + z_fine = importance_sample(z_coarse, weights, n_fine) + points_fine = rays_o[..., None, :] + rays_d[..., None, :] * z_fine[..., None] + color_fine, density_fine = nerf_fine(points_fine, rays_d[..., None, :].expand_as(points_fine)) + + # Merge & volume rendering + z_vals = torch.cat([z_coarse, z_fine], dim=-1) + color = torch.cat([color_coarse, color_fine], dim=-2) + density = torch.cat([density_coarse, density_fine], dim=-1) + z_vals, sort_inds = torch.sort(z_vals, dim=-1) + color = torch.gather(color, dim=-2, index=sort_inds[..., None].expand_as(color)) + density = torch.gather(density, dim=-1, index=sort_inds) + rgb, depth, weights = volume_rendering(color, density, z_vals, ray_length) + + if return_dict: + return {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'density': density} + else: + return rgb, depth + else: + # If coarse and fine stages use different models, we need to query the importance samples of both stages. + z_fine = importance_sample(z_coarse, weights, n_fine) + z_vals = torch.cat([z_coarse, z_fine], dim=-1) + points = rays_o[..., None, :] + rays_d[..., None, :] * z_vals[..., None] + color, density = nerf_fine(points) + rgb, depth, weights = volume_rendering(color, density, z_vals, ray_length) + + if return_dict: + return { + 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse}, + 'fine': {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'density': density} + } + else: + return rgb, depth + + +def mipnerf_render_rays( + mipnerf: Callable[[Tensor, Tensor, Tensor], Tuple[Tensor, Tensor]], + rays_o: Tensor, rays_d: Tensor, pixel_width: Tensor, + *, + return_dict: bool = False, + n_coarse: int = 64, n_fine: int = 64, uniform_ratio: float = 0.4, + near: float = 0.1, far: float = 100.0, + z_spacing: Literal['linear', 'inverse_linear'] = 'linear', +) -> Union[Tuple[Tensor, Tensor], Dict[str, Tensor]]: + """ + MipNeRF rendering. + + Args: + mipnerf: mipnerf model, which takes (points_mu, points_sigma) as input and returns (color, density) as output. + + mipnerf args: + points_mu: (..., n_rays, n_samples, 3) cone mu. + points_sigma: (..., n_rays, n_samples, 3, 3) cone sigma. + directions: (..., n_rays, n_samples, 3) + mipnerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + + Returns + if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0`, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + ``` + If n_fine > 0, the dict contains both coarse and fine results : + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ``` + """ + # 1. Coarse: bin sampling + z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, spacing=z_spacing, device=rays_o.device, dtype=rays_o.dtype) + points_mu_coarse, points_sigma_coarse = get_mipnerf_cones(rays_o, rays_d, z_coarse, pixel_width) + ray_length = rays_d.norm(dim=-1) + + # Query color and density + color_coarse, density_coarse = mipnerf(points_mu_coarse, points_sigma_coarse, rays_d[..., None, :].expand_as(points_mu_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples) + + # Volume rendering + rgb_coarse, depth_coarse, weights_coarse = volume_rendering(color_coarse, density_coarse, z_coarse, ray_length) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples) + + if n_fine == 0: + if return_dict: + return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights_coarse, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse} + else: + return rgb_coarse, depth_coarse + + # 2. Fine: Importance sampling. (NOTE: coarse stages and fine stages always share the same model, but coarse stage points can not be reused) + with torch.no_grad(): + weights_coarse = (1.0 - uniform_ratio) * weights_coarse + uniform_ratio / weights_coarse.shape[-1] + z_fine = importance_sample(z_coarse, weights_coarse, n_fine) + z_fine, _ = torch.sort(z_fine, dim=-2) + points_mu_fine, points_sigma_fine = get_mipnerf_cones(rays_o, rays_d, z_fine, pixel_width) + color_fine, density_fine = mipnerf(points_mu_fine, points_sigma_fine, rays_d[..., None, :].expand_as(points_mu_fine)) + + # Volume rendering + rgb_fine, depth_fine, weights_fine = volume_rendering(color_fine, density_fine, z_fine, ray_length) + + if return_dict: + return { + 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights_coarse, 'z_vals': z_coarse, 'color': color_coarse, 'density': density_coarse}, + 'fine': {'rgb': rgb_fine, 'depth': depth_fine, 'weights': weights_fine, 'z_vals': z_fine, 'color': color_fine, 'density': density_fine} + } + else: + return rgb_fine, depth_fine + + +def neus_render_rays( + neus: Callable[[Tensor, Tensor], Tuple[Tensor, Tensor]], + s: Union[Number, Tensor], + rays_o: Tensor, rays_d: Tensor, + *, + compute_normal: bool = True, + return_dict: bool = False, + n_coarse: int = 64, n_fine: int = 64, + near: float = 0.1, far: float = 100.0, + z_spacing: Literal['linear', 'inverse_linear'] = 'linear', +): + """ + TODO + NeuS rendering of rays. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + neus: neus model, which takes (points, directions) as input and returns (color, density) as output. + + nerf args: + points: (..., n_rays, n_samples, 3) + directions: (..., n_rays, n_samples, 3) + nerf returns: + color: (..., n_rays, n_samples, 3) color values. + density: (..., n_rays, n_samples) density values. + + rays_o: (..., n_rays, 3) ray origins + rays_d: (..., n_rays, 3) ray directions. + pixel_width: (..., n_rays) pixel width. How to compute? pixel_width = 1 / (normalized focal length * width) + + Returns + if return_dict is False, return rendered results only: (If `n_fine == 0`, return coarse results, otherwise return fine results) + rgb: (..., n_rays, 3) rendered color values. + depth: (..., n_rays) rendered depth values. + else, return a dict. If `n_fine == 0`, the dict only contains coarse results: + ``` + {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'sdf': ..., 'normal': ...} + ``` + If n_fine > 0, the dict contains both coarse and fine results: + ``` + { + "coarse": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..}, + "fine": {'rgb': .., 'depth': .., 'weights': .., 'z_vals': .., 'color': .., 'density': ..} + } + ``` + """ + + # 1. Coarse: bin sampling + z_coarse = bin_sample(rays_d.shape[:-1], n_coarse, near, far, device=rays_o.device, dtype=rays_o.dtype, spacing=z_spacing) # (n_batch, n_views, n_rays, n_samples) + points_coarse = rays_o[..., None, :] + rays_d[..., None, :] * z_coarse[..., None] # (n_batch, n_views, n_rays, n_samples, 3) + + # Query color and density + color_coarse, sdf_coarse = neus(points_coarse, rays_d[..., None, :].expand_as(points_coarse)) # (n_batch, n_views, n_rays, n_samples, 3), (n_batch, n_views, n_rays, n_samples) + + # Volume rendering + with torch.no_grad(): + rgb_coarse, depth_coarse, weights = neus_volume_rendering(color_coarse, sdf_coarse, s, z_coarse) # (n_batch, n_views, n_rays, 3), (n_batch, n_views, n_rays, 1), (n_batch, n_views, n_rays, n_samples) + + if n_fine == 0: + if return_dict: + return {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'sdf': sdf_coarse} + else: + return rgb_coarse, depth_coarse + + # If coarse and fine stages share the same model, the points of coarse stage can be reused, + # and we only need to query the importance samples of fine stage. + z_fine = importance_sample(z_coarse, weights, n_fine) + points_fine = rays_o[..., None, :] + rays_d[..., None, :] * z_fine[..., None] + color_fine, sdf_fine = neus(points_fine, rays_d[..., None, :].expand_as(points_fine)) + + # Merge & volume rendering + z_vals = torch.cat([z_coarse, z_fine], dim=-1) + color = torch.cat([color_coarse, color_fine], dim=-2) + sdf = torch.cat([sdf_coarse, sdf_fine], dim=-1) + z_vals, sort_inds = torch.sort(z_vals, dim=-1) + color = torch.gather(color, dim=-2, index=sort_inds[..., None].expand_as(color)) + sdf = torch.gather(sdf, dim=-1, index=sort_inds) + rgb, depth, weights = neus_volume_rendering(color, sdf, s, z_vals) + + if return_dict: + return { + 'coarse': {'rgb': rgb_coarse, 'depth': depth_coarse, 'weights': weights, 'z_vals': z_coarse, 'color': color_coarse, 'sdf': sdf_coarse}, + 'fine': {'rgb': rgb, 'depth': depth, 'weights': weights, 'z_vals': z_vals, 'color': color, 'sdf': sdf} + } + else: + return rgb, depth + + +def nerf_render_view( + nerf: Tensor, + extrinsics: Tensor, + intrinsics: Tensor, + width: int, + height: int, + *, + patchify: bool = False, + patch_size: Tuple[int, int] = (64, 64), + **options: Dict[str, Any] +) -> Tuple[Tensor, Tensor]: + """ + NeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + extrinsics: (..., 4, 4) extrinsics matrice of the rendered views + intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. + width (optional): image width of the rendered views. + height (optional): image height of the rendered views. + patchify (optional): If the image is too large, render it patch by patch + **options: rendering options. + + Returns: + rgb: (..., channels, height, width) rendered color values. + depth: (..., height, width) rendered depth values. + """ + if patchify: + # Patchified rendering + max_patch_width, max_patch_height = patch_size + n_rows, n_columns = math.ceil(height / max_patch_height), math.ceil(width / max_patch_width) + + rgb_rows, depth_rows = [], [] + for i_row in range(n_rows): + rgb_row, depth_row = [], [] + for i_column in range(n_columns): + patch_shape = patch_height, patch_width = min(max_patch_height, height - i_row * max_patch_height), min(max_patch_width, width - i_column * max_patch_width) + uv = image_uv(height, width, i_column * max_patch_width, i_row * max_patch_height, i_column * max_patch_width + patch_width, i_row * max_patch_height + patch_height).to(extrinsics) + uv = uv.flatten(0, 1) # (patch_height * patch_width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb_, depth_ = nerf_render_rays(nerf, ray_o_, ray_d_, **options, return_dict=False) + rgb_ = rgb_.transpose(-1, -2).unflatten(-1, patch_shape) # (..., 3, patch_height, patch_width) + depth_ = depth_.unflatten(-1, patch_shape) # (..., patch_height, patch_width) + + rgb_row.append(rgb_) + depth_row.append(depth_) + rgb_rows.append(torch.cat(rgb_row, dim=-1)) + depth_rows.append(torch.cat(depth_row, dim=-1)) + rgb = torch.cat(rgb_rows, dim=-2) + depth = torch.cat(depth_rows, dim=-2) + + return rgb, depth + else: + # Full rendering + uv = image_uv(height, width).to(extrinsics) + uv = uv.flatten(0, 1) # (height * width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb, depth = nerf_render_rays(nerf, ray_o_, ray_d_, **options, return_dict=False) + rgb = rgb.transpose(-1, -2).unflatten(-1, (height, width)) # (..., 3, height, width) + depth = depth.unflatten(-1, (height, width)) # (..., height, width) + + return rgb, depth + + +def mipnerf_render_view( + mipnerf: Tensor, + extrinsics: Tensor, + intrinsics: Tensor, + width: int, + height: int, + *, + patchify: bool = False, + patch_size: Tuple[int, int] = (64, 64), + **options: Dict[str, Any] +) -> Tuple[Tensor, Tensor]: + """ + MipNeRF rendering of views. Note that it supports arbitrary batch dimensions (denoted as `...`) + + Args: + extrinsics: (..., 4, 4) extrinsics matrice of the rendered views + intrinsics (optional): (..., 3, 3) intrinsics matrice of the rendered views. + width (optional): image width of the rendered views. + height (optional): image height of the rendered views. + patchify (optional): If the image is too large, render it patch by patch + **options: rendering options. + + Returns: + rgb: (..., 3, height, width) rendered color values. + depth: (..., height, width) rendered depth values. + """ + pixel_width = get_pixel_width(intrinsics, width, height) + + if patchify: + # Patchified rendering + max_patch_width, max_patch_height = patch_size + n_rows, n_columns = math.ceil(height / max_patch_height), math.ceil(width / max_patch_width) + + rgb_rows, depth_rows = [], [] + for i_row in range(n_rows): + rgb_row, depth_row = [], [] + for i_column in range(n_columns): + patch_shape = patch_height, patch_width = min(max_patch_height, height - i_row * max_patch_height), min(max_patch_width, width - i_column * max_patch_width) + uv = image_uv(height, width, i_column * max_patch_width, i_row * max_patch_height, i_column * max_patch_width + patch_width, i_row * max_patch_height + patch_height).to(extrinsics) + uv = uv.flatten(0, 1) # (patch_height * patch_width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb_, depth_ = mipnerf_render_rays(mipnerf, ray_o_, ray_d_, pixel_width, **options) + rgb_ = rgb_.transpose(-1, -2).unflatten(-1, patch_shape) # (..., 3, patch_height, patch_width) + depth_ = depth_.unflatten(-1, patch_shape) # (..., patch_height, patch_width) + + rgb_row.append(rgb_) + depth_row.append(depth_) + rgb_rows.append(torch.cat(rgb_row, dim=-1)) + depth_rows.append(torch.cat(depth_row, dim=-1)) + rgb = torch.cat(rgb_rows, dim=-2) + depth = torch.cat(depth_rows, dim=-2) + + return rgb, depth + else: + # Full rendering + uv = image_uv(height, width).to(extrinsics) + uv = uv.flatten(0, 1) # (height * width, 2) + ray_o_, ray_d_ = get_rays(extrinsics, intrinsics, uv) + rgb, depth = mipnerf_render_rays(mipnerf, ray_o_, ray_d_, pixel_width, **options) + rgb = rgb.transpose(-1, -2).unflatten(-1, (height, width)) # (..., 3, height, width) + depth = depth.unflatten(-1, (height, width)) # (..., height, width) + + return rgb, depth + + +class InstantNGP(nn.Module): + """ + An implementation of InstantNGP, Müller et. al., https://nvlabs.github.io/instant-ngp/. + Requires `tinycudann` package. + Install it by: + ``` + pip install git+https://github.com/NVlabs/tiny-cuda-nn/#subdirectory=bindings/torch + ``` + """ + def __init__(self, + view_dependent: bool = True, + base_resolution: int = 16, + finest_resolution: int = 2048, + n_levels: int = 16, + num_layers_density: int = 2, + hidden_dim_density: int = 64, + num_layers_color: int = 3, + hidden_dim_color: int = 64, + log2_hashmap_size: int = 19, + bound: float = 1.0, + color_channels: int = 3, + ): + super().__init__() + import tinycudann + N_FEATURES_PER_LEVEL = 2 + GEO_FEAT_DIM = 15 + + self.bound = bound + self.color_channels = color_channels + + # density network + self.num_layers_density = num_layers_density + self.hidden_dim_density = hidden_dim_density + + per_level_scale = (finest_resolution / base_resolution) ** (1 / (n_levels - 1)) + + self.encoder = tinycudann.Encoding( + n_input_dims=3, + encoding_config={ + "otype": "HashGrid", + "n_levels": n_levels, + "n_features_per_level": N_FEATURES_PER_LEVEL, + "log2_hashmap_size": log2_hashmap_size, + "base_resolution": base_resolution, + "per_level_scale": per_level_scale, + }, + ) + + self.density_net = tinycudann.Network( + n_input_dims=N_FEATURES_PER_LEVEL * n_levels, + n_output_dims=1 + GEO_FEAT_DIM, + network_config={ + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": "None", + "n_neurons": hidden_dim_density, + "n_hidden_layers": num_layers_density - 1, + }, + ) + + # color network + self.num_layers_color = num_layers_color + self.hidden_dim_color = hidden_dim_color + + self.view_dependent = view_dependent + if view_dependent: + self.encoder_dir = tinycudann.Encoding( + n_input_dims=3, + encoding_config={ + "otype": "SphericalHarmonics", + "degree": 4, + }, + ) + self.in_dim_color = self.encoder_dir.n_output_dims + GEO_FEAT_DIM + else: + self.in_dim_color = GEO_FEAT_DIM + + self.color_net = tinycudann.Network( + n_input_dims=self.in_dim_color, + n_output_dims=color_channels, + network_config={ + "otype": "FullyFusedMLP", + "activation": "ReLU", + "output_activation": "None", + "n_neurons": hidden_dim_color, + "n_hidden_layers": num_layers_color - 1, + }, + ) + + def forward(self, x: torch.Tensor, d: torch.Tensor) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Args: + x: (..., 3) points + d: (..., 3) directions + Returns: + color: (..., 3) color values. + density: (..., 1) density values. + """ + batch_shape = x.shape[:-1] + x, d = x.reshape(-1, 3), d.reshape(-1, 3) + + # density + x = (x + self.bound) / (2 * self.bound) # to [0, 1] + x = self.encoder(x) + density, geo_feat = self.density_net(x).split([1, 15], dim=-1) + density = F.softplus(density).squeeze(-1) + + # color + if self.view_dependent: + d = (F.normalize(d, dim=-1) + 1) / 2 # tcnn SH encoding requires inputs to be in [0, 1] + d = self.encoder_dir(d) + h = torch.cat([d, geo_feat], dim=-1) + else: + h = geo_feat + color = self.color_net(h) + + return color.reshape(*batch_shape, self.color_channels), density.reshape(*batch_shape) + diff --git a/utils3d/torch/rasterization.py b/utils3d/torch/rasterization.py new file mode 100644 index 0000000000000000000000000000000000000000..bdba0c8d5fcc60bd34c07bb84ee10b931ffda281 --- /dev/null +++ b/utils3d/torch/rasterization.py @@ -0,0 +1,362 @@ +from typing import * + +import torch +import nvdiffrast.torch as dr + +from . import utils, transforms, mesh +from ._helpers import batched + + +__all__ = [ + 'RastContext', + 'rasterize_triangle_faces', + 'warp_image_by_depth', + 'warp_image_by_forward_flow', +] + + +class RastContext: + """ + Create a rasterization context. Nothing but a wrapper of nvdiffrast.torch.RasterizeCudaContext or nvdiffrast.torch.RasterizeGLContext. + """ + def __init__(self, nvd_ctx: Union[dr.RasterizeCudaContext, dr.RasterizeGLContext] = None, *, backend: Literal['cuda', 'gl'] = 'gl', device: Union[str, torch.device] = None): + import nvdiffrast.torch as dr + if nvd_ctx is not None: + self.nvd_ctx = nvd_ctx + return + + if backend == 'gl': + self.nvd_ctx = dr.RasterizeGLContext(device=device) + elif backend == 'cuda': + self.nvd_ctx = dr.RasterizeCudaContext(device=device) + else: + raise ValueError(f'Unknown backend: {backend}') + + +def rasterize_triangle_faces( + ctx: RastContext, + vertices: torch.Tensor, + faces: torch.Tensor, + attr: torch.Tensor, + width: int, + height: int, + model: torch.Tensor = None, + view: torch.Tensor = None, + projection: torch.Tensor = None, + antialiasing: Union[bool, List[int]] = True, + diff_attrs: Union[None, List[int]] = None, +) -> Tuple[torch.Tensor, torch.Tensor, Optional[torch.Tensor]]: + """ + Rasterize a mesh with vertex attributes. + + Args: + ctx (GLContext): rasterizer context + vertices (np.ndarray): (B, N, 2 or 3 or 4) + faces (torch.Tensor): (T, 3) + attr (torch.Tensor): (B, N, C) + width (int): width of the output image + height (int): height of the output image + model (torch.Tensor, optional): ([B,] 4, 4) model matrix. Defaults to None (identity). + view (torch.Tensor, optional): ([B,] 4, 4) view matrix. Defaults to None (identity). + projection (torch.Tensor, optional): ([B,] 4, 4) projection matrix. Defaults to None (identity). + antialiasing (Union[bool, List[int]], optional): whether to perform antialiasing. Defaults to True. If a list of indices is provided, only those channels will be antialiased. + diff_attrs (Union[None, List[int]], optional): indices of attributes to compute screen-space derivatives. Defaults to None. + + Returns: + image: (torch.Tensor): (B, C, H, W) + depth: (torch.Tensor): (B, H, W) screen space depth, ranging from 0 (near) to 1. (far) + NOTE: Empty pixels will have depth 1., i.e. far plane. + """ + assert vertices.ndim == 3 + assert faces.ndim == 2 + + if vertices.shape[-1] == 2: + vertices = torch.cat([vertices, torch.zeros_like(vertices[..., :1]), torch.ones_like(vertices[..., :1])], dim=-1) + elif vertices.shape[-1] == 3: + vertices = torch.cat([vertices, torch.ones_like(vertices[..., :1])], dim=-1) + elif vertices.shape[-1] == 4: + pass + else: + raise ValueError(f'Wrong shape of vertices: {vertices.shape}') + + mvp = projection if projection is not None else torch.eye(4).to(vertices) + if view is not None: + mvp = mvp @ view + if model is not None: + mvp = mvp @ model + + pos_clip = vertices @ mvp.transpose(-1, -2) + faces = faces.contiguous() + attr = attr.contiguous() + + rast_out, rast_db = dr.rasterize(ctx.nvd_ctx, pos_clip, faces, resolution=[height, width], grad_db=True) + image, image_dr = dr.interpolate(attr, rast_out, faces, rast_db, diff_attrs=diff_attrs) + if antialiasing == True: + image = dr.antialias(image, rast_out, pos_clip, faces) + elif isinstance(antialiasing, list): + aa_image = dr.antialias(image[..., antialiasing], rast_out, pos_clip, faces) + image[..., antialiasing] = aa_image + + image = image.flip(1).permute(0, 3, 1, 2) + + depth = rast_out[..., 2].flip(1) + depth = (depth * 0.5 + 0.5) * (depth > 0).float() + (depth == 0).float() + if diff_attrs is not None: + image_dr = image_dr.flip(1).permute(0, 3, 1, 2) + return image, depth, image_dr + return image, depth + + +def texture( + ctx: RastContext, + uv: torch.Tensor, + uv_da: torch.Tensor, + texture: torch.Tensor, +) -> torch.Tensor: + dr.texture(ctx.nvd_ctx, uv, texture) + + +def warp_image_by_depth( + ctx: RastContext, + depth: torch.FloatTensor, + image: torch.FloatTensor = None, + mask: torch.BoolTensor = None, + width: int = None, + height: int = None, + *, + extrinsics_src: torch.FloatTensor = None, + extrinsics_tgt: torch.FloatTensor = None, + intrinsics_src: torch.FloatTensor = None, + intrinsics_tgt: torch.FloatTensor = None, + near: float = 0.1, + far: float = 100.0, + antialiasing: bool = True, + backslash: bool = False, + padding: int = 0, + return_uv: bool = False, + return_dr: bool = False, +) -> Tuple[torch.FloatTensor, torch.FloatTensor, torch.BoolTensor, Optional[torch.FloatTensor], Optional[torch.FloatTensor]]: + """ + Warp image by depth. + NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. + Otherwise, image mesh will be triangulated simply for batch rendering. + + Args: + ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context + depth (torch.Tensor): (B, H, W) linear depth + image (torch.Tensor): (B, C, H, W). None to use image space uv. Defaults to None. + width (int, optional): width of the output image. None to use the same as depth. Defaults to None. + height (int, optional): height of the output image. Defaults the same as depth.. + extrinsics_src (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for source. None to use identity. Defaults to None. + extrinsics_tgt (torch.Tensor, optional): (B, 4, 4) extrinsics matrix for target. None to use identity. Defaults to None. + intrinsics_src (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for source. None to use the same as target. Defaults to None. + intrinsics_tgt (torch.Tensor, optional): (B, 3, 3) intrinsics matrix for target. None to use the same as source. Defaults to None. + near (float, optional): near plane. Defaults to 0.1. + far (float, optional): far plane. Defaults to 100.0. + antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. + backslash (bool, optional): whether to use backslash triangulation. Defaults to False. + padding (int, optional): padding of the image. Defaults to 0. + return_uv (bool, optional): whether to return the uv. Defaults to False. + return_dr (bool, optional): whether to return the image-space derivatives of uv. Defaults to False. + + Returns: + image: (torch.FloatTensor): (B, C, H, W) rendered image + depth: (torch.FloatTensor): (B, H, W) linear depth, ranging from 0 to inf + mask: (torch.BoolTensor): (B, H, W) mask of valid pixels + uv: (torch.FloatTensor): (B, 2, H, W) image-space uv + dr: (torch.FloatTensor): (B, 4, H, W) image-space derivatives of uv + """ + assert depth.ndim == 3 + batch_size = depth.shape[0] + + if width is None: + width = depth.shape[-1] + if height is None: + height = depth.shape[-2] + if image is not None: + assert image.shape[-2:] == depth.shape[-2:], f'Shape of image {image.shape} does not match shape of depth {depth.shape}' + + if extrinsics_src is None: + extrinsics_src = torch.eye(4).to(depth) + if extrinsics_tgt is None: + extrinsics_tgt = torch.eye(4).to(depth) + if intrinsics_src is None: + intrinsics_src = intrinsics_tgt + if intrinsics_tgt is None: + intrinsics_tgt = intrinsics_src + + assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters." + + view_tgt = transforms.extrinsics_to_view(extrinsics_tgt) + perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far) + + if padding > 0: + uv, faces = utils.image_mesh(width=width+2, height=height+2) + uv = (uv - 1 / (width + 2)) * ((width + 2) / width) + uv_ = uv.clone().reshape(height+2, width+2, 2) + uv_[0, :, 1] -= padding / height + uv_[-1, :, 1] += padding / height + uv_[:, 0, 0] -= padding / width + uv_[:, -1, 0] += padding / width + uv_ = uv_.reshape(-1, 2) + depth = torch.nn.functional.pad(depth, [1, 1, 1, 1], mode='replicate') + if image is not None: + image = torch.nn.functional.pad(image, [1, 1, 1, 1], mode='replicate') + uv, uv_, faces = uv.to(depth.device), uv_.to(depth.device), faces.to(depth.device) + pts = transforms.unproject_cv( + uv_, + depth.flatten(-2, -1), + extrinsics_src, + intrinsics_src, + ) + else: + uv, faces = utils.image_mesh(width=depth.shape[-1], height=depth.shape[-2]) + if mask is not None: + depth = torch.where(mask, depth, torch.tensor(far, dtype=depth.dtype, device=depth.device)) + uv, faces = uv.to(depth.device), faces.to(depth.device) + pts = transforms.unproject_cv( + uv, + depth.flatten(-2, -1), + extrinsics_src, + intrinsics_src, + ) + + # triangulate + if batch_size == 1: + faces = mesh.triangulate(faces, vertices=pts[0]) + else: + faces = mesh.triangulate(faces, backslash=backslash) + + # rasterize attributes + diff_attrs = None + if image is not None: + attr = image.permute(0, 2, 3, 1).flatten(1, 2) + if return_dr or return_uv: + if return_dr: + diff_attrs = [image.shape[1], image.shape[1]+1] + if return_uv and antialiasing: + antialiasing = list(range(image.shape[1])) + attr = torch.cat([attr, uv.expand(batch_size, -1, -1)], dim=-1) + else: + attr = uv.expand(batch_size, -1, -1) + if antialiasing: + print("\033[93mWarning: you are performing antialiasing on uv. This may cause artifacts.\033[0m") + if return_uv: + return_uv = False + print("\033[93mWarning: image is None, return_uv is ignored.\033[0m") + if return_dr: + diff_attrs = [0, 1] + + if mask is not None: + attr = torch.cat([attr, mask.float().flatten(1, 2).unsqueeze(-1)], dim=-1) + + rast = rasterize_triangle_faces( + ctx, + pts, + faces, + attr, + width, + height, + view=view_tgt, + perspective=perspective_tgt, + antialiasing=antialiasing, + diff_attrs=diff_attrs, + ) + if return_dr: + output_image, screen_depth, output_dr = rast + else: + output_image, screen_depth = rast + output_mask = screen_depth < 1.0 + + if mask is not None: + output_image, rast_mask = output_image[..., :-1, :, :], output_image[..., -1, :, :] + output_mask &= (rast_mask > 0.9999).reshape(-1, height, width) + + if (return_dr or return_uv) and image is not None: + output_image, output_uv = output_image[..., :-2, :, :], output_image[..., -2:, :, :] + + output_depth = transforms.depth_buffer_to_linear(screen_depth, near=near, far=far) * output_mask + output_image = output_image * output_mask.unsqueeze(1) + + outs = [output_image, output_depth, output_mask] + if return_uv: + outs.append(output_uv) + if return_dr: + outs.append(output_dr) + return tuple(outs) + + +def warp_image_by_forward_flow( + ctx: RastContext, + image: torch.FloatTensor, + flow: torch.FloatTensor, + depth: torch.FloatTensor = None, + *, + antialiasing: bool = True, + backslash: bool = False, +) -> Tuple[torch.FloatTensor, torch.BoolTensor]: + """ + Warp image by forward flow. + NOTE: if batch size is 1, image mesh will be triangulated aware of the depth, yielding less distorted results. + Otherwise, image mesh will be triangulated simply for batch rendering. + + Args: + ctx (Union[dr.RasterizeCudaContext, dr.RasterizeGLContext]): rasterization context + image (torch.Tensor): (B, C, H, W) image + flow (torch.Tensor): (B, 2, H, W) forward flow + depth (torch.Tensor, optional): (B, H, W) linear depth. If None, will use the same for all pixels. Defaults to None. + antialiasing (bool, optional): whether to perform antialiasing. Defaults to True. + backslash (bool, optional): whether to use backslash triangulation. Defaults to False. + + Returns: + image: (torch.FloatTensor): (B, C, H, W) rendered image + mask: (torch.BoolTensor): (B, H, W) mask of valid pixels + """ + assert image.ndim == 4, f'Wrong shape of image: {image.shape}' + batch_size, _, height, width = image.shape + + if depth is None: + depth = torch.ones_like(flow[:, 0]) + + extrinsics = torch.eye(4).to(image) + fov = torch.deg2rad(torch.tensor([45.0], device=image.device)) + intrinsics = transforms.intrinsics_from_fov(fov, width, height, normalize=True)[0] + + view = transforms.extrinsics_to_view(extrinsics) + perspective = transforms.intrinsics_to_perspective(intrinsics, near=0.1, far=100) + + uv, faces = utils.image_mesh(width=width, height=height) + uv, faces = uv.to(image.device), faces.to(image.device) + uv = uv + flow.permute(0, 2, 3, 1).flatten(1, 2) + pts = transforms.unproject_cv( + uv, + depth.flatten(-2, -1), + extrinsics, + intrinsics, + ) + + # triangulate + if batch_size == 1: + faces = mesh.triangulate(faces, vertices=pts[0]) + else: + faces = mesh.triangulate(faces, backslash=backslash) + + # rasterize attributes + attr = image.permute(0, 2, 3, 1).flatten(1, 2) + rast = rasterize_triangle_faces( + ctx, + pts, + faces, + attr, + width, + height, + view=view, + perspective=perspective, + antialiasing=antialiasing, + ) + output_image, screen_depth = rast + output_mask = screen_depth < 1.0 + output_image = output_image * output_mask.unsqueeze(1) + + outs = [output_image, output_mask] + return tuple(outs) diff --git a/utils3d/torch/transforms.py b/utils3d/torch/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..46e61d741c6ef000c80aa65201b55e31ed4246c6 --- /dev/null +++ b/utils3d/torch/transforms.py @@ -0,0 +1,1189 @@ +from typing import * +from numbers import Number + +import torch +import torch.nn.functional as F + +from ._helpers import batched + + +__all__ = [ + 'perspective', + 'perspective_from_fov', + 'perspective_from_fov_xy', + 'intrinsics_from_focal_center', + 'intrinsics_from_fov', + 'intrinsics_from_fov_xy', + 'view_look_at', + 'extrinsics_look_at', + 'perspective_to_intrinsics', + 'intrinsics_to_perspective', + 'extrinsics_to_view', + 'view_to_extrinsics', + 'normalize_intrinsics', + 'crop_intrinsics', + 'pixel_to_uv', + 'pixel_to_ndc', + 'uv_to_pixel', + 'project_depth', + 'depth_buffer_to_linear', + 'project_gl', + 'project_cv', + 'unproject_gl', + 'unproject_cv', + 'skew_symmetric', + 'rotation_matrix_from_vectors', + 'euler_axis_angle_rotation', + 'euler_angles_to_matrix', + 'matrix_to_euler_angles', + 'matrix_to_quaternion', + 'quaternion_to_matrix', + 'matrix_to_axis_angle', + 'axis_angle_to_matrix', + 'axis_angle_to_quaternion', + 'quaternion_to_axis_angle', + 'slerp', + 'interpolate_extrinsics', + 'interpolate_view', + 'extrinsics_to_essential', + 'to4x4', + 'rotation_matrix_2d', + 'rotate_2d', + 'translate_2d', + 'scale_2d', + 'apply_2d', +] + + +@batched(0,0,0,0) +def perspective( + fov_y: Union[float, torch.Tensor], + aspect: Union[float, torch.Tensor], + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenGL perspective matrix + + Args: + fov_y (float | torch.Tensor): field of view in y axis + aspect (float | torch.Tensor): aspect ratio + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + + Returns: + (torch.Tensor): [..., 4, 4] perspective matrix + """ + N = fov_y.shape[0] + ret = torch.zeros((N, 4, 4), dtype=fov_y.dtype, device=fov_y.device) + ret[:, 0, 0] = 1. / (torch.tan(fov_y / 2) * aspect) + ret[:, 1, 1] = 1. / (torch.tan(fov_y / 2)) + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +def perspective_from_fov( + fov: Union[float, torch.Tensor], + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor], + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenGL perspective matrix from field of view in largest dimension + + Args: + fov (float | torch.Tensor): field of view in largest dimension + width (int | torch.Tensor): image width + height (int | torch.Tensor): image height + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + + Returns: + (torch.Tensor): [..., 4, 4] perspective matrix + """ + fov_y = 2 * torch.atan(torch.tan(fov / 2) * height / torch.maximum(width, height)) + aspect = width / height + return perspective(fov_y, aspect, near, far) + + +def perspective_from_fov_xy( + fov_x: Union[float, torch.Tensor], + fov_y: Union[float, torch.Tensor], + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenGL perspective matrix from field of view in x and y axis + + Args: + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + near (float | torch.Tensor): near plane to clip + far (float | torch.Tensor): far plane to clip + + Returns: + (torch.Tensor): [..., 4, 4] perspective matrix + """ + aspect = torch.tan(fov_x / 2) / torch.tan(fov_y / 2) + return perspective(fov_y, aspect, near, far) + + +@batched(0,0,0,0) +def intrinsics_from_focal_center( + fx: Union[float, torch.Tensor], + fy: Union[float, torch.Tensor], + cx: Union[float, torch.Tensor], + cy: Union[float, torch.Tensor] +) -> torch.Tensor: + """ + Get OpenCV intrinsics matrix + + Args: + focal_x (float | torch.Tensor): focal length in x axis + focal_y (float | torch.Tensor): focal length in y axis + cx (float | torch.Tensor): principal point in x axis + cy (float | torch.Tensor): principal point in y axis + + Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + """ + N = fx.shape[0] + ret = torch.zeros((N, 3, 3), dtype=fx.dtype, device=fx.device) + zeros, ones = torch.zeros(N, dtype=fx.dtype, device=fx.device), torch.ones(N, dtype=fx.dtype, device=fx.device) + ret = torch.stack([fx, zeros, cx, zeros, fy, cy, zeros, zeros, ones], dim=-1).unflatten(-1, (3, 3)) + return ret + + +@batched(0, 0, 0, 0, 0, 0) +def intrinsics_from_fov( + fov_max: Union[float, torch.Tensor] = None, + fov_min: Union[float, torch.Tensor] = None, + fov_x: Union[float, torch.Tensor] = None, + fov_y: Union[float, torch.Tensor] = None, + width: Union[int, torch.Tensor] = None, + height: Union[int, torch.Tensor] = None, +) -> torch.Tensor: + """ + Get normalized OpenCV intrinsics matrix from given field of view. + You can provide either fov_max, fov_min, fov_x or fov_y + + Args: + width (int | torch.Tensor): image width + height (int | torch.Tensor): image height + fov_max (float | torch.Tensor): field of view in largest dimension + fov_min (float | torch.Tensor): field of view in smallest dimension + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + + Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + """ + if fov_max is not None: + fx = torch.maximum(width, height) / width / (2 * torch.tan(fov_max / 2)) + fy = torch.maximum(width, height) / height / (2 * torch.tan(fov_max / 2)) + elif fov_min is not None: + fx = torch.minimum(width, height) / width / (2 * torch.tan(fov_min / 2)) + fy = torch.minimum(width, height) / height / (2 * torch.tan(fov_min / 2)) + elif fov_x is not None and fov_y is not None: + fx = 1 / (2 * torch.tan(fov_x / 2)) + fy = 1 / (2 * torch.tan(fov_y / 2)) + elif fov_x is not None: + fx = 1 / (2 * torch.tan(fov_x / 2)) + fy = fx * width / height + elif fov_y is not None: + fy = 1 / (2 * torch.tan(fov_y / 2)) + fx = fy * height / width + cx = 0.5 + cy = 0.5 + ret = intrinsics_from_focal_center(fx, fy, cx, cy) + return ret + + + +def intrinsics_from_fov_xy( + fov_x: Union[float, torch.Tensor], + fov_y: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Get OpenCV intrinsics matrix from field of view in x and y axis + + Args: + fov_x (float | torch.Tensor): field of view in x axis + fov_y (float | torch.Tensor): field of view in y axis + + Returns: + (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + """ + focal_x = 0.5 / torch.tan(fov_x / 2) + focal_y = 0.5 / torch.tan(fov_y / 2) + cx = cy = 0.5 + return intrinsics_from_focal_center(focal_x, focal_y, cx, cy) + + +@batched(1,1,1) +def view_look_at( + eye: torch.Tensor, + look_at: torch.Tensor, + up: torch.Tensor + ) -> torch.Tensor: + """ + Get OpenGL view matrix looking at something + + Args: + eye (torch.Tensor): [..., 3] the eye position + look_at (torch.Tensor): [..., 3] the position to look at + up (torch.Tensor): [..., 3] head up direction (y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (torch.Tensor): [..., 4, 4], view matrix + """ + N = eye.shape[0] + z = eye - look_at + x = torch.cross(up, z, dim=-1) + y = torch.cross(z, x, dim=-1) + # x = torch.cross(y, z, dim=-1) + x = x / x.norm(dim=-1, keepdim=True) + y = y / y.norm(dim=-1, keepdim=True) + z = z / z.norm(dim=-1, keepdim=True) + R = torch.stack([x, y, z], dim=-2) + t = -torch.matmul(R, eye[..., None]) + ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device) + ret[:, :3, :3] = R + ret[:, :3, 3] = t[:, :, 0] + ret[:, 3, 3] = 1. + return ret + + +@batched(1, 1, 1) +def extrinsics_look_at( + eye: torch.Tensor, + look_at: torch.Tensor, + up: torch.Tensor +) -> torch.Tensor: + """ + Get OpenCV extrinsics matrix looking at something + + Args: + eye (torch.Tensor): [..., 3] the eye position + look_at (torch.Tensor): [..., 3] the position to look at + up (torch.Tensor): [..., 3] head up direction (-y axis in screen space). Not necessarily othogonal to view direction + + Returns: + (torch.Tensor): [..., 4, 4], extrinsics matrix + """ + N = eye.shape[0] + z = look_at - eye + x = torch.cross(-up, z, dim=-1) + y = torch.cross(z, x, dim=-1) + # x = torch.cross(y, z, dim=-1) + x = x / x.norm(dim=-1, keepdim=True) + y = y / y.norm(dim=-1, keepdim=True) + z = z / z.norm(dim=-1, keepdim=True) + R = torch.stack([x, y, z], dim=-2) + t = -torch.matmul(R, eye[..., None]) + ret = torch.zeros((N, 4, 4), dtype=eye.dtype, device=eye.device) + ret[:, :3, :3] = R + ret[:, :3, 3] = t[:, :, 0] + ret[:, 3, 3] = 1. + return ret + + +@batched(2) +def perspective_to_intrinsics( + perspective: torch.Tensor +) -> torch.Tensor: + """ + OpenGL perspective matrix to OpenCV intrinsics + + Args: + perspective (torch.Tensor): [..., 4, 4] OpenGL perspective matrix + + Returns: + (torch.Tensor): shape [..., 3, 3] OpenCV intrinsics + """ + assert torch.allclose(perspective[:, [0, 1, 3], 3], 0), "The perspective matrix is not a projection matrix" + ret = torch.tensor([[0.5, 0., 0.5], [0., -0.5, 0.5], [0., 0., 1.]], dtype=perspective.dtype, device=perspective.device) \ + @ perspective[:, [0, 1, 3], :3] \ + @ torch.diag(torch.tensor([1, -1, -1], dtype=perspective.dtype, device=perspective.device)) + return ret / ret[:, 2, 2, None, None] + + +@batched(2,0,0) +def intrinsics_to_perspective( + intrinsics: torch.Tensor, + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor], + ) -> torch.Tensor: + """ + OpenCV intrinsics to OpenGL perspective matrix + + Args: + intrinsics (torch.Tensor): [..., 3, 3] OpenCV intrinsics matrix + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + Returns: + (torch.Tensor): [..., 4, 4] OpenGL perspective matrix + """ + N = intrinsics.shape[0] + fx, fy = intrinsics[:, 0, 0], intrinsics[:, 1, 1] + cx, cy = intrinsics[:, 0, 2], intrinsics[:, 1, 2] + ret = torch.zeros((N, 4, 4), dtype=intrinsics.dtype, device=intrinsics.device) + ret[:, 0, 0] = 2 * fx + ret[:, 1, 1] = 2 * fy + ret[:, 0, 2] = -2 * cx + 1 + ret[:, 1, 2] = 2 * cy - 1 + ret[:, 2, 2] = (near + far) / (near - far) + ret[:, 2, 3] = 2. * near * far / (near - far) + ret[:, 3, 2] = -1. + return ret + + +@batched(2) +def extrinsics_to_view( + extrinsics: torch.Tensor + ) -> torch.Tensor: + """ + OpenCV camera extrinsics to OpenGL view matrix + + Args: + extrinsics (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix + + Returns: + (torch.Tensor): [..., 4, 4] OpenGL view matrix + """ + return extrinsics * torch.tensor([1, -1, -1, 1], dtype=extrinsics.dtype, device=extrinsics.device)[:, None] + + +@batched(2) +def view_to_extrinsics( + view: torch.Tensor + ) -> torch.Tensor: + """ + OpenGL view matrix to OpenCV camera extrinsics + + Args: + view (torch.Tensor): [..., 4, 4] OpenGL view matrix + + Returns: + (torch.Tensor): [..., 4, 4] OpenCV camera extrinsics matrix + """ + return view * torch.tensor([1, -1, -1, 1], dtype=view.dtype, device=view.device)[:, None] + + +@batched(2,0,0) +def normalize_intrinsics( + intrinsics: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] + ) -> torch.Tensor: + """ + Normalize camera intrinsics(s) to uv space + + Args: + intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to normalize + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 3, 3] normalized camera intrinsics(s) + """ + zeros = torch.zeros_like(width) + ones = torch.ones_like(width) + transform = torch.stack([ + 1 / width, zeros, 0.5 / width, + zeros, 1 / height, 0.5 / height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3).to(intrinsics) + return transform @ intrinsics + + + +@batched(2,0,0,0,0,0,0) +def crop_intrinsics( + intrinsics: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor], + left: Union[int, torch.Tensor], + top: Union[int, torch.Tensor], + crop_width: Union[int, torch.Tensor], + crop_height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Evaluate the new intrinsics(s) after crop the image: cropped_img = img[top:top+crop_height, left:left+crop_width] + + Args: + intrinsics (torch.Tensor): [..., 3, 3] camera intrinsics(s) to crop + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + left (int | torch.Tensor): [...] left crop boundary + top (int | torch.Tensor): [...] top crop boundary + crop_width (int | torch.Tensor): [...] crop width + crop_height (int | torch.Tensor): [...] crop height + + Returns: + (torch.Tensor): [..., 3, 3] cropped camera intrinsics(s) + """ + zeros = torch.zeros_like(width) + ones = torch.ones_like(width) + transform = torch.stack([ + width / crop_width, zeros, -left / crop_width, + zeros, height / crop_height, -top / crop_height, + zeros, zeros, ones + ]).reshape(*zeros.shape, 3, 3).to(intrinsics) + return transform @ intrinsics + + +@batched(1,0,0) +def pixel_to_uv( + pixel: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Args: + pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + if not torch.is_floating_point(pixel): + pixel = pixel.float() + uv = (pixel + 0.5) / torch.stack([width, height], dim=-1).to(pixel) + return uv + + +@batched(1,0,0) +def uv_to_pixel( + uv: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Args: + uv (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in uv space, the range is (0, 1) + """ + pixel = uv * torch.stack([width, height], dim=-1).to(uv) - 0.5 + return pixel + + +@batched(1,0,0) +def pixel_to_ndc( + pixel: torch.Tensor, + width: Union[int, torch.Tensor], + height: Union[int, torch.Tensor] +) -> torch.Tensor: + """ + Args: + pixel (torch.Tensor): [..., 2] pixel coordinrates defined in image space, x range is (0, W - 1), y range is (0, H - 1) + width (int | torch.Tensor): [...] image width(s) + height (int | torch.Tensor): [...] image height(s) + + Returns: + (torch.Tensor): [..., 2] pixel coordinrates defined in ndc space, the range is (-1, 1) + """ + if not torch.is_floating_point(pixel): + pixel = pixel.float() + ndc = (pixel + 0.5) / (torch.stack([width, height], dim=-1).to(pixel) * torch.tensor([2, -2], dtype=pixel.dtype, device=pixel.device)) \ + + torch.tensor([-1, 1], dtype=pixel.dtype, device=pixel.device) + return ndc + + +@batched(0,0,0) +def project_depth( + depth: torch.Tensor, + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Project linear depth to depth value in screen space + + Args: + depth (torch.Tensor): [...] depth value + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + + Returns: + (torch.Tensor): [..., 1] depth value in screen space, value ranging in [0, 1] + """ + return (far - near * far / depth) / (far - near) + + +@batched(0,0,0) +def depth_buffer_to_linear( + depth: torch.Tensor, + near: Union[float, torch.Tensor], + far: Union[float, torch.Tensor] + ) -> torch.Tensor: + """ + Linearize depth value to linear depth + + Args: + depth (torch.Tensor): [...] screen depth value, ranging in [0, 1] + near (float | torch.Tensor): [...] near plane to clip + far (float | torch.Tensor): [...] far plane to clip + + Returns: + (torch.Tensor): [...] linear depth + """ + return near * far / (far - (far - near) * depth) + + +@batched(2, 2, 2, 2) +def project_gl( + points: torch.Tensor, + model: torch.Tensor = None, + view: torch.Tensor = None, + perspective: torch.Tensor = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Project 3D points to 2D following the OpenGL convention (except for row major matrice) + + Args: + points (torch.Tensor): [..., N, 3 or 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + model (torch.Tensor): [..., 4, 4] model matrix + view (torch.Tensor): [..., 4, 4] view matrix + perspective (torch.Tensor): [..., 4, 4] perspective matrix + + Returns: + scr_coord (torch.Tensor): [..., N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + linear_depth (torch.Tensor): [..., N] linear depth + """ + assert perspective is not None, "perspective matrix is required" + + if points.shape[-1] == 3: + points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + mvp = perspective if perspective is not None else torch.eye(4).to(points) + if view is not None: + mvp = mvp @ view + if model is not None: + mvp = mvp @ model + clip_coord = points @ mvp.transpose(-1, -2) + ndc_coord = clip_coord[..., :3] / clip_coord[..., 3:] + scr_coord = ndc_coord * 0.5 + 0.5 + linear_depth = clip_coord[..., 3] + return scr_coord, linear_depth + + +@batched(2, 2, 2) +def project_cv( + points: torch.Tensor, + extrinsics: torch.Tensor = None, + intrinsics: torch.Tensor = None +) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Project 3D points to 2D following the OpenCV convention + + Args: + points (torch.Tensor): [..., N, 3] or [..., N, 4] 3D points to project, if the last + dimension is 4, the points are assumed to be in homogeneous coordinates + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + + Returns: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + linear_depth (torch.Tensor): [..., N] linear depth + """ + assert intrinsics is not None, "intrinsics matrix is required" + if points.shape[-1] == 3: + points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + if extrinsics is not None: + points = points @ extrinsics.transpose(-1, -2) + points = points[..., :3] @ intrinsics.transpose(-2, -1) + uv_coord = points[..., :2] / points[..., 2:] + linear_depth = points[..., 2] + return uv_coord, linear_depth + + +@batched(2, 2, 2, 2) +def unproject_gl( + screen_coord: torch.Tensor, + model: torch.Tensor = None, + view: torch.Tensor = None, + perspective: torch.Tensor = None + ) -> torch.Tensor: + """ + Unproject screen space coordinates to 3D view space following the OpenGL convention (except for row major matrice) + + Args: + screen_coord (torch.Tensor): [... N, 3] screen space coordinates, value ranging in [0, 1]. + The origin (0., 0., 0.) is corresponding to the left & bottom & nearest + model (torch.Tensor): [..., 4, 4] model matrix + view (torch.Tensor): [..., 4, 4] view matrix + perspective (torch.Tensor): [..., 4, 4] perspective matrix + + Returns: + points (torch.Tensor): [..., N, 3] 3d points + """ + assert perspective is not None, "perspective matrix is required" + ndc_xy = screen_coord * 2 - 1 + clip_coord = torch.cat([ndc_xy, torch.ones_like(ndc_xy[..., :1])], dim=-1) + transform = perspective + if view is not None: + transform = transform @ view + if model is not None: + transform = transform @ model + transform = torch.inverse(transform) + points = clip_coord @ transform.transpose(-1, -2) + points = points[..., :3] / points[..., 3:] + return points + + +@batched(2, 1, 2, 2) +def unproject_cv( + uv_coord: torch.Tensor, + depth: torch.Tensor, + extrinsics: torch.Tensor = None, + intrinsics: torch.Tensor = None +) -> torch.Tensor: + """ + Unproject uv coordinates to 3D view space following the OpenCV convention + + Args: + uv_coord (torch.Tensor): [..., N, 2] uv coordinates, value ranging in [0, 1]. + The origin (0., 0.) is corresponding to the left & top + depth (torch.Tensor): [..., N] depth value + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + intrinsics (torch.Tensor): [..., 3, 3] intrinsics matrix + + Returns: + points (torch.Tensor): [..., N, 3] 3d points + """ + assert intrinsics is not None, "intrinsics matrix is required" + points = torch.cat([uv_coord, torch.ones_like(uv_coord[..., :1])], dim=-1) + points = points @ torch.inverse(intrinsics).transpose(-2, -1) + points = points * depth[..., None] + if extrinsics is not None: + points = torch.cat([points, torch.ones_like(points[..., :1])], dim=-1) + points = (points @ torch.inverse(extrinsics).transpose(-2, -1))[..., :3] + return points + + +def euler_axis_angle_rotation(axis: str, angle: torch.Tensor) -> torch.Tensor: + """ + Return the rotation matrices for one of the rotations about an axis + of which Euler angles describe, for each value of the angle given. + + Args: + axis: Axis label "X" or "Y or "Z". + angle: any shape tensor of Euler angles in radians + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + + cos = torch.cos(angle) + sin = torch.sin(angle) + one = torch.ones_like(angle) + zero = torch.zeros_like(angle) + + if axis == "X": + R_flat = (one, zero, zero, zero, cos, -sin, zero, sin, cos) + elif axis == "Y": + R_flat = (cos, zero, sin, zero, one, zero, -sin, zero, cos) + elif axis == "Z": + R_flat = (cos, -sin, zero, sin, cos, zero, zero, zero, one) + else: + raise ValueError("letter must be either X, Y or Z.") + + return torch.stack(R_flat, -1).reshape(angle.shape + (3, 3)) + + +def euler_angles_to_matrix(euler_angles: torch.Tensor, convention: str = 'XYZ') -> torch.Tensor: + """ + Convert rotations given as Euler angles in radians to rotation matrices. + + Args: + euler_angles: Euler angles in radians as tensor of shape (..., 3), XYZ + convention: permutation of "X", "Y" or "Z", representing the order of Euler rotations to apply. + + Returns: + Rotation matrices as tensor of shape (..., 3, 3). + """ + if euler_angles.dim() == 0 or euler_angles.shape[-1] != 3: + raise ValueError("Invalid input euler angles.") + if len(convention) != 3: + raise ValueError("Convention must have 3 letters.") + if convention[1] in (convention[0], convention[2]): + raise ValueError(f"Invalid convention {convention}.") + for letter in convention: + if letter not in ("X", "Y", "Z"): + raise ValueError(f"Invalid letter {letter} in convention string.") + matrices = [ + euler_axis_angle_rotation(c, euler_angles[..., 'XYZ'.index(c)]) + for c in convention + ] + # return functools.reduce(torch.matmul, matrices) + return matrices[2] @ matrices[1] @ matrices[0] + + +def skew_symmetric(v: torch.Tensor): + "Skew symmetric matrix from a 3D vector" + assert v.shape[-1] == 3, "v must be 3D" + x, y, z = v.unbind(dim=-1) + zeros = torch.zeros_like(x) + return torch.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros, + ], dim=-1).reshape(*v.shape[:-1], 3, 3) + + +def rotation_matrix_from_vectors(v1: torch.Tensor, v2: torch.Tensor): + "Rotation matrix that rotates v1 to v2" + I = torch.eye(3).to(v1) + v1 = F.normalize(v1, dim=-1) + v2 = F.normalize(v2, dim=-1) + v = torch.cross(v1, v2, dim=-1) + c = torch.sum(v1 * v2, dim=-1) + K = skew_symmetric(v) + R = I + K + (1 / (1 + c))[None, None] * (K @ K) + return R + + +def _angle_from_tan( + axis: str, other_axis: str, data, horizontal: bool, tait_bryan: bool +) -> torch.Tensor: + """ + Extract the first or third Euler angle from the two members of + the matrix which are positive constant times its sine and cosine. + + Args: + axis: Axis label "X" or "Y or "Z" for the angle we are finding. + other_axis: Axis label "X" or "Y or "Z" for the middle axis in the + convention. + data: Rotation matrices as tensor of shape (..., 3, 3). + horizontal: Whether we are looking for the angle for the third axis, + which means the relevant entries are in the same row of the + rotation matrix. If not, they are in the same column. + tait_bryan: Whether the first and third axes in the convention differ. + + Returns: + Euler Angles in radians for each matrix in data as a tensor + of shape (...). + """ + + i1, i2 = {"X": (2, 1), "Y": (0, 2), "Z": (1, 0)}[axis] + if horizontal: + i2, i1 = i1, i2 + even = (axis + other_axis) in ["XY", "YZ", "ZX"] + if horizontal == even: + return torch.atan2(data[..., i1], data[..., i2]) + if tait_bryan: + return torch.atan2(-data[..., i2], data[..., i1]) + return torch.atan2(data[..., i2], -data[..., i1]) + + +def matrix_to_euler_angles(matrix: torch.Tensor, convention: str) -> torch.Tensor: + """ + Convert rotations given as rotation matrices to Euler angles in radians. + NOTE: The composition order eg. `XYZ` means `Rz * Ry * Rx` (like blender), instead of `Rx * Ry * Rz` (like pytorch3d) + + Args: + matrix: Rotation matrices as tensor of shape (..., 3, 3). + convention: Convention string of three uppercase letters. + + Returns: + Euler angles in radians as tensor of shape (..., 3), in the order of XYZ (like blender), instead of convention (like pytorch3d) + """ + if not all(c in 'XYZ' for c in convention) or not all(c in convention for c in 'XYZ'): + raise ValueError(f"Invalid convention {convention}.") + if not matrix.shape[-2:] == (3, 3): + raise ValueError(f"Invalid rotation matrix shape {matrix.shape}.") + + i0 = 'XYZ'.index(convention[0]) + i2 = 'XYZ'.index(convention[2]) + tait_bryan = i0 != i2 + if tait_bryan: + central_angle = torch.asin(matrix[..., i2, i0] * (-1.0 if i2 - i0 in [-1, 2] else 1.0)) + else: + central_angle = torch.acos(matrix[..., i2, i2]) + + # Angles in composition order + o = [ + _angle_from_tan( + convention[0], convention[1], matrix[..., i2, :], True, tait_bryan + ), + central_angle, + _angle_from_tan( + convention[2], convention[1], matrix[..., i0], False, tait_bryan + ), + ] + return torch.stack([o[convention.index(c)] for c in 'XYZ'], -1) + + +def axis_angle_to_matrix(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert axis-angle representation (rotation vector) to rotation matrix, whose direction is the axis of rotation and length is the angle of rotation + + Args: + axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors + + Returns: + torch.Tensor: shape (..., 3, 3) The rotation matrices for the given axis-angle parameters + """ + batch_shape = axis_angle.shape[:-1] + device, dtype = axis_angle.device, axis_angle.dtype + + angle = torch.norm(axis_angle + eps, dim=-1, keepdim=True) + axis = axis_angle / angle + + cos = torch.cos(angle)[..., None, :] + sin = torch.sin(angle)[..., None, :] + + rx, ry, rz = torch.split(axis, 3, dim=-1) + zeros = torch.zeros((*batch_shape, 1), dtype=dtype, device=device) + K = torch.cat([zeros, -rz, ry, rz, zeros, -rx, -ry, rx, zeros], dim=-1).view((*batch_shape, 3, 3)) + + ident = torch.eye(3, dtype=dtype, device=device) + rot_mat = ident + sin * K + (1 - cos) * torch.matmul(K, K) + return rot_mat + + +def matrix_to_axis_angle(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert a batch of 3x3 rotation matrices to axis-angle representation (rotation vector) + + Args: + rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert + + Returns: + torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given rotation matrices + """ + quat = matrix_to_quaternion(rot_mat) + axis_angle = quaternion_to_axis_angle(quat, eps=eps) + return axis_angle + + +def quaternion_to_axis_angle(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert a batch of quaternions (w, x, y, z) to axis-angle representation (rotation vector) + + Args: + quaternion (torch.Tensor): shape (..., 4), the quaternions to convert + + Returns: + torch.Tensor: shape (..., 3), the axis-angle vectors corresponding to the given quaternions + """ + assert quaternion.shape[-1] == 4 + norm = torch.norm(quaternion[..., 1:], dim=-1, keepdim=True) + axis = quaternion[..., 1:] / norm.clamp(min=eps) + angle = 2 * torch.atan2(norm, quaternion[..., 0:1]) + return angle * axis + + +def axis_angle_to_quaternion(axis_angle: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert axis-angle representation (rotation vector) to quaternion (w, x, y, z) + + Args: + axis_angle (torch.Tensor): shape (..., 3), axis-angle vcetors + + Returns: + torch.Tensor: shape (..., 4) The quaternions for the given axis-angle parameters + """ + axis = F.normalize(axis_angle, dim=-1, eps=eps) + angle = torch.norm(axis_angle, dim=-1, keepdim=True) + quat = torch.cat([torch.cos(angle / 2), torch.sin(angle / 2) * axis], dim=-1) + return quat + + +def matrix_to_quaternion(rot_mat: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Convert 3x3 rotation matrix to quaternion (w, x, y, z) + + Args: + rot_mat (torch.Tensor): shape (..., 3, 3), the rotation matrices to convert + + Returns: + torch.Tensor: shape (..., 4), the quaternions corresponding to the given rotation matrices + """ + # Extract the diagonal and off-diagonal elements of the rotation matrix + m00, m01, m02, m10, m11, m12, m20, m21, m22 = rot_mat.flatten(-2).unbind(dim=-1) + + diag = torch.diagonal(rot_mat, dim1=-2, dim2=-1) + M = torch.tensor([ + [1, 1, 1], + [1, -1, -1], + [-1, 1, -1], + [-1, -1, 1] + ], dtype=rot_mat.dtype, device=rot_mat.device) + wxyz = (1 + diag @ M.transpose(-1, -2)).clamp_(0).sqrt().mul(0.5) + _, max_idx = wxyz.max(dim=-1) + xw = torch.sign(m21 - m12) + yw = torch.sign(m02 - m20) + zw = torch.sign(m10 - m01) + yz = torch.sign(m21 + m12) + xz = torch.sign(m02 + m20) + xy = torch.sign(m01 + m10) + ones = torch.ones_like(xw) + sign = torch.where( + max_idx[..., None] == 0, + torch.stack([ones, xw, yw, zw], dim=-1), + torch.where( + max_idx[..., None] == 1, + torch.stack([xw, ones, xy, xz], dim=-1), + torch.where( + max_idx[..., None] == 2, + torch.stack([yw, xy, ones, yz], dim=-1), + torch.stack([zw, xz, yz, ones], dim=-1) + ) + ) + ) + quat = sign * wxyz + quat = F.normalize(quat, dim=-1, eps=eps) + return quat + + +def quaternion_to_matrix(quaternion: torch.Tensor, eps: float = 1e-12) -> torch.Tensor: + """Converts a batch of quaternions (w, x, y, z) to rotation matrices + + Args: + quaternion (torch.Tensor): shape (..., 4), the quaternions to convert + + Returns: + torch.Tensor: shape (..., 3, 3), the rotation matrices corresponding to the given quaternions + """ + assert quaternion.shape[-1] == 4 + quaternion = F.normalize(quaternion, dim=-1, eps=eps) + w, x, y, z = quaternion.unbind(dim=-1) + zeros = torch.zeros_like(w) + I = torch.eye(3, dtype=quaternion.dtype, device=quaternion.device) + xyz = quaternion[..., 1:] + A = xyz[..., :, None] * xyz[..., None, :] - I * (xyz ** 2).sum(dim=-1)[..., None, None] + B = torch.stack([ + zeros, -z, y, + z, zeros, -x, + -y, x, zeros + ], dim=-1).unflatten(-1, (3, 3)) + rot_mat = I + 2 * (A + w[..., None, None] * B) + return rot_mat + + +def slerp(rot_mat_1: torch.Tensor, rot_mat_2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor: + """Spherical linear interpolation between two rotation matrices + + Args: + rot_mat_1 (torch.Tensor): shape (..., 3, 3), the first rotation matrix + rot_mat_2 (torch.Tensor): shape (..., 3, 3), the second rotation matrix + t (torch.Tensor): scalar or shape (...,), the interpolation factor + + Returns: + torch.Tensor: shape (..., 3, 3), the interpolated rotation matrix + """ + assert rot_mat_1.shape[-2:] == (3, 3) + rot_vec_1 = matrix_to_axis_angle(rot_mat_1) + rot_vec_2 = matrix_to_axis_angle(rot_mat_2) + if isinstance(t, Number): + t = torch.tensor(t, dtype=rot_mat_1.dtype, device=rot_mat_1.device) + rot_vec = (1 - t[..., None]) * rot_vec_1 + t[..., None] * rot_vec_2 + rot_mat = axis_angle_to_matrix(rot_vec) + return rot_mat + + +def interpolate_extrinsics(ext1: torch.Tensor, ext2: torch.Tensor, t: Union[Number, torch.Tensor]) -> torch.Tensor: + """Interpolate extrinsics between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. + + Args: + ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose + ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose + t (torch.Tensor): scalar or shape (...,), the interpolation factor + + Returns: + torch.Tensor: shape (..., 4, 4), the interpolated camera pose + """ + return torch.inverse(interpolate_transform(torch.inverse(ext1), torch.inverse(ext2), t)) + + +def interpolate_view(view1: torch.Tensor, view2: torch.Tensor, t: Union[Number, torch.Tensor]): + """Interpolate view matrices between two camera poses. Linear interpolation for translation, spherical linear interpolation for rotation. + + Args: + ext1 (torch.Tensor): shape (..., 4, 4), the first camera pose + ext2 (torch.Tensor): shape (..., 4, 4), the second camera pose + t (torch.Tensor): scalar or shape (...,), the interpolation factor + + Returns: + torch.Tensor: shape (..., 4, 4), the interpolated camera pose + """ + return interpolate_extrinsics(view1, view2, t) + + +def interpolate_transform(transform1: torch.Tensor, transform2: torch.Tensor, t: Union[Number, torch.Tensor]): + assert transform1.shape[-2:] == (4, 4) and transform2.shape[-2:] == (4, 4) + if isinstance(t, Number): + t = torch.tensor(t, dtype=transform1.dtype, device=transform1.device) + pos = (1 - t[..., None]) * transform1[..., :3, 3] + t[..., None] * transform2[..., :3, 3] + rot = slerp(transform1[..., :3, :3], transform2[..., :3, :3], t) + transform = torch.cat([rot, pos[..., None]], dim=-1) + transform = torch.cat([ext, torch.tensor([0, 0, 0, 1], dtype=transform.dtype, device=transform.device).expand_as(transform[..., :1, :])], dim=-2) + return transform + + +def extrinsics_to_essential(extrinsics: torch.Tensor): + """ + extrinsics matrix `[[R, t] [0, 0, 0, 1]]` such that `x' = R (x - t)` to essential matrix such that `x' E x = 0` + + Args: + extrinsics (torch.Tensor): [..., 4, 4] extrinsics matrix + + Returns: + (torch.Tensor): [..., 3, 3] essential matrix + """ + assert extrinsics.shape[-2:] == (4, 4) + R = extrinsics[..., :3, :3] + t = extrinsics[..., :3, 3] + zeros = torch.zeros_like(t) + t_x = torch.stack([ + zeros, -t[..., 2], t[..., 1], + t[..., 2], zeros, -t[..., 0], + -t[..., 1], t[..., 0], zeros + ]).reshape(*t.shape[:-1], 3, 3) + return R @ t_x + + +def to4x4(R: torch.Tensor, t: torch.Tensor): + """ + Compose rotation matrix and translation vector to 4x4 transformation matrix + + Args: + R (torch.Tensor): [..., 3, 3] rotation matrix + t (torch.Tensor): [..., 3] translation vector + + Returns: + (torch.Tensor): [..., 4, 4] transformation matrix + """ + assert R.shape[-2:] == (3, 3) + assert t.shape[-1] == 3 + assert R.shape[:-2] == t.shape[:-1] + return torch.cat([ + torch.cat([R, t[..., None]], dim=-1), + torch.tensor([0, 0, 0, 1], dtype=R.dtype, device=R.device).expand(*R.shape[:-2], 1, 4) + ], dim=-2) + + +def rotation_matrix_2d(theta: Union[float, torch.Tensor]): + """ + 2x2 matrix for 2D rotation + + Args: + theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) + + Returns: + (torch.Tensor): (..., 2, 2) rotation matrix + """ + if isinstance(theta, float): + theta = torch.tensor(theta) + return torch.stack([ + torch.cos(theta), -torch.sin(theta), + torch.sin(theta), torch.cos(theta), + ], dim=-1).unflatten(-1, (2, 2)) + + +def rotate_2d(theta: Union[float, torch.Tensor], center: torch.Tensor = None): + """ + 3x3 matrix for 2D rotation around a center + ``` + [[Rxx, Rxy, tx], + [Ryx, Ryy, ty], + [0, 0, 1]] + ``` + Args: + theta (float | torch.Tensor): rotation angle in radians, arbitrary shape (...,) + center (torch.Tensor): rotation center, arbitrary shape (..., 2). Default to (0, 0) + + Returns: + (torch.Tensor): (..., 3, 3) transformation matrix + """ + if isinstance(theta, float): + theta = torch.tensor(theta) + if center is not None: + theta = theta.to(center) + if center is None: + center = torch.zeros(2).to(theta).expand(*theta.shape, -1) + R = rotation_matrix_2d(theta) + return torch.cat([ + torch.cat([ + R, + center[..., :, None] - R @ center[..., :, None], + ], dim=-1), + torch.tensor([[0, 0, 1]], dtype=center.dtype, device=center.device).expand(*center.shape[:-1], -1, -1), + ], dim=-2) + + +def translate_2d(translation: torch.Tensor): + """ + Translation matrix for 2D translation + ``` + [[1, 0, tx], + [0, 1, ty], + [0, 0, 1]] + ``` + Args: + translation (torch.Tensor): translation vector, arbitrary shape (..., 2) + + Returns: + (torch.Tensor): (..., 3, 3) transformation matrix + """ + return torch.cat([ + torch.cat([ + torch.eye(2, dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1), + translation[..., None], + ], dim=-1), + torch.tensor([[0, 0, 1]], dtype=translation.dtype, device=translation.device).expand(*translation.shape[:-1], -1, -1), + ], dim=-2) + + +def scale_2d(scale: Union[float, torch.Tensor], center: torch.Tensor = None): + """ + Scale matrix for 2D scaling + ``` + [[s, 0, tx], + [0, s, ty], + [0, 0, 1]] + ``` + Args: + scale (float | torch.Tensor): scale factor, arbitrary shape (...,) + center (torch.Tensor): scale center, arbitrary shape (..., 2). Default to (0, 0) + + Returns: + (torch.Tensor): (..., 3, 3) transformation matrix + """ + if isinstance(scale, float): + scale = torch.tensor(scale) + if center is not None: + scale = scale.to(center) + if center is None: + center = torch.zeros(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape, -1) + return torch.cat([ + torch.cat([ + scale * torch.eye(2, dtype=scale.dtype, device=scale.device).expand(*scale.shape[:-1], -1, -1), + center[..., :, None] - center[..., :, None] * scale[..., None, None], + ], dim=-1), + torch.tensor([[0, 0, 1]], dtype=scale.dtype, device=scale.device).expand(*center.shape[:-1], -1, -1), + ], dim=-2) + + +def apply_2d(transform: torch.Tensor, points: torch.Tensor): + """ + Apply (3x3 or 2x3) 2D affine transformation to points + ``` + p = R @ p + t + ``` + Args: + transform (torch.Tensor): (..., 2 or 3, 3) transformation matrix + points (torch.Tensor): (..., N, 2) points to transform + + Returns: + (torch.Tensor): (..., N, 2) transformed points + """ + assert transform.shape[-2:] == (3, 3) or transform.shape[-2:] == (2, 3), "transform must be 3x3 or 2x3" + assert points.shape[-1] == 2, "points must be 2D" + return points @ transform[..., :2, :2].mT + transform[..., :2, None, 2] \ No newline at end of file diff --git a/utils3d/torch/utils.py b/utils3d/torch/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..877ffb8a60a7f5206fbeb5a9e4a584758b875da4 --- /dev/null +++ b/utils3d/torch/utils.py @@ -0,0 +1,351 @@ +from typing import * + +import torch +import torch.nn.functional as F + +from . import transforms +from . import mesh +from ._helpers import batched + + +__all__ = [ + 'sliding_window_1d', + 'sliding_window_2d', + 'sliding_window_nd', + 'image_uv', + 'image_pixel_center', + 'image_mesh', + 'chessboard', + 'depth_edge', + 'depth_aliasing', + 'image_mesh_from_depth', + 'point_to_normal', + 'depth_to_normal', + 'masked_min', + 'masked_max', + 'bounding_rect' +] + + +def sliding_window_1d(x: torch.Tensor, window_size: int, stride: int = 1, dim: int = -1) -> torch.Tensor: + """ + Sliding window view of the input tensor. The dimension of the sliding window is appended to the end of the input tensor's shape. + NOTE: Since Pytorch has `unfold` function, 1D sliding window view is just a wrapper of it. + """ + return x.unfold(dim, window_size, stride) + + +def sliding_window_nd(x: torch.Tensor, window_size: Tuple[int, ...], stride: Tuple[int, ...], dim: Tuple[int, ...]) -> torch.Tensor: + dim = [dim[i] % x.ndim for i in range(len(dim))] + assert len(window_size) == len(stride) == len(dim) + for i in range(len(window_size)): + x = sliding_window_1d(x, window_size[i], stride[i], dim[i]) + return x + + +def sliding_window_2d(x: torch.Tensor, window_size: Union[int, Tuple[int, int]], stride: Union[int, Tuple[int, int]], dim: Union[int, Tuple[int, int]] = (-2, -1)) -> torch.Tensor: + if isinstance(window_size, int): + window_size = (window_size, window_size) + if isinstance(stride, int): + stride = (stride, stride) + return sliding_window_nd(x, window_size, stride, dim) + + +def image_uv(height: int, width: int, left: int = None, top: int = None, right: int = None, bottom: int = None, device: torch.device = None, dtype: torch.dtype = None) -> torch.Tensor: + """ + Get image space UV grid, ranging in [0, 1]. + + >>> image_uv(10, 10): + [[[0.05, 0.05], [0.15, 0.05], ..., [0.95, 0.05]], + [[0.05, 0.15], [0.15, 0.15], ..., [0.95, 0.15]], + ... ... ... + [[0.05, 0.95], [0.15, 0.95], ..., [0.95, 0.95]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = torch.linspace((left + 0.5) / width, (right - 0.5) / width, right - left, device=device, dtype=dtype) + v = torch.linspace((top + 0.5) / height, (bottom - 0.5) / height, bottom - top, device=device, dtype=dtype) + u, v = torch.meshgrid(u, v, indexing='xy') + uv = torch.stack([u, v], dim=-1) + return uv + + +def image_pixel_center( + height: int, + width: int, + left: int = None, + top: int = None, + right: int = None, + bottom: int = None, + dtype: torch.dtype = None, + device: torch.device = None +) -> torch.Tensor: + """ + Get image pixel center coordinates, ranging in [0, width] and [0, height]. + `image[i, j]` has pixel center coordinates `(j + 0.5, i + 0.5)`. + + >>> image_pixel_center(10, 10): + [[[0.5, 0.5], [1.5, 0.5], ..., [9.5, 0.5]], + [[0.5, 1.5], [1.5, 1.5], ..., [9.5, 1.5]], + ... ... ... + [[0.5, 9.5], [1.5, 9.5], ..., [9.5, 9.5]]] + + Args: + width (int): image width + height (int): image height + + Returns: + np.ndarray: shape (height, width, 2) + """ + if left is None: left = 0 + if top is None: top = 0 + if right is None: right = width + if bottom is None: bottom = height + u = torch.linspace(left + 0.5, right - 0.5, right - left, dtype=dtype, device=device) + v = torch.linspace(top + 0.5, bottom - 0.5, bottom - top, dtype=dtype, device=device) + u, v = torch.meshgrid(u, v, indexing='xy') + return torch.stack([u, v], dim=2) + + +def image_mesh(height: int, width: int, mask: torch.Tensor = None, device: torch.device = None, dtype: torch.dtype = None) -> Tuple[torch.Tensor, torch.Tensor]: + """ + Get a quad mesh regarding image pixel uv coordinates as vertices and image grid as faces. + + Args: + width (int): image width + height (int): image height + mask (np.ndarray, optional): binary mask of shape (height, width), dtype=bool. Defaults to None. + + Returns: + uv (np.ndarray): uv corresponding to pixels as described in image_uv() + faces (np.ndarray): quad faces connecting neighboring pixels + indices (np.ndarray, optional): indices of vertices in the original mesh + """ + if device is None and mask is not None: + device = mask.device + if mask is not None: + assert mask.shape[0] == height and mask.shape[1] == width + assert mask.dtype == torch.bool + uv = image_uv(height, width, device=device, dtype=dtype).reshape((-1, 2)) + row_faces = torch.stack([ + torch.arange(0, width - 1, dtype=torch.int32, device=device), + torch.arange(width, 2 * width - 1, dtype=torch.int32, device=device), + torch.arange(1 + width, 2 * width, dtype=torch.int32, device=device), + torch.arange(1, width, dtype=torch.int32, device=device) + ], dim=1) + faces = (torch.arange(0, (height - 1) * width, width, device=device, dtype=torch.int32)[:, None, None] + row_faces[None, :, :]).reshape((-1, 4)) + if mask is not None: + quad_mask = (mask[:-1, :-1] & mask[1:, :-1] & mask[1:, 1:] & mask[:-1, 1:]).ravel() + faces = faces[quad_mask] + faces, uv, indices = mesh.remove_unreferenced_vertices(faces, uv, return_indices=True) + return uv, faces, indices + return uv, faces + + +def depth_edge(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: + """ + Compute the edge mask of a depth map. The edge is defined as the pixels whose neighbors have a large difference in depth. + + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool + """ + shape = depth.shape + depth = depth.reshape(-1, 1, *shape[-2:]) + if mask is not None: + mask = mask.reshape(-1, 1, *shape[-2:]) + + if mask is None: + diff = (F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2)) + else: + diff = (F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2)) + + edge = torch.zeros_like(depth, dtype=torch.bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= (diff / depth).nan_to_num_() > rtol + edge = edge.reshape(*shape) + return edge + + +def depth_aliasing(depth: torch.Tensor, atol: float = None, rtol: float = None, kernel_size: int = 3, mask: torch.Tensor = None) -> torch.BoolTensor: + """ + Compute the map that indicates the aliasing of a depth map. The aliasing is defined as the pixels which neither close to the maximum nor the minimum of its neighbors. + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + atol (float): absolute tolerance + rtol (float): relative tolerance + + Returns: + edge (torch.Tensor): shape (..., height, width) of dtype torch.bool + """ + shape = depth.shape + depth = depth.reshape(-1, 1, *shape[-2:]) + if mask is not None: + mask = mask.reshape(-1, 1, *shape[-2:]) + + if mask is None: + diff_max = F.max_pool2d(depth, kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = F.max_pool2d(-depth, kernel_size, stride=1, padding=kernel_size // 2) + depth + else: + diff_max = F.max_pool2d(torch.where(mask, depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) - depth + diff_min = F.max_pool2d(torch.where(mask, -depth, -torch.inf), kernel_size, stride=1, padding=kernel_size // 2) + depth + diff = torch.minimum(diff_max, diff_min) + + edge = torch.zeros_like(depth, dtype=torch.bool) + if atol is not None: + edge |= diff > atol + if rtol is not None: + edge |= (diff / depth).nan_to_num_() > rtol + edge = edge.reshape(*shape) + return edge + + +def image_mesh_from_depth( + depth: torch.Tensor, + extrinsics: torch.Tensor = None, + intrinsics: torch.Tensor = None +) -> Tuple[torch.Tensor, torch.Tensor]: + height, width = depth.shape + uv, faces = image_mesh(height, width) + faces = faces.reshape(-1, 4) + depth = depth.reshape(-1) + pts = transforms.unproject_cv(image_uv, depth, extrinsics, intrinsics) + faces = mesh.triangulate(faces, vertices=pts) + return pts, faces + + +@batched(3, 2, 2) +def point_to_normal(point: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """ + Calculate normal map from point map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + point (torch.Tensor): shape (..., height, width, 3), point map + Returns: + normal (torch.Tensor): shape (..., height, width, 3), normal map. + """ + has_mask = mask is not None + + if mask is None: + mask = torch.ones_like(point[..., 0], dtype=torch.bool) + mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) + + pts = F.pad(point.permute(0, 3, 1, 2), (1, 1, 1, 1), mode='constant', value=1).permute(0, 2, 3, 1) + up = pts[:, :-2, 1:-1, :] - pts[:, 1:-1, 1:-1, :] + left = pts[:, 1:-1, :-2, :] - pts[:, 1:-1, 1:-1, :] + down = pts[:, 2:, 1:-1, :] - pts[:, 1:-1, 1:-1, :] + right = pts[:, 1:-1, 2:, :] - pts[:, 1:-1, 1:-1, :] + normal = torch.stack([ + torch.cross(up, left, dim=-1), + torch.cross(left, down, dim=-1), + torch.cross(down, right, dim=-1), + torch.cross(right, up, dim=-1), + ]) + normal = F.normalize(normal, dim=-1) + valid = torch.stack([ + mask[:, :-2, 1:-1] & mask[:, 1:-1, :-2], + mask[:, 1:-1, :-2] & mask[:, 2:, 1:-1], + mask[:, 2:, 1:-1] & mask[:, 1:-1, 2:], + mask[:, 1:-1, 2:] & mask[:, :-2, 1:-1], + ]) & mask[None, :, 1:-1, 1:-1] + normal = (normal * valid[..., None]).sum(dim=0) + normal = F.normalize(normal, dim=-1) + + if has_mask: + return normal, valid.any(dim=0) + else: + return normal + + +@batched(2, 2, 2) +def depth_to_normal(depth: torch.Tensor, intrinsics: torch.Tensor, mask: torch.Tensor = None) -> torch.Tensor: + """ + Calculate normal map from depth map. Value range is [-1, 1]. Normal direction in OpenGL identity camera's coordinate system. + + Args: + depth (torch.Tensor): shape (..., height, width), linear depth map + intrinsics (torch.Tensor): shape (..., 3, 3), intrinsics matrix + Returns: + normal (torch.Tensor): shape (..., 3, height, width), normal map. + """ + has_mask = mask is not None + + height, width = depth.shape[-2:] + if mask is None: + mask = torch.ones_like(depth, dtype=torch.bool) + mask = F.pad(mask, (1, 1, 1, 1), mode='constant', value=0) + + uv = image_uv(*depth.shape[-2:]).unsqueeze(0).to(depth) + pts = transforms.unproject_cv(uv.reshape(-1, 2), depth.flatten(-2), intrinsics=intrinsics, extrinsics=None).unflatten(-2, (height, width)) + + return point_to_normal(pts, mask) + + +def masked_min(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Similar to torch.min, but with mask + """ + if dim is None: + return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min() + else: + return torch.where(mask, input, torch.tensor(torch.inf, dtype=input.dtype, device=input.device)).min(dim=dim, keepdim=keepdim) + + +def masked_max(input: torch.Tensor, mask: torch.BoolTensor, dim: int = None, keepdim: bool = False) -> Union[torch.Tensor, Tuple[torch.Tensor, torch.Tensor]]: + """Similar to torch.max, but with mask + """ + if dim is None: + return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max() + else: + return torch.where(mask, input, torch.tensor(-torch.inf, dtype=input.dtype, device=input.device)).max(dim=dim, keepdim=keepdim) + + +def bounding_rect(mask: torch.BoolTensor): + """get bounding rectangle of a mask + + Args: + mask (torch.Tensor): shape (..., height, width), mask + + Returns: + rect (torch.Tensor): shape (..., 4), bounding rectangle (left, top, right, bottom) + """ + height, width = mask.shape[-2:] + mask = mask.flatten(-2).unsqueeze(-1) + uv = image_uv(height, width).to(mask.device).reshape(-1, 2) + left_top = masked_min(uv, mask, dim=-2)[0] + right_bottom = masked_max(uv, mask, dim=-2)[0] + return torch.cat([left_top, right_bottom], dim=-1) + + +def chessboard(width: int, height: int, grid_size: int, color_a: torch.Tensor, color_b: torch.Tensor) -> torch.Tensor: + """get a chessboard image + + Args: + width (int): image width + height (int): image height + grid_size (int): size of chessboard grid + color_a (torch.Tensor): shape (chanenls,), color of the grid at the top-left corner + color_b (torch.Tensor): shape (chanenls,), color in complementary grids + + Returns: + image (torch.Tensor): shape (height, width, channels), chessboard image + """ + x = torch.div(torch.arange(width), grid_size, rounding_mode='floor') + y = torch.div(torch.arange(height), grid_size, rounding_mode='floor') + mask = ((x[None, :] + y[:, None]) % 2).to(color_a) + image = (1 - mask[..., None]) * color_a + mask[..., None] * color_b + return image \ No newline at end of file