Ruicheng commited on
Commit
ec0c8fa
1 Parent(s): 119634a

first commit

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +25 -0
  2. app.py +111 -0
  3. moge/model/__init__.py +1 -0
  4. moge/model/dinov2/__init__.py +6 -0
  5. moge/model/dinov2/hub/__init__.py +4 -0
  6. moge/model/dinov2/hub/backbones.py +156 -0
  7. moge/model/dinov2/hub/utils.py +39 -0
  8. moge/model/dinov2/layers/__init__.py +11 -0
  9. moge/model/dinov2/layers/attention.py +89 -0
  10. moge/model/dinov2/layers/block.py +259 -0
  11. moge/model/dinov2/layers/dino_head.py +58 -0
  12. moge/model/dinov2/layers/drop_path.py +34 -0
  13. moge/model/dinov2/layers/layer_scale.py +27 -0
  14. moge/model/dinov2/layers/mlp.py +40 -0
  15. moge/model/dinov2/layers/patch_embed.py +88 -0
  16. moge/model/dinov2/layers/swiglu_ffn.py +72 -0
  17. moge/model/dinov2/models/__init__.py +43 -0
  18. moge/model/dinov2/models/vision_transformer.py +396 -0
  19. moge/model/dinov2/utils/__init__.py +4 -0
  20. moge/model/dinov2/utils/cluster.py +95 -0
  21. moge/model/dinov2/utils/config.py +72 -0
  22. moge/model/dinov2/utils/dtype.py +37 -0
  23. moge/model/dinov2/utils/param_groups.py +103 -0
  24. moge/model/dinov2/utils/utils.py +95 -0
  25. moge/model/moge_model.py +376 -0
  26. moge/model/utils.py +38 -0
  27. moge/utils/__init__.py +0 -0
  28. moge/utils/blob.py +314 -0
  29. moge/utils/download.py +55 -0
  30. moge/utils/geometry_numpy.py +175 -0
  31. moge/utils/geometry_torch.py +231 -0
  32. moge/utils/io.py +347 -0
  33. moge/utils/pipeline.py +503 -0
  34. moge/utils/tools.py +240 -0
  35. moge/utils/vis.py +51 -0
  36. moge/utils/webfile.py +73 -0
  37. moge/utils/webzipfile.py +128 -0
  38. packages.txt +1 -0
  39. requirements.txt +5 -0
  40. utils3d/__init__.py +14 -0
  41. utils3d/io/__init__.py +4 -0
  42. utils3d/io/colmap.py +139 -0
  43. utils3d/io/glb.py +105 -0
  44. utils3d/io/ply.py +104 -0
  45. utils3d/io/wavefront_obj.py +146 -0
  46. utils3d/numpy/__init__.py +135 -0
  47. utils3d/numpy/_helpers.py +88 -0
  48. utils3d/numpy/mesh.py +355 -0
  49. utils3d/numpy/quadmesh.py +472 -0
  50. utils3d/numpy/rasterization.py +471 -0
.gitignore ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /data
2
+ /download
3
+ /extract
4
+ /view_point_cloud
5
+ /view_depth_map
6
+ /blobcache
7
+ /snapshot
8
+ /reference_embeddings
9
+ /.gradio
10
+ /debug
11
+ /workspace
12
+ /mlruns
13
+ /infer_output
14
+ /video_output
15
+ /eval_output
16
+ /.blobcache
17
+ /test_images
18
+ /test_videos
19
+ /vis
20
+ /videos
21
+ /raid
22
+ /blobmnt
23
+ /eval_dump
24
+ /pretrained
25
+ __pycache__/
app.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import time
3
+ from pathlib import Path
4
+ import uuid
5
+ import tempfile
6
+ from typing import Union
7
+ import spaces
8
+ import atexit
9
+ from concurrent.futures import ThreadPoolExecutor
10
+
11
+ import gradio as gr
12
+ import cv2
13
+ import torch
14
+ import numpy as np
15
+
16
+ from moge.model import MoGeModel
17
+ from moge.utils.vis import colorize_depth
18
+ import utils3d
19
+
20
+ model = MoGeModel.from_pretrained('Ruicheng/moge-vitl').cuda().eval()
21
+ thread_pool_executor = ThreadPoolExecutor(max_workers=1)
22
+
23
+
24
+ def delete_later(path: Union[str, os.PathLike], delay: int = 300):
25
+ def _delete():
26
+ try:
27
+ os.remove(path)
28
+ except:
29
+ pass
30
+ def _wait_and_delete():
31
+ time.sleep(delay)
32
+ _delete(path)
33
+ thread_pool_executor.submit(_wait_and_delete)
34
+ atexit.register(_delete)
35
+
36
+ @spaces.GPU
37
+ def run(image: np.ndarray, remove_edge: bool = True):
38
+ run_id = str(uuid.uuid4())
39
+
40
+ larger_size = max(image.shape[:2])
41
+ if larger_size > 1024:
42
+ scale = 1024 / larger_size
43
+ image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
44
+
45
+ image_tensor = torch.tensor(image, dtype=torch.float32, device=torch.device('cuda')).permute(2, 0, 1) / 255
46
+ output = model.infer(image_tensor, resolution_level=9, apply_mask=True)
47
+ points, depth, mask = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy()
48
+
49
+ if remove_edge:
50
+ mask = mask & ~utils3d.numpy.depth_edge(depth, mask=mask, rtol=0.02)
51
+ mask = mask & (depth > 0)
52
+
53
+ _, faces, indices = utils3d.numpy.image_mesh(width=image.shape[1], height=image.shape[0], mask=mask)
54
+ faces = utils3d.numpy.triangulate(faces)
55
+
56
+ tempdir = Path(tempfile.gettempdir(), 'moge')
57
+ tempdir.mkdir(exist_ok=True)
58
+
59
+ output_glb_path = Path(tempdir, f'{run_id}.glb')
60
+ output_glb_path.parent.mkdir(exist_ok=True)
61
+ tempfile.TemporaryFile()
62
+ utils3d.io.write_glb(
63
+ output_glb_path,
64
+ vertices=points.reshape(-1, 3)[indices] * [-1, -1, 1],
65
+ faces=faces,
66
+ vertex_colors=image.reshape(-1, 3)[indices] / 255,
67
+ )
68
+
69
+ output_ply_path = Path(tempdir, f'{run_id}.ply')
70
+ output_ply_path.parent.mkdir(exist_ok=True)
71
+ utils3d.io.write_ply(
72
+ output_ply_path,
73
+ vertices=points.reshape(-1, 3)[indices] * [-1, -1, 1],
74
+ faces=faces,
75
+ vertex_colors=image.reshape(-1, 3)[indices] / 255,
76
+ )
77
+
78
+ colorized_depth = colorize_depth(depth)
79
+
80
+ delete_later(output_glb_path, delay=300)
81
+ delete_later(output_ply_path, delay=300)
82
+
83
+ return colorized_depth, output_glb_path, output_ply_path.as_posix()
84
+
85
+
86
+ DESCRIPTION = """
87
+ MoGe turns 2D images into 3D point maps.
88
+
89
+ NOTE:
90
+ * If the image is too large (> 1024px), it will be resized accordingly.
91
+ * The color in the 3D viewer may look dark due to rendering of 3D viewer. You may download the 3D model as .glb or .ply file to view it in other 3D viewers.
92
+ """
93
+
94
+ if __name__ == '__main__':
95
+
96
+ gr.Interface(
97
+ fn=run,
98
+ inputs=[
99
+ gr.Image(type="numpy", image_mode="RGB"),
100
+ gr.Checkbox(True, label="Remove edges"),
101
+ ],
102
+ outputs=[
103
+ gr.Image(type="numpy", label="Depth map (colorized)"),
104
+ gr.Model3D(display_mode="solid", clear_color=[1.0, 1.0, 1.0, 1.0], label="3D Viewer"),
105
+ gr.File(type="filepath", label="Download the model as .ply file"),
106
+ ],
107
+ title="MoGe Live Demo",
108
+ description=DESCRIPTION,
109
+ clear_btn=None,
110
+ allow_flagging="never",
111
+ ).launch(share=False)
moge/model/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ from .moge_model import MoGeModel
moge/model/dinov2/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ __version__ = "0.0.1"
moge/model/dinov2/hub/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
moge/model/dinov2/hub/backbones.py ADDED
@@ -0,0 +1,156 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ from typing import Union
8
+
9
+ import torch
10
+
11
+ from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
12
+
13
+
14
+ class Weights(Enum):
15
+ LVD142M = "LVD142M"
16
+
17
+
18
+ def _make_dinov2_model(
19
+ *,
20
+ arch_name: str = "vit_large",
21
+ img_size: int = 518,
22
+ patch_size: int = 14,
23
+ init_values: float = 1.0,
24
+ ffn_layer: str = "mlp",
25
+ block_chunks: int = 0,
26
+ num_register_tokens: int = 0,
27
+ interpolate_antialias: bool = False,
28
+ interpolate_offset: float = 0.1,
29
+ pretrained: bool = True,
30
+ weights: Union[Weights, str] = Weights.LVD142M,
31
+ **kwargs,
32
+ ):
33
+ from ..models import vision_transformer as vits
34
+
35
+ if isinstance(weights, str):
36
+ try:
37
+ weights = Weights[weights]
38
+ except KeyError:
39
+ raise AssertionError(f"Unsupported weights: {weights}")
40
+
41
+ model_base_name = _make_dinov2_model_name(arch_name, patch_size)
42
+ vit_kwargs = dict(
43
+ img_size=img_size,
44
+ patch_size=patch_size,
45
+ init_values=init_values,
46
+ ffn_layer=ffn_layer,
47
+ block_chunks=block_chunks,
48
+ num_register_tokens=num_register_tokens,
49
+ interpolate_antialias=interpolate_antialias,
50
+ interpolate_offset=interpolate_offset,
51
+ )
52
+ vit_kwargs.update(**kwargs)
53
+ model = vits.__dict__[arch_name](**vit_kwargs)
54
+
55
+ if pretrained:
56
+ model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
57
+ url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
58
+ state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
59
+ model.load_state_dict(state_dict, strict=True)
60
+
61
+ return model
62
+
63
+
64
+ def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
65
+ """
66
+ DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
67
+ """
68
+ return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
69
+
70
+
71
+ def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
72
+ """
73
+ DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
74
+ """
75
+ return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
76
+
77
+
78
+ def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
79
+ """
80
+ DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
81
+ """
82
+ return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
83
+
84
+
85
+ def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
86
+ """
87
+ DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
88
+ """
89
+ return _make_dinov2_model(
90
+ arch_name="vit_giant2",
91
+ ffn_layer="swiglufused",
92
+ weights=weights,
93
+ pretrained=pretrained,
94
+ **kwargs,
95
+ )
96
+
97
+
98
+ def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
99
+ """
100
+ DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
101
+ """
102
+ return _make_dinov2_model(
103
+ arch_name="vit_small",
104
+ pretrained=pretrained,
105
+ weights=weights,
106
+ num_register_tokens=4,
107
+ interpolate_antialias=True,
108
+ interpolate_offset=0.0,
109
+ **kwargs,
110
+ )
111
+
112
+
113
+ def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
114
+ """
115
+ DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
116
+ """
117
+ return _make_dinov2_model(
118
+ arch_name="vit_base",
119
+ pretrained=pretrained,
120
+ weights=weights,
121
+ num_register_tokens=4,
122
+ interpolate_antialias=True,
123
+ interpolate_offset=0.0,
124
+ **kwargs,
125
+ )
126
+
127
+
128
+ def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
129
+ """
130
+ DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
131
+ """
132
+ return _make_dinov2_model(
133
+ arch_name="vit_large",
134
+ pretrained=pretrained,
135
+ weights=weights,
136
+ num_register_tokens=4,
137
+ interpolate_antialias=True,
138
+ interpolate_offset=0.0,
139
+ **kwargs,
140
+ )
141
+
142
+
143
+ def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
144
+ """
145
+ DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
146
+ """
147
+ return _make_dinov2_model(
148
+ arch_name="vit_giant2",
149
+ ffn_layer="swiglufused",
150
+ weights=weights,
151
+ pretrained=pretrained,
152
+ num_register_tokens=4,
153
+ interpolate_antialias=True,
154
+ interpolate_offset=0.0,
155
+ **kwargs,
156
+ )
moge/model/dinov2/hub/utils.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import itertools
7
+ import math
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ _DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
15
+
16
+
17
+ def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
18
+ compact_arch_name = arch_name.replace("_", "")[:4]
19
+ registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
20
+ return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
21
+
22
+
23
+ class CenterPadding(nn.Module):
24
+ def __init__(self, multiple):
25
+ super().__init__()
26
+ self.multiple = multiple
27
+
28
+ def _get_pad(self, size):
29
+ new_size = math.ceil(size / self.multiple) * self.multiple
30
+ pad_size = new_size - size
31
+ pad_size_left = pad_size // 2
32
+ pad_size_right = pad_size - pad_size_left
33
+ return pad_size_left, pad_size_right
34
+
35
+ @torch.inference_mode()
36
+ def forward(self, x):
37
+ pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
38
+ output = F.pad(x, pads)
39
+ return output
moge/model/dinov2/layers/__init__.py ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from .dino_head import DINOHead
7
+ from .mlp import Mlp
8
+ from .patch_embed import PatchEmbed
9
+ from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
10
+ from .block import NestedTensorBlock
11
+ from .attention import MemEffAttention
moge/model/dinov2/layers/attention.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ import logging
11
+ import os
12
+ import warnings
13
+
14
+ from torch import Tensor
15
+ from torch import nn
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
22
+ try:
23
+ if XFORMERS_ENABLED:
24
+ from xformers.ops import memory_efficient_attention, unbind
25
+
26
+ XFORMERS_AVAILABLE = True
27
+ # warnings.warn("xFormers is available (Attention)")
28
+ else:
29
+ # warnings.warn("xFormers is disabled (Attention)")
30
+ raise ImportError
31
+ except ImportError:
32
+ XFORMERS_AVAILABLE = False
33
+ # warnings.warn("xFormers is not available (Attention)")
34
+
35
+
36
+ class Attention(nn.Module):
37
+ def __init__(
38
+ self,
39
+ dim: int,
40
+ num_heads: int = 8,
41
+ qkv_bias: bool = False,
42
+ proj_bias: bool = True,
43
+ attn_drop: float = 0.0,
44
+ proj_drop: float = 0.0,
45
+ ) -> None:
46
+ super().__init__()
47
+ self.num_heads = num_heads
48
+ head_dim = dim // num_heads
49
+ self.scale = head_dim**-0.5
50
+
51
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
52
+ self.attn_drop = nn.Dropout(attn_drop)
53
+ self.proj = nn.Linear(dim, dim, bias=proj_bias)
54
+ self.proj_drop = nn.Dropout(proj_drop)
55
+
56
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
57
+ B, N, C = x.shape
58
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
59
+
60
+ q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
61
+ attn = q @ k.transpose(-2, -1)
62
+
63
+ attn = attn.softmax(dim=-1)
64
+ attn = self.attn_drop(attn)
65
+
66
+ x = (attn @ v).transpose(1, 2).reshape(B, N, C)
67
+ x = self.proj(x)
68
+ x = self.proj_drop(x)
69
+ return x
70
+
71
+
72
+ class MemEffAttention(Attention):
73
+ def forward(self, x: Tensor, attn_bias=None) -> Tensor:
74
+ if not XFORMERS_AVAILABLE:
75
+ if attn_bias is not None:
76
+ raise AssertionError("xFormers is required for using nested tensors")
77
+ return super().forward(x)
78
+
79
+ B, N, C = x.shape
80
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
81
+
82
+ q, k, v = unbind(qkv, 2)
83
+
84
+ x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
85
+ x = x.reshape([B, N, C])
86
+
87
+ x = self.proj(x)
88
+ x = self.proj_drop(x)
89
+ return x
moge/model/dinov2/layers/block.py ADDED
@@ -0,0 +1,259 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ import logging
11
+ import os
12
+ from typing import Callable, List, Any, Tuple, Dict
13
+ import warnings
14
+
15
+ import torch
16
+ from torch import nn, Tensor
17
+
18
+ from .attention import Attention, MemEffAttention
19
+ from .drop_path import DropPath
20
+ from .layer_scale import LayerScale
21
+ from .mlp import Mlp
22
+
23
+
24
+ logger = logging.getLogger("dinov2")
25
+
26
+
27
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
28
+ try:
29
+ if XFORMERS_ENABLED:
30
+ from xformers.ops import fmha, scaled_index_add, index_select_cat
31
+
32
+ XFORMERS_AVAILABLE = True
33
+ # warnings.warn("xFormers is available (Block)")
34
+ else:
35
+ # warnings.warn("xFormers is disabled (Block)")
36
+ raise ImportError
37
+ except ImportError:
38
+ XFORMERS_AVAILABLE = False
39
+ # warnings.warn("xFormers is not available (Block)")
40
+
41
+
42
+ class Block(nn.Module):
43
+ def __init__(
44
+ self,
45
+ dim: int,
46
+ num_heads: int,
47
+ mlp_ratio: float = 4.0,
48
+ qkv_bias: bool = False,
49
+ proj_bias: bool = True,
50
+ ffn_bias: bool = True,
51
+ drop: float = 0.0,
52
+ attn_drop: float = 0.0,
53
+ init_values=None,
54
+ drop_path: float = 0.0,
55
+ act_layer: Callable[..., nn.Module] = nn.GELU,
56
+ norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
57
+ attn_class: Callable[..., nn.Module] = Attention,
58
+ ffn_layer: Callable[..., nn.Module] = Mlp,
59
+ ) -> None:
60
+ super().__init__()
61
+ # print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
62
+ self.norm1 = norm_layer(dim)
63
+ self.attn = attn_class(
64
+ dim,
65
+ num_heads=num_heads,
66
+ qkv_bias=qkv_bias,
67
+ proj_bias=proj_bias,
68
+ attn_drop=attn_drop,
69
+ proj_drop=drop,
70
+ )
71
+ self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
72
+ self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
73
+
74
+ self.norm2 = norm_layer(dim)
75
+ mlp_hidden_dim = int(dim * mlp_ratio)
76
+ self.mlp = ffn_layer(
77
+ in_features=dim,
78
+ hidden_features=mlp_hidden_dim,
79
+ act_layer=act_layer,
80
+ drop=drop,
81
+ bias=ffn_bias,
82
+ )
83
+ self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
84
+ self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
85
+
86
+ self.sample_drop_ratio = drop_path
87
+
88
+ def forward(self, x: Tensor) -> Tensor:
89
+ def attn_residual_func(x: Tensor) -> Tensor:
90
+ return self.ls1(self.attn(self.norm1(x)))
91
+
92
+ def ffn_residual_func(x: Tensor) -> Tensor:
93
+ return self.ls2(self.mlp(self.norm2(x)))
94
+
95
+ if self.training and self.sample_drop_ratio > 0.1:
96
+ # the overhead is compensated only for a drop path rate larger than 0.1
97
+ x = drop_add_residual_stochastic_depth(
98
+ x,
99
+ residual_func=attn_residual_func,
100
+ sample_drop_ratio=self.sample_drop_ratio,
101
+ )
102
+ x = drop_add_residual_stochastic_depth(
103
+ x,
104
+ residual_func=ffn_residual_func,
105
+ sample_drop_ratio=self.sample_drop_ratio,
106
+ )
107
+ elif self.training and self.sample_drop_ratio > 0.0:
108
+ x = x + self.drop_path1(attn_residual_func(x))
109
+ x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
110
+ else:
111
+ x = x + attn_residual_func(x)
112
+ x = x + ffn_residual_func(x)
113
+ return x
114
+
115
+
116
+ def drop_add_residual_stochastic_depth(
117
+ x: Tensor,
118
+ residual_func: Callable[[Tensor], Tensor],
119
+ sample_drop_ratio: float = 0.0,
120
+ ) -> Tensor:
121
+ # 1) extract subset using permutation
122
+ b, n, d = x.shape
123
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
124
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
125
+ x_subset = x[brange]
126
+
127
+ # 2) apply residual_func to get residual
128
+ residual = residual_func(x_subset)
129
+
130
+ x_flat = x.flatten(1)
131
+ residual = residual.flatten(1)
132
+
133
+ residual_scale_factor = b / sample_subset_size
134
+
135
+ # 3) add the residual
136
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
137
+ return x_plus_residual.view_as(x)
138
+
139
+
140
+ def get_branges_scales(x, sample_drop_ratio=0.0):
141
+ b, n, d = x.shape
142
+ sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
143
+ brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
144
+ residual_scale_factor = b / sample_subset_size
145
+ return brange, residual_scale_factor
146
+
147
+
148
+ def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
149
+ if scaling_vector is None:
150
+ x_flat = x.flatten(1)
151
+ residual = residual.flatten(1)
152
+ x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
153
+ else:
154
+ x_plus_residual = scaled_index_add(
155
+ x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
156
+ )
157
+ return x_plus_residual
158
+
159
+
160
+ attn_bias_cache: Dict[Tuple, Any] = {}
161
+
162
+
163
+ def get_attn_bias_and_cat(x_list, branges=None):
164
+ """
165
+ this will perform the index select, cat the tensors, and provide the attn_bias from cache
166
+ """
167
+ batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
168
+ all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
169
+ if all_shapes not in attn_bias_cache.keys():
170
+ seqlens = []
171
+ for b, x in zip(batch_sizes, x_list):
172
+ for _ in range(b):
173
+ seqlens.append(x.shape[1])
174
+ attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
175
+ attn_bias._batch_sizes = batch_sizes
176
+ attn_bias_cache[all_shapes] = attn_bias
177
+
178
+ if branges is not None:
179
+ cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
180
+ else:
181
+ tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
182
+ cat_tensors = torch.cat(tensors_bs1, dim=1)
183
+
184
+ return attn_bias_cache[all_shapes], cat_tensors
185
+
186
+
187
+ def drop_add_residual_stochastic_depth_list(
188
+ x_list: List[Tensor],
189
+ residual_func: Callable[[Tensor, Any], Tensor],
190
+ sample_drop_ratio: float = 0.0,
191
+ scaling_vector=None,
192
+ ) -> Tensor:
193
+ # 1) generate random set of indices for dropping samples in the batch
194
+ branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
195
+ branges = [s[0] for s in branges_scales]
196
+ residual_scale_factors = [s[1] for s in branges_scales]
197
+
198
+ # 2) get attention bias and index+concat the tensors
199
+ attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
200
+
201
+ # 3) apply residual_func to get residual, and split the result
202
+ residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
203
+
204
+ outputs = []
205
+ for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
206
+ outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
207
+ return outputs
208
+
209
+
210
+ class NestedTensorBlock(Block):
211
+ def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
212
+ """
213
+ x_list contains a list of tensors to nest together and run
214
+ """
215
+ assert isinstance(self.attn, MemEffAttention)
216
+
217
+ if self.training and self.sample_drop_ratio > 0.0:
218
+
219
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
220
+ return self.attn(self.norm1(x), attn_bias=attn_bias)
221
+
222
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
223
+ return self.mlp(self.norm2(x))
224
+
225
+ x_list = drop_add_residual_stochastic_depth_list(
226
+ x_list,
227
+ residual_func=attn_residual_func,
228
+ sample_drop_ratio=self.sample_drop_ratio,
229
+ scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
230
+ )
231
+ x_list = drop_add_residual_stochastic_depth_list(
232
+ x_list,
233
+ residual_func=ffn_residual_func,
234
+ sample_drop_ratio=self.sample_drop_ratio,
235
+ scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
236
+ )
237
+ return x_list
238
+ else:
239
+
240
+ def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
241
+ return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
242
+
243
+ def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
244
+ return self.ls2(self.mlp(self.norm2(x)))
245
+
246
+ attn_bias, x = get_attn_bias_and_cat(x_list)
247
+ x = x + attn_residual_func(x, attn_bias=attn_bias)
248
+ x = x + ffn_residual_func(x)
249
+ return attn_bias.split(x)
250
+
251
+ def forward(self, x_or_x_list):
252
+ if isinstance(x_or_x_list, Tensor):
253
+ return super().forward(x_or_x_list)
254
+ elif isinstance(x_or_x_list, list):
255
+ if not XFORMERS_AVAILABLE:
256
+ raise AssertionError("xFormers is required for using nested tensors")
257
+ return self.forward_nested(x_or_x_list)
258
+ else:
259
+ raise AssertionError
moge/model/dinov2/layers/dino_head.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import torch
7
+ import torch.nn as nn
8
+ from torch.nn.init import trunc_normal_
9
+ from torch.nn.utils import weight_norm
10
+
11
+
12
+ class DINOHead(nn.Module):
13
+ def __init__(
14
+ self,
15
+ in_dim,
16
+ out_dim,
17
+ use_bn=False,
18
+ nlayers=3,
19
+ hidden_dim=2048,
20
+ bottleneck_dim=256,
21
+ mlp_bias=True,
22
+ ):
23
+ super().__init__()
24
+ nlayers = max(nlayers, 1)
25
+ self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
26
+ self.apply(self._init_weights)
27
+ self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
28
+ self.last_layer.weight_g.data.fill_(1)
29
+
30
+ def _init_weights(self, m):
31
+ if isinstance(m, nn.Linear):
32
+ trunc_normal_(m.weight, std=0.02)
33
+ if isinstance(m, nn.Linear) and m.bias is not None:
34
+ nn.init.constant_(m.bias, 0)
35
+
36
+ def forward(self, x):
37
+ x = self.mlp(x)
38
+ eps = 1e-6 if x.dtype == torch.float16 else 1e-12
39
+ x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
40
+ x = self.last_layer(x)
41
+ return x
42
+
43
+
44
+ def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
45
+ if nlayers == 1:
46
+ return nn.Linear(in_dim, bottleneck_dim, bias=bias)
47
+ else:
48
+ layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
49
+ if use_bn:
50
+ layers.append(nn.BatchNorm1d(hidden_dim))
51
+ layers.append(nn.GELU())
52
+ for _ in range(nlayers - 2):
53
+ layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
54
+ if use_bn:
55
+ layers.append(nn.BatchNorm1d(hidden_dim))
56
+ layers.append(nn.GELU())
57
+ layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
58
+ return nn.Sequential(*layers)
moge/model/dinov2/layers/drop_path.py ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
9
+
10
+
11
+ from torch import nn
12
+
13
+
14
+ def drop_path(x, drop_prob: float = 0.0, training: bool = False):
15
+ if drop_prob == 0.0 or not training:
16
+ return x
17
+ keep_prob = 1 - drop_prob
18
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
19
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
20
+ if keep_prob > 0.0:
21
+ random_tensor.div_(keep_prob)
22
+ output = x * random_tensor
23
+ return output
24
+
25
+
26
+ class DropPath(nn.Module):
27
+ """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
28
+
29
+ def __init__(self, drop_prob=None):
30
+ super(DropPath, self).__init__()
31
+ self.drop_prob = drop_prob
32
+
33
+ def forward(self, x):
34
+ return drop_path(x, self.drop_prob, self.training)
moge/model/dinov2/layers/layer_scale.py ADDED
@@ -0,0 +1,27 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
7
+
8
+ from typing import Union
9
+
10
+ import torch
11
+ from torch import Tensor
12
+ from torch import nn
13
+
14
+
15
+ class LayerScale(nn.Module):
16
+ def __init__(
17
+ self,
18
+ dim: int,
19
+ init_values: Union[float, Tensor] = 1e-5,
20
+ inplace: bool = False,
21
+ ) -> None:
22
+ super().__init__()
23
+ self.inplace = inplace
24
+ self.gamma = nn.Parameter(init_values * torch.ones(dim))
25
+
26
+ def forward(self, x: Tensor) -> Tensor:
27
+ return x.mul_(self.gamma) if self.inplace else x * self.gamma
moge/model/dinov2/layers/mlp.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
9
+
10
+
11
+ from typing import Callable, Optional
12
+
13
+ from torch import Tensor, nn
14
+
15
+
16
+ class Mlp(nn.Module):
17
+ def __init__(
18
+ self,
19
+ in_features: int,
20
+ hidden_features: Optional[int] = None,
21
+ out_features: Optional[int] = None,
22
+ act_layer: Callable[..., nn.Module] = nn.GELU,
23
+ drop: float = 0.0,
24
+ bias: bool = True,
25
+ ) -> None:
26
+ super().__init__()
27
+ out_features = out_features or in_features
28
+ hidden_features = hidden_features or in_features
29
+ self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
30
+ self.act = act_layer()
31
+ self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
32
+ self.drop = nn.Dropout(drop)
33
+
34
+ def forward(self, x: Tensor) -> Tensor:
35
+ x = self.fc1(x)
36
+ x = self.act(x)
37
+ x = self.drop(x)
38
+ x = self.fc2(x)
39
+ x = self.drop(x)
40
+ return x
moge/model/dinov2/layers/patch_embed.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
9
+
10
+ from typing import Callable, Optional, Tuple, Union
11
+
12
+ from torch import Tensor
13
+ import torch.nn as nn
14
+
15
+
16
+ def make_2tuple(x):
17
+ if isinstance(x, tuple):
18
+ assert len(x) == 2
19
+ return x
20
+
21
+ assert isinstance(x, int)
22
+ return (x, x)
23
+
24
+
25
+ class PatchEmbed(nn.Module):
26
+ """
27
+ 2D image to patch embedding: (B,C,H,W) -> (B,N,D)
28
+
29
+ Args:
30
+ img_size: Image size.
31
+ patch_size: Patch token size.
32
+ in_chans: Number of input image channels.
33
+ embed_dim: Number of linear projection output channels.
34
+ norm_layer: Normalization layer.
35
+ """
36
+
37
+ def __init__(
38
+ self,
39
+ img_size: Union[int, Tuple[int, int]] = 224,
40
+ patch_size: Union[int, Tuple[int, int]] = 16,
41
+ in_chans: int = 3,
42
+ embed_dim: int = 768,
43
+ norm_layer: Optional[Callable] = None,
44
+ flatten_embedding: bool = True,
45
+ ) -> None:
46
+ super().__init__()
47
+
48
+ image_HW = make_2tuple(img_size)
49
+ patch_HW = make_2tuple(patch_size)
50
+ patch_grid_size = (
51
+ image_HW[0] // patch_HW[0],
52
+ image_HW[1] // patch_HW[1],
53
+ )
54
+
55
+ self.img_size = image_HW
56
+ self.patch_size = patch_HW
57
+ self.patches_resolution = patch_grid_size
58
+ self.num_patches = patch_grid_size[0] * patch_grid_size[1]
59
+
60
+ self.in_chans = in_chans
61
+ self.embed_dim = embed_dim
62
+
63
+ self.flatten_embedding = flatten_embedding
64
+
65
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
66
+ self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
67
+
68
+ def forward(self, x: Tensor) -> Tensor:
69
+ _, _, H, W = x.shape
70
+ patch_H, patch_W = self.patch_size
71
+
72
+ assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
73
+ assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
74
+
75
+ x = self.proj(x) # B C H W
76
+ H, W = x.size(2), x.size(3)
77
+ x = x.flatten(2).transpose(1, 2) # B HW C
78
+ x = self.norm(x)
79
+ if not self.flatten_embedding:
80
+ x = x.reshape(-1, H, W, self.embed_dim) # B H W C
81
+ return x
82
+
83
+ def flops(self) -> float:
84
+ Ho, Wo = self.patches_resolution
85
+ flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
86
+ if self.norm is not None:
87
+ flops += Ho * Wo * self.embed_dim
88
+ return flops
moge/model/dinov2/layers/swiglu_ffn.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ from typing import Callable, Optional
8
+ import warnings
9
+
10
+ from torch import Tensor, nn
11
+ import torch.nn.functional as F
12
+
13
+
14
+ class SwiGLUFFN(nn.Module):
15
+ def __init__(
16
+ self,
17
+ in_features: int,
18
+ hidden_features: Optional[int] = None,
19
+ out_features: Optional[int] = None,
20
+ act_layer: Callable[..., nn.Module] = None,
21
+ drop: float = 0.0,
22
+ bias: bool = True,
23
+ ) -> None:
24
+ super().__init__()
25
+ out_features = out_features or in_features
26
+ hidden_features = hidden_features or in_features
27
+ self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
28
+ self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
29
+
30
+ def forward(self, x: Tensor) -> Tensor:
31
+ x12 = self.w12(x)
32
+ x1, x2 = x12.chunk(2, dim=-1)
33
+ hidden = F.silu(x1) * x2
34
+ return self.w3(hidden)
35
+
36
+
37
+ XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
38
+ try:
39
+ if XFORMERS_ENABLED:
40
+ from xformers.ops import SwiGLU
41
+
42
+ XFORMERS_AVAILABLE = True
43
+ # warnings.warn("xFormers is available (SwiGLU)")
44
+ else:
45
+ # warnings.warn("xFormers is disabled (SwiGLU)")
46
+ raise ImportError
47
+ except ImportError:
48
+ SwiGLU = SwiGLUFFN
49
+ XFORMERS_AVAILABLE = False
50
+
51
+ # warnings.warn("xFormers is not available (SwiGLU)")
52
+
53
+
54
+ class SwiGLUFFNFused(SwiGLU):
55
+ def __init__(
56
+ self,
57
+ in_features: int,
58
+ hidden_features: Optional[int] = None,
59
+ out_features: Optional[int] = None,
60
+ act_layer: Callable[..., nn.Module] = None,
61
+ drop: float = 0.0,
62
+ bias: bool = True,
63
+ ) -> None:
64
+ out_features = out_features or in_features
65
+ hidden_features = hidden_features or in_features
66
+ hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
67
+ super().__init__(
68
+ in_features=in_features,
69
+ hidden_features=hidden_features,
70
+ out_features=out_features,
71
+ bias=bias,
72
+ )
moge/model/dinov2/models/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+
8
+ from . import vision_transformer as vits
9
+
10
+
11
+ logger = logging.getLogger("dinov2")
12
+
13
+
14
+ def build_model(args, only_teacher=False, img_size=224):
15
+ args.arch = args.arch.removesuffix("_memeff")
16
+ if "vit" in args.arch:
17
+ vit_kwargs = dict(
18
+ img_size=img_size,
19
+ patch_size=args.patch_size,
20
+ init_values=args.layerscale,
21
+ ffn_layer=args.ffn_layer,
22
+ block_chunks=args.block_chunks,
23
+ qkv_bias=args.qkv_bias,
24
+ proj_bias=args.proj_bias,
25
+ ffn_bias=args.ffn_bias,
26
+ num_register_tokens=args.num_register_tokens,
27
+ interpolate_offset=args.interpolate_offset,
28
+ interpolate_antialias=args.interpolate_antialias,
29
+ )
30
+ teacher = vits.__dict__[args.arch](**vit_kwargs)
31
+ if only_teacher:
32
+ return teacher, teacher.embed_dim
33
+ student = vits.__dict__[args.arch](
34
+ **vit_kwargs,
35
+ drop_path_rate=args.drop_path_rate,
36
+ drop_path_uniform=args.drop_path_uniform,
37
+ )
38
+ embed_dim = student.embed_dim
39
+ return student, teacher, embed_dim
40
+
41
+
42
+ def build_model_from_cfg(cfg, only_teacher=False):
43
+ return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
moge/model/dinov2/models/vision_transformer.py ADDED
@@ -0,0 +1,396 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ # References:
7
+ # https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
8
+ # https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
9
+
10
+ from functools import partial
11
+ import math
12
+ import logging
13
+ from typing import Sequence, Tuple, Union, Callable
14
+
15
+ import torch
16
+ import torch.nn as nn
17
+ import torch.utils.checkpoint
18
+ from torch.nn.init import trunc_normal_
19
+
20
+ from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
21
+
22
+
23
+ logger = logging.getLogger("dinov2")
24
+
25
+
26
+ def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
27
+ if not depth_first and include_root:
28
+ fn(module=module, name=name)
29
+ for child_name, child_module in module.named_children():
30
+ child_name = ".".join((name, child_name)) if name else child_name
31
+ named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
32
+ if depth_first and include_root:
33
+ fn(module=module, name=name)
34
+ return module
35
+
36
+
37
+ class BlockChunk(nn.ModuleList):
38
+ def forward(self, x):
39
+ for b in self:
40
+ x = b(x)
41
+ return x
42
+
43
+
44
+ class DinoVisionTransformer(nn.Module):
45
+ def __init__(
46
+ self,
47
+ img_size=224,
48
+ patch_size=16,
49
+ in_chans=3,
50
+ embed_dim=768,
51
+ depth=12,
52
+ num_heads=12,
53
+ mlp_ratio=4.0,
54
+ qkv_bias=True,
55
+ ffn_bias=True,
56
+ proj_bias=True,
57
+ drop_path_rate=0.0,
58
+ drop_path_uniform=False,
59
+ init_values=None, # for layerscale: None or 0 => no layerscale
60
+ embed_layer=PatchEmbed,
61
+ act_layer=nn.GELU,
62
+ block_fn=Block,
63
+ ffn_layer="mlp",
64
+ block_chunks=1,
65
+ num_register_tokens=0,
66
+ interpolate_antialias=False,
67
+ interpolate_offset=0.1,
68
+ ):
69
+ """
70
+ Args:
71
+ img_size (int, tuple): input image size
72
+ patch_size (int, tuple): patch size
73
+ in_chans (int): number of input channels
74
+ embed_dim (int): embedding dimension
75
+ depth (int): depth of transformer
76
+ num_heads (int): number of attention heads
77
+ mlp_ratio (int): ratio of mlp hidden dim to embedding dim
78
+ qkv_bias (bool): enable bias for qkv if True
79
+ proj_bias (bool): enable bias for proj in attn if True
80
+ ffn_bias (bool): enable bias for ffn if True
81
+ drop_path_rate (float): stochastic depth rate
82
+ drop_path_uniform (bool): apply uniform drop rate across blocks
83
+ weight_init (str): weight init scheme
84
+ init_values (float): layer-scale init values
85
+ embed_layer (nn.Module): patch embedding layer
86
+ act_layer (nn.Module): MLP activation layer
87
+ block_fn (nn.Module): transformer block class
88
+ ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
89
+ block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
90
+ num_register_tokens: (int) number of extra cls tokens (so-called "registers")
91
+ interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
92
+ interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
93
+ """
94
+ super().__init__()
95
+ norm_layer = partial(nn.LayerNorm, eps=1e-6)
96
+
97
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
98
+ self.num_tokens = 1
99
+ self.n_blocks = depth
100
+ self.num_heads = num_heads
101
+ self.patch_size = patch_size
102
+ self.num_register_tokens = num_register_tokens
103
+ self.interpolate_antialias = interpolate_antialias
104
+ self.interpolate_offset = interpolate_offset
105
+
106
+ self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
107
+ num_patches = self.patch_embed.num_patches
108
+
109
+ self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
110
+ self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
111
+ assert num_register_tokens >= 0
112
+ self.register_tokens = (
113
+ nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
114
+ )
115
+
116
+ if drop_path_uniform is True:
117
+ dpr = [drop_path_rate] * depth
118
+ else:
119
+ dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
120
+
121
+ if ffn_layer == "mlp":
122
+ logger.info("using MLP layer as FFN")
123
+ ffn_layer = Mlp
124
+ elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
125
+ logger.info("using SwiGLU layer as FFN")
126
+ ffn_layer = SwiGLUFFNFused
127
+ elif ffn_layer == "identity":
128
+ logger.info("using Identity layer as FFN")
129
+
130
+ def f(*args, **kwargs):
131
+ return nn.Identity()
132
+
133
+ ffn_layer = f
134
+ else:
135
+ raise NotImplementedError
136
+
137
+ blocks_list = [
138
+ block_fn(
139
+ dim=embed_dim,
140
+ num_heads=num_heads,
141
+ mlp_ratio=mlp_ratio,
142
+ qkv_bias=qkv_bias,
143
+ proj_bias=proj_bias,
144
+ ffn_bias=ffn_bias,
145
+ drop_path=dpr[i],
146
+ norm_layer=norm_layer,
147
+ act_layer=act_layer,
148
+ ffn_layer=ffn_layer,
149
+ init_values=init_values,
150
+ )
151
+ for i in range(depth)
152
+ ]
153
+ if block_chunks > 0:
154
+ self.chunked_blocks = True
155
+ chunked_blocks = []
156
+ chunksize = depth // block_chunks
157
+ for i in range(0, depth, chunksize):
158
+ # this is to keep the block index consistent if we chunk the block list
159
+ chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
160
+ self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
161
+ else:
162
+ self.chunked_blocks = False
163
+ self.blocks = nn.ModuleList(blocks_list)
164
+
165
+ self.norm = norm_layer(embed_dim)
166
+ self.head = nn.Identity()
167
+
168
+ self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
169
+
170
+ self.init_weights()
171
+
172
+ def init_weights(self):
173
+ trunc_normal_(self.pos_embed, std=0.02)
174
+ nn.init.normal_(self.cls_token, std=1e-6)
175
+ if self.register_tokens is not None:
176
+ nn.init.normal_(self.register_tokens, std=1e-6)
177
+ named_apply(init_weights_vit_timm, self)
178
+
179
+ def interpolate_pos_encoding(self, x, w, h):
180
+ previous_dtype = x.dtype
181
+ npatch = x.shape[1] - 1
182
+ N = self.pos_embed.shape[1] - 1
183
+ if npatch == N and w == h:
184
+ return self.pos_embed
185
+ pos_embed = self.pos_embed.float()
186
+ class_pos_embed = pos_embed[:, 0]
187
+ patch_pos_embed = pos_embed[:, 1:]
188
+ dim = x.shape[-1]
189
+ w0 = w // self.patch_size
190
+ h0 = h // self.patch_size
191
+ M = int(math.sqrt(N)) # Recover the number of patches in each dimension
192
+ assert N == M * M
193
+ kwargs = {}
194
+ if self.interpolate_offset:
195
+ # Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
196
+ # Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
197
+ sx = float(w0 + self.interpolate_offset) / M
198
+ sy = float(h0 + self.interpolate_offset) / M
199
+ kwargs["scale_factor"] = (sx, sy)
200
+ else:
201
+ # Simply specify an output size instead of a scale factor
202
+ kwargs["size"] = (w0, h0)
203
+ patch_pos_embed = nn.functional.interpolate(
204
+ patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
205
+ mode="bicubic",
206
+ antialias=self.interpolate_antialias,
207
+ **kwargs,
208
+ )
209
+ assert (w0, h0) == patch_pos_embed.shape[-2:]
210
+ patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
211
+ return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
212
+
213
+ def prepare_tokens_with_masks(self, x, masks=None):
214
+ B, nc, w, h = x.shape
215
+ x = self.patch_embed(x)
216
+ if masks is not None:
217
+ x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
218
+
219
+ x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
220
+ x = x + self.interpolate_pos_encoding(x, w, h)
221
+
222
+ if self.register_tokens is not None:
223
+ x = torch.cat(
224
+ (
225
+ x[:, :1],
226
+ self.register_tokens.expand(x.shape[0], -1, -1),
227
+ x[:, 1:],
228
+ ),
229
+ dim=1,
230
+ )
231
+
232
+ return x
233
+
234
+ def forward_features_list(self, x_list, masks_list):
235
+ x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
236
+ for blk in self.blocks:
237
+ x = blk(x)
238
+
239
+ all_x = x
240
+ output = []
241
+ for x, masks in zip(all_x, masks_list):
242
+ x_norm = self.norm(x)
243
+ output.append(
244
+ {
245
+ "x_norm_clstoken": x_norm[:, 0],
246
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
247
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
248
+ "x_prenorm": x,
249
+ "masks": masks,
250
+ }
251
+ )
252
+ return output
253
+
254
+ def forward_features(self, x, masks=None):
255
+ if isinstance(x, list):
256
+ return self.forward_features_list(x, masks)
257
+
258
+ x = self.prepare_tokens_with_masks(x, masks)
259
+
260
+ for blk in self.blocks:
261
+ x = blk(x)
262
+
263
+ x_norm = self.norm(x)
264
+ return {
265
+ "x_norm_clstoken": x_norm[:, 0],
266
+ "x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
267
+ "x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
268
+ "x_prenorm": x,
269
+ "masks": masks,
270
+ }
271
+
272
+ def _get_intermediate_layers_not_chunked(self, x, n=1):
273
+ x = self.prepare_tokens_with_masks(x)
274
+ # If n is an int, take the n last blocks. If it's a list, take them
275
+ output, total_block_len = [], len(self.blocks)
276
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
277
+ for i, blk in enumerate(self.blocks):
278
+ x = blk(x)
279
+ if i in blocks_to_take:
280
+ output.append(x)
281
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
282
+ return output
283
+
284
+ def _get_intermediate_layers_chunked(self, x, n=1):
285
+ x = self.prepare_tokens_with_masks(x)
286
+ output, i, total_block_len = [], 0, len(self.blocks[-1])
287
+ # If n is an int, take the n last blocks. If it's a list, take them
288
+ blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
289
+ for block_chunk in self.blocks:
290
+ for blk in block_chunk[i:]: # Passing the nn.Identity()
291
+ x = blk(x)
292
+ if i in blocks_to_take:
293
+ output.append(x)
294
+ i += 1
295
+ assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
296
+ return output
297
+
298
+ def get_intermediate_layers(
299
+ self,
300
+ x: torch.Tensor,
301
+ n: Union[int, Sequence] = 1, # Layers or n last layers to take
302
+ reshape: bool = False,
303
+ return_class_token: bool = False,
304
+ norm=True,
305
+ ) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
306
+ if self.chunked_blocks:
307
+ outputs = self._get_intermediate_layers_chunked(x, n)
308
+ else:
309
+ outputs = self._get_intermediate_layers_not_chunked(x, n)
310
+ if norm:
311
+ outputs = [self.norm(out) for out in outputs]
312
+ class_tokens = [out[:, 0] for out in outputs]
313
+ outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
314
+ if reshape:
315
+ B, _, w, h = x.shape
316
+ outputs = [
317
+ out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
318
+ for out in outputs
319
+ ]
320
+ if return_class_token:
321
+ return tuple(zip(outputs, class_tokens))
322
+ return tuple(outputs)
323
+
324
+ def forward(self, *args, is_training=False, **kwargs):
325
+ ret = self.forward_features(*args, **kwargs)
326
+ if is_training:
327
+ return ret
328
+ else:
329
+ return self.head(ret["x_norm_clstoken"])
330
+
331
+
332
+ def init_weights_vit_timm(module: nn.Module, name: str = ""):
333
+ """ViT weight initialization, original timm impl (for reproducibility)"""
334
+ if isinstance(module, nn.Linear):
335
+ trunc_normal_(module.weight, std=0.02)
336
+ if module.bias is not None:
337
+ nn.init.zeros_(module.bias)
338
+
339
+
340
+ def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
341
+ model = DinoVisionTransformer(
342
+ patch_size=patch_size,
343
+ embed_dim=384,
344
+ depth=12,
345
+ num_heads=6,
346
+ mlp_ratio=4,
347
+ block_fn=partial(Block, attn_class=MemEffAttention),
348
+ num_register_tokens=num_register_tokens,
349
+ **kwargs,
350
+ )
351
+ return model
352
+
353
+
354
+ def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
355
+ model = DinoVisionTransformer(
356
+ patch_size=patch_size,
357
+ embed_dim=768,
358
+ depth=12,
359
+ num_heads=12,
360
+ mlp_ratio=4,
361
+ block_fn=partial(Block, attn_class=MemEffAttention),
362
+ num_register_tokens=num_register_tokens,
363
+ **kwargs,
364
+ )
365
+ return model
366
+
367
+
368
+ def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
369
+ model = DinoVisionTransformer(
370
+ patch_size=patch_size,
371
+ embed_dim=1024,
372
+ depth=24,
373
+ num_heads=16,
374
+ mlp_ratio=4,
375
+ block_fn=partial(Block, attn_class=MemEffAttention),
376
+ num_register_tokens=num_register_tokens,
377
+ **kwargs,
378
+ )
379
+ return model
380
+
381
+
382
+ def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
383
+ """
384
+ Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
385
+ """
386
+ model = DinoVisionTransformer(
387
+ patch_size=patch_size,
388
+ embed_dim=1536,
389
+ depth=40,
390
+ num_heads=24,
391
+ mlp_ratio=4,
392
+ block_fn=partial(Block, attn_class=MemEffAttention),
393
+ num_register_tokens=num_register_tokens,
394
+ **kwargs,
395
+ )
396
+ return model
moge/model/dinov2/utils/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
moge/model/dinov2/utils/cluster.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from enum import Enum
7
+ import os
8
+ from pathlib import Path
9
+ from typing import Any, Dict, Optional
10
+
11
+
12
+ class ClusterType(Enum):
13
+ AWS = "aws"
14
+ FAIR = "fair"
15
+ RSC = "rsc"
16
+
17
+
18
+ def _guess_cluster_type() -> ClusterType:
19
+ uname = os.uname()
20
+ if uname.sysname == "Linux":
21
+ if uname.release.endswith("-aws"):
22
+ # Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
23
+ return ClusterType.AWS
24
+ elif uname.nodename.startswith("rsc"):
25
+ # Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
26
+ return ClusterType.RSC
27
+
28
+ return ClusterType.FAIR
29
+
30
+
31
+ def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
32
+ if cluster_type is None:
33
+ return _guess_cluster_type()
34
+
35
+ return cluster_type
36
+
37
+
38
+ def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
39
+ cluster_type = get_cluster_type(cluster_type)
40
+ if cluster_type is None:
41
+ return None
42
+
43
+ CHECKPOINT_DIRNAMES = {
44
+ ClusterType.AWS: "checkpoints",
45
+ ClusterType.FAIR: "checkpoint",
46
+ ClusterType.RSC: "checkpoint/dino",
47
+ }
48
+ return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
49
+
50
+
51
+ def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
52
+ checkpoint_path = get_checkpoint_path(cluster_type)
53
+ if checkpoint_path is None:
54
+ return None
55
+
56
+ username = os.environ.get("USER")
57
+ assert username is not None
58
+ return checkpoint_path / username
59
+
60
+
61
+ def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
62
+ cluster_type = get_cluster_type(cluster_type)
63
+ if cluster_type is None:
64
+ return None
65
+
66
+ SLURM_PARTITIONS = {
67
+ ClusterType.AWS: "learnlab",
68
+ ClusterType.FAIR: "learnlab",
69
+ ClusterType.RSC: "learn",
70
+ }
71
+ return SLURM_PARTITIONS[cluster_type]
72
+
73
+
74
+ def get_slurm_executor_parameters(
75
+ nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
76
+ ) -> Dict[str, Any]:
77
+ # create default parameters
78
+ params = {
79
+ "mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
80
+ "gpus_per_node": num_gpus_per_node,
81
+ "tasks_per_node": num_gpus_per_node, # one task per GPU
82
+ "cpus_per_task": 10,
83
+ "nodes": nodes,
84
+ "slurm_partition": get_slurm_partition(cluster_type),
85
+ }
86
+ # apply cluster-specific adjustments
87
+ cluster_type = get_cluster_type(cluster_type)
88
+ if cluster_type == ClusterType.AWS:
89
+ params["cpus_per_task"] = 12
90
+ del params["mem_gb"]
91
+ elif cluster_type == ClusterType.RSC:
92
+ params["cpus_per_task"] = 12
93
+ # set additional parameters / apply overrides
94
+ params.update(kwargs)
95
+ return params
moge/model/dinov2/utils/config.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import math
7
+ import logging
8
+ import os
9
+
10
+ from omegaconf import OmegaConf
11
+
12
+ import dinov2.distributed as distributed
13
+ from dinov2.logging import setup_logging
14
+ from dinov2.utils import utils
15
+ from dinov2.configs import dinov2_default_config
16
+
17
+
18
+ logger = logging.getLogger("dinov2")
19
+
20
+
21
+ def apply_scaling_rules_to_cfg(cfg): # to fix
22
+ if cfg.optim.scaling_rule == "sqrt_wrt_1024":
23
+ base_lr = cfg.optim.base_lr
24
+ cfg.optim.lr = base_lr
25
+ cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
26
+ logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
27
+ else:
28
+ raise NotImplementedError
29
+ return cfg
30
+
31
+
32
+ def write_config(cfg, output_dir, name="config.yaml"):
33
+ logger.info(OmegaConf.to_yaml(cfg))
34
+ saved_cfg_path = os.path.join(output_dir, name)
35
+ with open(saved_cfg_path, "w") as f:
36
+ OmegaConf.save(config=cfg, f=f)
37
+ return saved_cfg_path
38
+
39
+
40
+ def get_cfg_from_args(args):
41
+ args.output_dir = os.path.abspath(args.output_dir)
42
+ args.opts += [f"train.output_dir={args.output_dir}"]
43
+ default_cfg = OmegaConf.create(dinov2_default_config)
44
+ cfg = OmegaConf.load(args.config_file)
45
+ cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
46
+ return cfg
47
+
48
+
49
+ def default_setup(args):
50
+ distributed.enable(overwrite=True)
51
+ seed = getattr(args, "seed", 0)
52
+ rank = distributed.get_global_rank()
53
+
54
+ global logger
55
+ setup_logging(output=args.output_dir, level=logging.INFO)
56
+ logger = logging.getLogger("dinov2")
57
+
58
+ utils.fix_random_seeds(seed + rank)
59
+ logger.info("git:\n {}\n".format(utils.get_sha()))
60
+ logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
61
+
62
+
63
+ def setup(args):
64
+ """
65
+ Create configs and perform basic setups.
66
+ """
67
+ cfg = get_cfg_from_args(args)
68
+ os.makedirs(args.output_dir, exist_ok=True)
69
+ default_setup(args)
70
+ apply_scaling_rules_to_cfg(cfg)
71
+ write_config(cfg, args.output_dir)
72
+ return cfg
moge/model/dinov2/utils/dtype.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+
7
+ from typing import Dict, Union
8
+
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ TypeSpec = Union[str, np.dtype, torch.dtype]
14
+
15
+
16
+ _NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
17
+ np.dtype("bool"): torch.bool,
18
+ np.dtype("uint8"): torch.uint8,
19
+ np.dtype("int8"): torch.int8,
20
+ np.dtype("int16"): torch.int16,
21
+ np.dtype("int32"): torch.int32,
22
+ np.dtype("int64"): torch.int64,
23
+ np.dtype("float16"): torch.float16,
24
+ np.dtype("float32"): torch.float32,
25
+ np.dtype("float64"): torch.float64,
26
+ np.dtype("complex64"): torch.complex64,
27
+ np.dtype("complex128"): torch.complex128,
28
+ }
29
+
30
+
31
+ def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
32
+ if isinstance(dtype, torch.dtype):
33
+ return dtype
34
+ if isinstance(dtype, str):
35
+ dtype = np.dtype(dtype)
36
+ assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
37
+ return _NUMPY_TO_TORCH_DTYPE[dtype]
moge/model/dinov2/utils/param_groups.py ADDED
@@ -0,0 +1,103 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ from collections import defaultdict
7
+ import logging
8
+
9
+
10
+ logger = logging.getLogger("dinov2")
11
+
12
+
13
+ def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
14
+ """
15
+ Calculate lr decay rate for different ViT blocks.
16
+ Args:
17
+ name (string): parameter name.
18
+ lr_decay_rate (float): base lr decay rate.
19
+ num_layers (int): number of ViT blocks.
20
+ Returns:
21
+ lr decay rate for the given parameter.
22
+ """
23
+ layer_id = num_layers + 1
24
+ if name.startswith("backbone") or force_is_backbone:
25
+ if (
26
+ ".pos_embed" in name
27
+ or ".patch_embed" in name
28
+ or ".mask_token" in name
29
+ or ".cls_token" in name
30
+ or ".register_tokens" in name
31
+ ):
32
+ layer_id = 0
33
+ elif force_is_backbone and (
34
+ "pos_embed" in name
35
+ or "patch_embed" in name
36
+ or "mask_token" in name
37
+ or "cls_token" in name
38
+ or "register_tokens" in name
39
+ ):
40
+ layer_id = 0
41
+ elif ".blocks." in name and ".residual." not in name:
42
+ layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
43
+ elif chunked_blocks and "blocks." in name and "residual." not in name:
44
+ layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
45
+ elif "blocks." in name and "residual." not in name:
46
+ layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
47
+
48
+ return lr_decay_rate ** (num_layers + 1 - layer_id)
49
+
50
+
51
+ def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
52
+ chunked_blocks = False
53
+ if hasattr(model, "n_blocks"):
54
+ logger.info("chunked fsdp")
55
+ n_blocks = model.n_blocks
56
+ chunked_blocks = model.chunked_blocks
57
+ elif hasattr(model, "blocks"):
58
+ logger.info("first code branch")
59
+ n_blocks = len(model.blocks)
60
+ elif hasattr(model, "backbone"):
61
+ logger.info("second code branch")
62
+ n_blocks = len(model.backbone.blocks)
63
+ else:
64
+ logger.info("else code branch")
65
+ n_blocks = 0
66
+ all_param_groups = []
67
+
68
+ for name, param in model.named_parameters():
69
+ name = name.replace("_fsdp_wrapped_module.", "")
70
+ if not param.requires_grad:
71
+ continue
72
+ decay_rate = get_vit_lr_decay_rate(
73
+ name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
74
+ )
75
+ d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
76
+
77
+ if "last_layer" in name:
78
+ d.update({"is_last_layer": True})
79
+
80
+ if name.endswith(".bias") or "norm" in name or "gamma" in name:
81
+ d.update({"wd_multiplier": 0.0})
82
+
83
+ if "patch_embed" in name:
84
+ d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
85
+
86
+ all_param_groups.append(d)
87
+ logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
88
+
89
+ return all_param_groups
90
+
91
+
92
+ def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
93
+ fused_params_groups = defaultdict(lambda: {"params": []})
94
+ for d in all_params_groups:
95
+ identifier = ""
96
+ for k in keys:
97
+ identifier += k + str(d[k]) + "_"
98
+
99
+ for k in keys:
100
+ fused_params_groups[identifier][k] = d[k]
101
+ fused_params_groups[identifier]["params"].append(d["params"])
102
+
103
+ return fused_params_groups.values()
moge/model/dinov2/utils/utils.py ADDED
@@ -0,0 +1,95 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright (c) Meta Platforms, Inc. and affiliates.
2
+ #
3
+ # This source code is licensed under the Apache License, Version 2.0
4
+ # found in the LICENSE file in the root directory of this source tree.
5
+
6
+ import logging
7
+ import os
8
+ import random
9
+ import subprocess
10
+ from urllib.parse import urlparse
11
+
12
+ import numpy as np
13
+ import torch
14
+ from torch import nn
15
+
16
+
17
+ logger = logging.getLogger("dinov2")
18
+
19
+
20
+ def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
21
+ if urlparse(pretrained_weights).scheme: # If it looks like an URL
22
+ state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
23
+ else:
24
+ state_dict = torch.load(pretrained_weights, map_location="cpu")
25
+ if checkpoint_key is not None and checkpoint_key in state_dict:
26
+ logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
27
+ state_dict = state_dict[checkpoint_key]
28
+ # remove `module.` prefix
29
+ state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
30
+ # remove `backbone.` prefix induced by multicrop wrapper
31
+ state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
32
+ msg = model.load_state_dict(state_dict, strict=False)
33
+ logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
34
+
35
+
36
+ def fix_random_seeds(seed=31):
37
+ """
38
+ Fix random seeds.
39
+ """
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed_all(seed)
42
+ np.random.seed(seed)
43
+ random.seed(seed)
44
+
45
+
46
+ def get_sha():
47
+ cwd = os.path.dirname(os.path.abspath(__file__))
48
+
49
+ def _run(command):
50
+ return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
51
+
52
+ sha = "N/A"
53
+ diff = "clean"
54
+ branch = "N/A"
55
+ try:
56
+ sha = _run(["git", "rev-parse", "HEAD"])
57
+ subprocess.check_output(["git", "diff"], cwd=cwd)
58
+ diff = _run(["git", "diff-index", "HEAD"])
59
+ diff = "has uncommitted changes" if diff else "clean"
60
+ branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
61
+ except Exception:
62
+ pass
63
+ message = f"sha: {sha}, status: {diff}, branch: {branch}"
64
+ return message
65
+
66
+
67
+ class CosineScheduler(object):
68
+ def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
69
+ super().__init__()
70
+ self.final_value = final_value
71
+ self.total_iters = total_iters
72
+
73
+ freeze_schedule = np.zeros((freeze_iters))
74
+
75
+ warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
76
+
77
+ iters = np.arange(total_iters - warmup_iters - freeze_iters)
78
+ schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
79
+ self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
80
+
81
+ assert len(self.schedule) == self.total_iters
82
+
83
+ def __getitem__(self, it):
84
+ if it >= self.total_iters:
85
+ return self.final_value
86
+ else:
87
+ return self.schedule[it]
88
+
89
+
90
+ def has_batchnorms(model):
91
+ bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
92
+ for name, module in model.named_modules():
93
+ if isinstance(module, bn_types):
94
+ return True
95
+ return False
moge/model/moge_model.py ADDED
@@ -0,0 +1,376 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from numbers import Number
3
+ from functools import partial
4
+ from pathlib import Path
5
+ import importlib
6
+ import warnings
7
+ import json
8
+
9
+ import torch
10
+ import torch.nn as nn
11
+ import torch.nn.functional as F
12
+ import torch.utils
13
+ import torch.utils.checkpoint
14
+ import torch.version
15
+ import utils3d
16
+ from huggingface_hub import hf_hub_download
17
+
18
+ from ..utils.geometry_torch import image_plane_uv, point_map_to_depth, gaussian_blur_2d
19
+ from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
20
+ from ..utils.tools import timeit
21
+
22
+
23
+ class ResidualConvBlock(nn.Module):
24
+ def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'):
25
+ super(ResidualConvBlock, self).__init__()
26
+ if out_channels is None:
27
+ out_channels = in_channels
28
+ if hidden_channels is None:
29
+ hidden_channels = in_channels
30
+
31
+ if activation =='relu':
32
+ activation_cls = lambda: nn.ReLU(inplace=True)
33
+ elif activation == 'leaky_relu':
34
+ activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True)
35
+ elif activation =='silu':
36
+ activation_cls = lambda: nn.SiLU(inplace=True)
37
+ elif activation == 'elu':
38
+ activation_cls = lambda: nn.ELU(inplace=True)
39
+ else:
40
+ raise ValueError(f'Unsupported activation function: {activation}')
41
+
42
+ self.layers = nn.Sequential(
43
+ nn.GroupNorm(1, in_channels),
44
+ activation_cls(),
45
+ nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode),
46
+ nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels),
47
+ activation_cls(),
48
+ nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode)
49
+ )
50
+
51
+ self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
52
+
53
+ def forward(self, x):
54
+ skip = self.skip_connection(x)
55
+ x = self.layers(x)
56
+ x = x + skip
57
+ return x
58
+
59
+
60
+ class Head(nn.Module):
61
+ def __init__(
62
+ self,
63
+ num_features: int,
64
+ dim_in: int,
65
+ dim_out: List[int],
66
+ dim_proj: int = 512,
67
+ dim_upsample: List[int] = [256, 128, 128],
68
+ dim_times_res_block_hidden: int = 1,
69
+ num_res_blocks: int = 1,
70
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
71
+ last_res_blocks: int = 0,
72
+ last_conv_channels: int = 32,
73
+ last_conv_size: int = 1
74
+ ):
75
+ super().__init__()
76
+
77
+ self.projects = nn.ModuleList([
78
+ nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features)
79
+ ])
80
+
81
+ self.upsample_blocks = nn.ModuleList([
82
+ nn.Sequential(
83
+ self._make_upsampler(in_ch + 2, out_ch),
84
+ *(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks))
85
+ ) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
86
+ ])
87
+
88
+ self.output_block = nn.ModuleList([
89
+ self._make_output_block(
90
+ dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,
91
+ ) for dim_out_ in dim_out
92
+ ])
93
+
94
+ def _make_upsampler(self, in_channels: int, out_channels: int):
95
+ upsampler = nn.Sequential(
96
+ nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
97
+ nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
98
+ )
99
+ upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
100
+ return upsampler
101
+
102
+ def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']):
103
+ return nn.Sequential(
104
+ nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
105
+ *(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)),
106
+ nn.ReLU(inplace=True),
107
+ nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'),
108
+ )
109
+
110
+ def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
111
+ img_h, img_w = image.shape[-2:]
112
+ patch_h, patch_w = img_h // 14, img_w // 14
113
+
114
+ # Process the hidden states
115
+ x = torch.stack([
116
+ proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
117
+ for proj, (feat, clstoken) in zip(self.projects, hidden_states)
118
+ ], dim=1).sum(dim=1)
119
+
120
+ # Upsample stage
121
+ # (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
122
+ for i, block in enumerate(self.upsample_blocks):
123
+ # UV coordinates is for awareness of image aspect ratio
124
+ uv = image_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
125
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
126
+ x = torch.cat([x, uv], dim=1)
127
+ for layer in block:
128
+ x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
129
+
130
+ # (patch_h * 8, patch_w * 8) -> (img_h, img_w)
131
+ x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
132
+ uv = image_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
133
+ uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
134
+ x = torch.cat([x, uv], dim=1)
135
+
136
+ if isinstance(self.output_block, nn.ModuleList):
137
+ output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block]
138
+ else:
139
+ output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False)
140
+
141
+ return output
142
+
143
+
144
+ class MoGeModel(nn.Module):
145
+ image_mean: torch.Tensor
146
+ image_std: torch.Tensor
147
+
148
+ def __init__(self,
149
+ encoder: str = 'dinov2_vitb14',
150
+ intermediate_layers: Union[int, List[int]] = 4,
151
+ dim_proj: int = 512,
152
+ dim_upsample: List[int] = [256, 128, 128],
153
+ dim_times_res_block_hidden: int = 1,
154
+ num_res_blocks: int = 1,
155
+ output_mask: bool = False,
156
+ split_head: bool = False,
157
+ remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
158
+ res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
159
+ trained_diagonal_size_range: Tuple[Number, Number] = (600, 900),
160
+ trained_area_range: Tuple[Number, Number] = (500 * 500, 700 * 700),
161
+ last_res_blocks: int = 0,
162
+ last_conv_channels: int = 32,
163
+ last_conv_size: int = 1,
164
+ **deprecated_kwargs
165
+ ):
166
+ super(MoGeModel, self).__init__()
167
+ if deprecated_kwargs:
168
+ warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
169
+
170
+ self.encoder = encoder
171
+ self.remap_output = remap_output
172
+ self.intermediate_layers = intermediate_layers
173
+ self.trained_diagonal_size_range = trained_diagonal_size_range
174
+ self.trained_area_range = trained_area_range
175
+ self.output_mask = output_mask
176
+ self.split_head = split_head
177
+
178
+ # NOTE: We have copied the DINOv2 code in torchhub to this repository.
179
+ # Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues.
180
+ hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
181
+ self.backbone = hub_loader(pretrained=False)
182
+ dim_feature = self.backbone.blocks[0].attn.qkv.in_features
183
+
184
+ self.head = Head(
185
+ num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers),
186
+ dim_in=dim_feature,
187
+ dim_out=3 if not output_mask else 4 if output_mask and not split_head else [3, 1],
188
+ dim_proj=dim_proj,
189
+ dim_upsample=dim_upsample,
190
+ dim_times_res_block_hidden=dim_times_res_block_hidden,
191
+ num_res_blocks=num_res_blocks,
192
+ res_block_norm=res_block_norm,
193
+ last_res_blocks=last_res_blocks,
194
+ last_conv_channels=last_conv_channels,
195
+ last_conv_size=last_conv_size
196
+ )
197
+
198
+ image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
199
+ image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
200
+
201
+ self.register_buffer("image_mean", image_mean)
202
+ self.register_buffer("image_std", image_std)
203
+
204
+ if torch.__version__ >= '2.0':
205
+ self.enable_pytorch_native_sdpa()
206
+
207
+ @classmethod
208
+ def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
209
+ """
210
+ Load a model from a checkpoint file.
211
+
212
+ ### Parameters:
213
+ - `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
214
+ - `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
215
+ - `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
216
+
217
+ ### Returns:
218
+ - A new instance of `MoGe` with the parameters loaded from the checkpoint.
219
+ """
220
+ if Path(pretrained_model_name_or_path).exists():
221
+ checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True)
222
+ else:
223
+ cached_checkpoint_path = hf_hub_download(
224
+ repo_id=pretrained_model_name_or_path,
225
+ repo_type="model",
226
+ filename="model.pt",
227
+ **hf_kwargs
228
+ )
229
+ checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True)
230
+ model_config = checkpoint['model_config']
231
+ if model_kwargs is not None:
232
+ model_config.update(model_kwargs)
233
+ model = cls(**model_config)
234
+ model.load_state_dict(checkpoint['model'])
235
+ return model
236
+
237
+ @staticmethod
238
+ def cache_pretrained_backbone(encoder: str, pretrained: bool):
239
+ _ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained)
240
+
241
+ def load_pretrained_backbone(self):
242
+ "Load the backbone with pretrained dinov2 weights from torch hub"
243
+ state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
244
+ self.backbone.load_state_dict(state_dict)
245
+
246
+ def enable_backbone_gradient_checkpointing(self):
247
+ for i in range(len(self.backbone.blocks)):
248
+ self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
249
+
250
+ def enable_pytorch_native_sdpa(self):
251
+ for i in range(len(self.backbone.blocks)):
252
+ self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
253
+
254
+ def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]:
255
+ raw_img_h, raw_img_w = image.shape[-2:]
256
+ patch_h, patch_w = raw_img_h // 14, raw_img_w // 14
257
+
258
+ image = (image - self.image_mean) / self.image_std
259
+
260
+ # Apply image transformation for DINOv2
261
+ image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True)
262
+
263
+ # Get intermediate layers from the backbone
264
+ with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision):
265
+ features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
266
+
267
+ # Predict points (and mask)
268
+ output = self.head(features, image)
269
+ if self.output_mask:
270
+ if self.split_head:
271
+ points, mask = output
272
+ else:
273
+ points, mask = output.split([3, 1], dim=1)
274
+ points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
275
+ else:
276
+ points = output.permute(0, 2, 3, 1)
277
+
278
+ if self.remap_output == 'linear' or self.remap_output == False:
279
+ pass
280
+ elif self.remap_output =='sinh' or self.remap_output == True:
281
+ points = torch.sinh(points)
282
+ elif self.remap_output == 'exp':
283
+ xy, z = points.split([2, 1], dim=-1)
284
+ z = torch.exp(z)
285
+ points = torch.cat([xy * z, z], dim=-1)
286
+ elif self.remap_output =='sinh_exp':
287
+ xy, z = points.split([2, 1], dim=-1)
288
+ points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
289
+ else:
290
+ raise ValueError(f"Invalid remap output type: {self.remap_output}")
291
+
292
+ return_dict = {'points': points}
293
+ if self.output_mask:
294
+ return_dict['mask'] = mask
295
+ return return_dict
296
+
297
+ @torch.inference_mode()
298
+ def infer(
299
+ self,
300
+ image: torch.Tensor,
301
+ force_projection: bool = True,
302
+ resolution_level: int = 9,
303
+ apply_mask: bool = True,
304
+ ) -> Dict[str, torch.Tensor]:
305
+ """
306
+ User-friendly inference function
307
+
308
+ ### Parameters
309
+ - `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
310
+ - `resolution_level`: the resolution level to use for the output point map in 0-9. Default: 9 (highest)
311
+ - `interpolation_mode`: interpolation mode for the output points map. Default: 'bilinear'.
312
+
313
+ ### Returns
314
+
315
+ A dictionary containing the following keys:
316
+ - `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
317
+ - `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
318
+ - `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
319
+ """
320
+ if image.dim() == 3:
321
+ omit_batch_dim = True
322
+ image = image.unsqueeze(0)
323
+ else:
324
+ omit_batch_dim = False
325
+
326
+ original_height, original_width = image.shape[-2:]
327
+ area = original_height * original_width
328
+
329
+ min_area, max_area = self.trained_area_range
330
+ expected_area = min_area + (max_area - min_area) * (resolution_level / 9)
331
+
332
+ if expected_area != area:
333
+ expected_width, expected_height = int(original_width * (expected_area / area) ** 0.5), int(original_height * (expected_area / area) ** 0.5)
334
+ image = F.interpolate(image, (expected_height, expected_width), mode="bicubic", align_corners=False, antialias=True)
335
+
336
+ output = self.forward(image)
337
+ points, mask = output['points'], output.get('mask', None)
338
+
339
+ # Get camera-origin-centered point map
340
+ depth, fov_x, fov_y, z_shift = point_map_to_depth(points, None if mask is None else mask > 0.5)
341
+ intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov_x, fov_y)
342
+
343
+ # If projection constraint is forces, recompute the point map using the actual depth map
344
+ if force_projection:
345
+ points = utils3d.torch.unproject_cv(utils3d.torch.image_uv(width=expected_width, height=expected_height, dtype=points.dtype, device=points.device), depth, extrinsics=None, intrinsics=intrinsics[..., None, :, :])
346
+ else:
347
+ points = points + torch.stack([torch.zeros_like(z_shift), torch.zeros_like(z_shift), z_shift], dim=-1)[..., None, None, :]
348
+
349
+ # Resize the output to the original resolution
350
+ if expected_area != area:
351
+ points = F.interpolate(points.permute(0, 3, 1, 2), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).permute(0, 2, 3, 1)
352
+ depth = F.interpolate(depth.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1)
353
+ mask = None if mask is None else F.interpolate(mask.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1)
354
+
355
+ # Apply mask if needed
356
+ if self.output_mask and apply_mask:
357
+ mask_binary = (depth > 0) & (mask > 0.5)
358
+ points = torch.where(mask_binary[..., None], points, torch.inf)
359
+ depth = torch.where(mask_binary, depth, torch.inf)
360
+
361
+ if omit_batch_dim:
362
+ points = points.squeeze(0)
363
+ intrinsics = intrinsics.squeeze(0)
364
+ depth = depth.squeeze(0)
365
+ if self.output_mask:
366
+ mask = mask.squeeze(0)
367
+
368
+ return_dict = {
369
+ 'points': points,
370
+ 'intrinsics': intrinsics,
371
+ 'depth': depth,
372
+ }
373
+ if self.output_mask:
374
+ return_dict['mask'] = mask > 0.5
375
+
376
+ return return_dict
moge/model/utils.py ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ import torch.nn.functional as F
6
+
7
+ def wrap_module_with_gradient_checkpointing(module: nn.Module):
8
+ from torch.utils.checkpoint import checkpoint
9
+ class _CheckpointingWrapper(module.__class__):
10
+ _restore_cls = module.__class__
11
+ def forward(self, *args, **kwargs):
12
+ return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
13
+
14
+ module.__class__ = _CheckpointingWrapper
15
+ return module
16
+
17
+
18
+ def unwrap_module_with_gradient_checkpointing(module: nn.Module):
19
+ module.__class__ = module.__class__._restore_cls
20
+
21
+
22
+ def wrap_dinov2_attention_with_sdpa(module: nn.Module):
23
+ assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
24
+ class _AttentionWrapper(module.__class__):
25
+ def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
26
+ B, N, C = x.shape
27
+ qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
28
+
29
+ q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
30
+
31
+ x = F.scaled_dot_product_attention(q, k, v, attn_bias)
32
+ x = x.permute(0, 2, 1, 3).reshape(B, N, C)
33
+
34
+ x = self.proj(x)
35
+ x = self.proj_drop(x)
36
+ return x
37
+ module.__class__ = _AttentionWrapper
38
+ return module
moge/utils/__init__.py ADDED
File without changes
moge/utils/blob.py ADDED
@@ -0,0 +1,314 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import IO, Generator, Tuple, Union, overload
2
+ from pathlib import Path, PosixPath, PurePosixPath
3
+ import io
4
+ import os
5
+ import re
6
+ import requests
7
+ import fnmatch
8
+
9
+ from azure.identity import DefaultAzureCredential
10
+ from azure.storage.blob import ContainerClient, BlobClient
11
+ import requests.adapters
12
+ import requests.packages
13
+ from urllib3.util.retry import Retry
14
+
15
+
16
+ __all__ = [
17
+ 'download_blob', 'upload_blob',
18
+ 'download_blob_with_cache',
19
+ 'open_blob', 'open_blob_with_cache',
20
+ 'blob_file_exists',
21
+ 'AzureBlobPath','SmartPath'
22
+ ]
23
+
24
+ DEFAULT_CREDENTIAL = DefaultAzureCredential()
25
+
26
+ BLOB_CACHE_DIR = './.blobcache'
27
+
28
+ def download_blob(blob: Union[str, BlobClient]) -> bytes:
29
+ if isinstance(blob, str):
30
+ blob_client = BlobClient.from_blob_url(blob_client)
31
+ else:
32
+ blob_client = blob
33
+ return blob_client.download_blob().read()
34
+
35
+
36
+ def upload_blob(blob: Union[str, BlobClient], data: Union[str, bytes]):
37
+ if isinstance(blob, str):
38
+ blob_client = BlobClient.from_blob_url(blob)
39
+ else:
40
+ blob_client = blob
41
+ blob_client.upload_blob(data, overwrite=True)
42
+
43
+
44
+ def download_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> bytes:
45
+ """
46
+ Download a blob file from a container and return its content as bytes.
47
+ If the file is already present in the cache, it is read from there.
48
+ """
49
+ cache_path = Path(cache_dir) / blob_name
50
+ if cache_path.exists():
51
+ return cache_path.read_bytes()
52
+ data = download_blob(container, blob_name)
53
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
54
+ cache_path.write_bytes(data)
55
+ return data
56
+
57
+
58
+ def open_blob(container: Union[str, ContainerClient], blob_name: str) -> io.BytesIO:
59
+ """
60
+ Open a blob file for reading from a container and return its content as a BytesIO object.
61
+ """
62
+ return io.BytesIO(download_blob(container, blob_name))
63
+
64
+
65
+ def open_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> io.BytesIO:
66
+ """
67
+ Open a blob file for reading from a container and return its content as a BytesIO object.
68
+ If the file is already present in the cache, it is read from there.
69
+ """
70
+ return io.BytesIO(download_blob_with_cache(container, blob_name, cache_dir=cache_dir))
71
+
72
+
73
+ def blob_file_exists(container: Union[str, ContainerClient], blob_name: str) -> bool:
74
+ """
75
+ Check if a blob file exists in a container.
76
+ """
77
+ if isinstance(container, str):
78
+ container = ContainerClient.from_container_url(container)
79
+ blob_client = container.get_blob_client(blob_name)
80
+ return blob_client.exists()
81
+
82
+ def is_blob_url(url: str) -> bool:
83
+ return re.match(r'https://[^/]+blob.core.windows.net/+', url) is not None
84
+
85
+
86
+ def split_blob_url(url: str) -> Tuple[str, str, str]:
87
+ match = re.match(r'(https://[^/]+blob.core.windows.net/[^/?]+)(/([^\?]*))?(\?.+)?', url)
88
+ if match:
89
+ container, _, path, sas = match.groups()
90
+ return container, path or '', sas or ''
91
+ raise ValueError(f'Not a valid blob URL: {url}')
92
+
93
+
94
+ def join_blob_path(url: str, *others: str) -> str:
95
+ container, path, sas = split_blob_url(url)
96
+ return container + '/' + os.path.join(path, *others) + sas
97
+
98
+
99
+ class AzureBlobStringWriter(io.StringIO):
100
+ def __init__(self, blob_client: BlobClient, encoding: str = 'utf-8', **kwargs):
101
+ self._encoding = encoding
102
+ self.blob_client = blob_client
103
+ self.kwargs = kwargs
104
+ super().__init__()
105
+
106
+ def close(self):
107
+ self.blob_client.upload_blob(self.getvalue().encode(self._encoding), blob_type='BlockBlob', overwrite=True, **self.kwargs)
108
+
109
+
110
+ class AzureBlobBytesWriter(io.BytesIO):
111
+ def __init__(self, blob_client: BlobClient, **kwargs):
112
+ super().__init__()
113
+ self.blob_client = blob_client
114
+ self.kwargs = kwargs
115
+
116
+ def close(self):
117
+ self.blob_client.upload_blob(self.getvalue(), blob_type='BlockBlob', overwrite=True, **self.kwargs)
118
+
119
+
120
+ def open_azure_blob(blob: Union[str, BlobClient], mode: str = 'r', encoding: str = 'utf-8', newline: str = None, cache_blob: bool = False, **kwargs) -> IO:
121
+ if isinstance(blob, str):
122
+ blob_client = BlobClient.from_blob_url(blob)
123
+ elif isinstance(blob, BlobClient):
124
+ blob_client = blob
125
+ else:
126
+ raise ValueError(f'Must be a blob URL or a BlobClient object: {blob}')
127
+
128
+ if cache_blob:
129
+ cache_path = Path(BLOB_CACHE_DIR, blob_client.account_name, blob_client.container_name, blob_client.blob_name)
130
+
131
+ if mode == 'r' or mode == 'rb':
132
+ if cache_blob:
133
+ if cache_path.exists():
134
+ data = cache_path.read_bytes()
135
+ else:
136
+ data = blob_client.download_blob(**kwargs).read()
137
+ cache_path.parent.mkdir(parents=True, exist_ok=True)
138
+ cache_path.write_bytes(data)
139
+ else:
140
+ data = blob_client.download_blob(**kwargs).read()
141
+ if mode == 'r':
142
+ return io.StringIO(data.decode(encoding), newline=newline)
143
+ else:
144
+ return io.BytesIO(data)
145
+ elif mode == 'w':
146
+ return AzureBlobStringWriter(blob_client, **kwargs)
147
+ elif mode == 'wb':
148
+ return AzureBlobBytesWriter(blob_client, **kwargs)
149
+ else:
150
+ raise ValueError(f'Unsupported mode: {mode}')
151
+
152
+
153
+ def smart_open(path_or_url: Union[Path, str], mode: str = 'r', encoding: str = 'utf-8') -> IO:
154
+ if is_blob_url(str(path_or_url)):
155
+ return open_azure_blob(str(path_or_url), mode, encoding)
156
+ return open(path_or_url, mode, encoding)
157
+
158
+
159
+ class AzureBlobPath(PurePosixPath):
160
+ """
161
+ Implementation of pathlib.Path like interface for Azure Blob Storage.
162
+ """
163
+ container_client: ContainerClient
164
+ _parse_path = PurePosixPath._parse_args if hasattr(PurePosixPath, '_parse_args') else PurePosixPath._parse_path
165
+
166
+ def __new__(cls, *args, **kwargs):
167
+ """Override the old __new__ method. Parts are parsed in __init__"""
168
+ return object.__new__(cls)
169
+
170
+ def __init__(self, root: Union[str, 'AzureBlobPath', ContainerClient], *others: Union[str, PurePosixPath], pool_maxsize: int = 256, retries: int = 3):
171
+ if isinstance(root, AzureBlobPath):
172
+ self.container_client = root.container_client
173
+ parts = root.parts + others
174
+ elif isinstance(root, str):
175
+ url = root
176
+ container, path, sas = split_blob_url(url)
177
+ session = self._get_session(pool_maxsize=pool_maxsize, retries=retries)
178
+ if sas:
179
+ self.container_client = ContainerClient.from_container_url(container + sas, session=session)
180
+ else:
181
+ self.container_client = ContainerClient.from_container_url(container, credential=DEFAULT_CREDENTIAL, session=session)
182
+ parts = (path, *others)
183
+ elif isinstance(root, ContainerClient):
184
+ self.container_client = root
185
+ parts = others
186
+ else:
187
+ raise ValueError(f'Invalid root: {root}')
188
+
189
+ if hasattr(PurePosixPath, '_parse_args'):
190
+ # For compatibility with Python 3.10
191
+ drv, root, parts = PurePosixPath._parse_args(parts)
192
+ self._drv = drv
193
+ self._root = root
194
+ self._parts = parts
195
+ else:
196
+ super().__init__(*parts)
197
+
198
+ def _get_session(self, pool_maxsize: int = 1024, retries: int = 3) -> requests.Session:
199
+ session = requests.Session()
200
+ retry_strategy = Retry(
201
+ total=retries,
202
+ status_forcelist=[429, 500, 502, 503, 504],
203
+ allowed_methods=["HEAD", "GET", "PUT", "DELETE"],
204
+ backoff_factor=1,
205
+ raise_on_status=False,
206
+ read=retries,
207
+ connect=retries,
208
+ redirect=retries,
209
+ )
210
+ adapter = requests.adapters.HTTPAdapter(pool_connections=pool_maxsize, pool_maxsize=pool_maxsize, max_retries=retry_strategy)
211
+ session.mount('http://', adapter)
212
+ session.mount('https://', adapter)
213
+ return session
214
+
215
+ def _from_parsed_parts(self, drv, root, parts):
216
+ "For compatibility with Python 3.10"
217
+ return AzureBlobPath(self.container_client, drv, root, *parts)
218
+
219
+ def with_segments(self, *pathsegments):
220
+ return AzureBlobPath(self.container_client, *pathsegments)
221
+
222
+ @property
223
+ def path(self) -> str:
224
+ return '/'.join(self.parts)
225
+
226
+ @property
227
+ def blob_client(self) -> BlobClient:
228
+ return self.container_client.get_blob_client(self.path)
229
+
230
+ @property
231
+ def url(self) -> str:
232
+ if len(self.parts) == 0:
233
+ return self.container_client.url
234
+ return self.container_client.get_blob_client(self.path).url
235
+
236
+ @property
237
+ def container_name(self) -> str:
238
+ return self.container_client.container_name
239
+
240
+ @property
241
+ def account_name(self) -> str:
242
+ return self.container_client.account_name
243
+
244
+ def __str__(self):
245
+ return self.url
246
+
247
+ def __repr__(self):
248
+ return self.url
249
+
250
+ def open(self, mode: str = 'r', encoding: str = 'utf-8', cache_blob: bool = False, **kwargs) -> IO:
251
+ return open_azure_blob(self.blob_client, mode, encoding, cache_blob=cache_blob, **kwargs)
252
+
253
+ def __truediv__(self, other: Union[str, Path]) -> 'AzureBlobPath':
254
+ return self.joinpath(other)
255
+
256
+ def mkdir(self, parents: bool = False, exist_ok: bool = False):
257
+ pass
258
+
259
+ def iterdir(self) -> Generator['AzureBlobPath', None, None]:
260
+ path = self.path
261
+ if not path.endswith('/'):
262
+ path += '/'
263
+ for item in self.container_client.walk_blobs(self.path):
264
+ yield AzureBlobPath(self.container_client, item.name)
265
+
266
+ def glob(self, pattern: str) -> Generator['AzureBlobPath', None, None]:
267
+ special_chars = ".^$+{}[]()|/"
268
+ for char in special_chars:
269
+ pattern = pattern.replace(char, "\\" + char)
270
+ pattern = pattern.replace('**', './/.')
271
+ pattern = pattern.replace('*', '[^/]*')
272
+ pattern = pattern.replace('.//.', '.*')
273
+ pattern = "^" + pattern + "$"
274
+ reg = re.compile(pattern)
275
+
276
+ for item in self.container_client.list_blobs(self.path):
277
+ if reg.match(os.path.relpath(item.name, self.path)):
278
+ yield AzureBlobPath(self.container_client, item.name)
279
+
280
+ def exists(self) -> bool:
281
+ return self.blob_client.exists()
282
+
283
+ def read_bytes(self, cache_blob: bool = False) -> bytes:
284
+ with self.open('rb', cache_blob=cache_blob) as f:
285
+ return f.read()
286
+
287
+ def read_text(self, encoding: str = 'utf-8', cache_blob: bool = False) -> str:
288
+ with self.open('r', encoding=encoding, cache_blob=cache_blob) as f:
289
+ return f.read()
290
+
291
+ def write_bytes(self, data: bytes):
292
+ self.blob_client.upload_blob(data, overwrite=True)
293
+
294
+ def write_text(self, data: str, encoding: str = 'utf-8'):
295
+ self.blob_client.upload_blob(data.encode(encoding), overwrite=True)
296
+
297
+ def unlink(self):
298
+ self.blob_client.delete_blob()
299
+
300
+ def new_client(self) -> 'AzureBlobPath':
301
+ return AzureBlobPath(self.container_client.url, self.path)
302
+
303
+
304
+ class SmartPath(Path, AzureBlobPath):
305
+ """
306
+ Supports both local file paths and Azure Blob Storage URLs.
307
+ """
308
+ def __new__(cls, first: Union[Path, str], *others: Union[str, PurePosixPath]) -> Union[Path, AzureBlobPath]:
309
+ if is_blob_url(str(first)):
310
+ return AzureBlobPath(str(first), *others)
311
+ return Path(first, *others)
312
+
313
+
314
+
moge/utils/download.py ADDED
@@ -0,0 +1,55 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import *
3
+ import requests
4
+
5
+ from tqdm import tqdm
6
+
7
+
8
+ __all__ = ["download_file", "download_bytes"]
9
+
10
+
11
+ def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None:
12
+ # Ensure headers is a dict if not provided
13
+ headers = headers or {}
14
+
15
+ # Initialize local variables
16
+ file_path = Path(filepath)
17
+ downloaded_bytes = 0
18
+
19
+ # Check if we should resume the download
20
+ if resume and file_path.exists():
21
+ downloaded_bytes = file_path.stat().st_size
22
+ headers['Range'] = f"bytes={downloaded_bytes}-"
23
+
24
+ # Make a GET request to fetch the file
25
+ with requests.get(url, stream=True, headers=headers) as response:
26
+ response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
27
+
28
+ # Calculate the total size to download
29
+ total_size = downloaded_bytes + int(response.headers.get('content-length', 0))
30
+
31
+ # Display a progress bar while downloading
32
+ with (
33
+ tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar,
34
+ open(file_path, 'ab') as file,
35
+ ):
36
+ # Set the initial position of the progress bar
37
+ pbar.update(downloaded_bytes)
38
+
39
+ # Write the content to the file in chunks
40
+ for chunk in response.iter_content(chunk_size=4096):
41
+ file.write(chunk)
42
+ pbar.update(len(chunk))
43
+
44
+
45
+ def download_bytes(url: str, headers: dict = None) -> bytes:
46
+ # Ensure headers is a dict if not provided
47
+ headers = headers or {}
48
+
49
+ # Make a GET request to fetch the file
50
+ with requests.get(url, stream=True, headers=headers) as response:
51
+ response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
52
+
53
+ # Read the content of the response
54
+ return response.content
55
+
moge/utils/geometry_numpy.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from functools import partial
3
+ import math
4
+
5
+ import numpy as np
6
+ import utils3d
7
+
8
+ from .tools import timeit
9
+
10
+ def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
11
+ if w is None:
12
+ return np.mean(x, axis=axis)
13
+ else:
14
+ w = w.astype(x.dtype)
15
+ return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
16
+
17
+
18
+ def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
19
+ if w is None:
20
+ return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
21
+ else:
22
+ w = w.astype(x.dtype)
23
+ return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
24
+
25
+
26
+ def image_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
27
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
28
+ if aspect_ratio is None:
29
+ aspect_ratio = width / height
30
+
31
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
32
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
33
+
34
+ u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
35
+ v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
36
+ u, v = np.meshgrid(u, v, indexing='xy')
37
+ uv = np.stack([u, v], axis=-1)
38
+ return uv
39
+
40
+
41
+ def focal_to_fov_numpy(focal: np.ndarray):
42
+ return 2 * np.arctan(0.5 / focal)
43
+
44
+
45
+ def fov_to_focal_numpy(fov: np.ndarray):
46
+ return 0.5 / np.tan(fov / 2)
47
+
48
+
49
+ def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
50
+ fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
51
+ fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
52
+ return fov_x, fov_y
53
+
54
+
55
+ def solve_optimal_shift_focal(uv: np.ndarray, xyz: np.ndarray, ransac_iters: int = None, ransac_hypothetical_size: float = 0.1, ransac_threshold: float = 0.1):
56
+ "Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
57
+ from scipy.optimize import least_squares
58
+ uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
59
+
60
+ def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
61
+ xy_proj = xy / (z + shift)[: , None]
62
+ f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
63
+ err = (f * xy_proj - uv).ravel()
64
+ return err
65
+
66
+ initial_shift = 0 #-z.min(keepdims=True) + 1.0
67
+
68
+ if ransac_iters is None:
69
+ solution = least_squares(partial(fn, uv, xy, z), x0=initial_shift, ftol=1e-3, method='lm')
70
+ optim_shift = solution['x'].squeeze().astype(np.float32)
71
+ else:
72
+ best_err, best_shift = np.inf, None
73
+ for _ in range(ransac_iters):
74
+ maybe_inliers = np.random.choice(len(z), size=int(ransac_hypothetical_size * len(z)), replace=False)
75
+ solution = least_squares(partial(fn, uv[maybe_inliers], xy[maybe_inliers], z[maybe_inliers]), x0=initial_shift, ftol=1e-3, method='lm')
76
+ maybe_shift = solution['x'].squeeze().astype(np.float32)
77
+ confirmed_inliers = np.linalg.norm(fn(uv, xy, z, maybe_shift).reshape(-1, 2), axis=-1) < ransac_threshold
78
+ if confirmed_inliers.sum() > 10:
79
+ solution = least_squares(partial(fn, uv[confirmed_inliers], xy[confirmed_inliers], z[confirmed_inliers]), x0=maybe_shift, ftol=1e-3, method='lm')
80
+ better_shift = solution['x'].squeeze().astype(np.float32)
81
+ else:
82
+ better_shift = maybe_shift
83
+ err = np.linalg.norm(fn(uv, xy, z, better_shift).reshape(-1, 2), axis=-1).clip(max=ransac_threshold).mean()
84
+ if err < best_err:
85
+ best_err, best_shift = err, better_shift
86
+ initial_shift = best_shift
87
+
88
+ optim_shift = best_shift
89
+
90
+ xy_proj = xy / (z + optim_shift)[: , None]
91
+ optim_focal = (xy_proj * uv).sum() / (xy_proj * xy_proj).sum()
92
+
93
+ return optim_shift, optim_focal
94
+
95
+
96
+ def point_map_to_depth_numpy(points: np.ndarray, mask: np.ndarray = None, downsample_size: Tuple[int, int] = (64, 64)):
97
+ import cv2
98
+ assert points.shape[-1] == 3, "Points should (H, W, 3)"
99
+
100
+ height, width = points.shape[-3], points.shape[-2]
101
+ diagonal = (height ** 2 + width ** 2) ** 0.5
102
+
103
+ uv = image_plane_uv_numpy(width=width, height=height)
104
+
105
+ if mask is None:
106
+ points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
107
+ uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
108
+ else:
109
+ index, mask_lr = mask_aware_nearest_resize_numpy(mask, *downsample_size)
110
+ points_lr, uv_lr = points[index][mask_lr], uv[index][mask_lr]
111
+
112
+ if points_lr.size == 0:
113
+ return np.zeros((height, width)), 0, 0, 0
114
+
115
+ optim_shift, optim_focal = solve_optimal_shift_focal(uv_lr, points_lr, ransac_iters=None)
116
+
117
+ fov_x = 2 * np.arctan(width / diagonal / optim_focal)
118
+ fov_y = 2 * np.arctan(height / diagonal / optim_focal)
119
+
120
+ depth = points[:, :, 2] + optim_shift
121
+ return depth, fov_x, fov_y, optim_shift
122
+
123
+
124
+ def mask_aware_nearest_resize_numpy(mask: np.ndarray, target_width: int, target_height: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
125
+ """
126
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
127
+
128
+ ### Parameters
129
+ - `mask`: Input 2D mask of shape (..., H, W)
130
+ - `target_width`: target width of the resized map
131
+ - `target_height`: target height of the resized map
132
+
133
+ ### Returns
134
+ - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width). Indices are like j + i * W, where j is the row index and i is the column index.
135
+ - `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
136
+ """
137
+ height, width = mask.shape[-2:]
138
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
139
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
140
+ filter_size = filter_h_i * filter_w_i
141
+ padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2)
142
+
143
+ # Window the original mask and uv
144
+ uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
145
+ indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
146
+ padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
147
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
148
+ padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
149
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
150
+ padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
151
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
152
+ windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
153
+ windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
154
+ windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
155
+
156
+ # Gather the target pixels's local window
157
+ target_uv = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
158
+ target_corner = target_uv - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
159
+ target_corner = np.round(target_corner - 0.5).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
160
+
161
+ target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
162
+ target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
163
+ target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
164
+
165
+ # Compute nearest neighbor in the local window for each pixel
166
+ dist = np.square(target_window_uv - target_uv[..., None])
167
+ dist = dist[..., 0, :] + dist[..., 1, :]
168
+ dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
169
+ nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
170
+ nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
171
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
172
+ target_mask = np.any(target_window_mask, axis=-1)
173
+ batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
174
+
175
+ return (*batch_indices, nearest_i, nearest_j), target_mask
moge/utils/geometry_torch.py ADDED
@@ -0,0 +1,231 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import math
3
+ from collections import namedtuple
4
+
5
+ import numpy as np
6
+ import torch
7
+ import torch.nn as nn
8
+ import torch.nn.functional as F
9
+ import torch.types
10
+ import utils3d
11
+
12
+ from .tools import timeit
13
+ from .geometry_numpy import solve_optimal_shift_focal
14
+
15
+
16
+ def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
17
+ if w is None:
18
+ return x.mean(dim=dim, keepdim=keepdim)
19
+ else:
20
+ w = w.to(x.dtype)
21
+ return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
22
+
23
+
24
+ def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
25
+ if w is None:
26
+ return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
27
+ else:
28
+ w = w.to(x.dtype)
29
+ return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
30
+
31
+
32
+ def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
33
+ if w is None:
34
+ return x.add(eps).log().mean(dim=dim).exp()
35
+ else:
36
+ w = w.to(x.dtype)
37
+ return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
38
+
39
+
40
+ def image_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
41
+ "UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
42
+ if aspect_ratio is None:
43
+ aspect_ratio = width / height
44
+
45
+ span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
46
+ span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
47
+
48
+ u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
49
+ v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
50
+ u, v = torch.meshgrid(u, v, indexing='xy')
51
+ uv = torch.stack([u, v], dim=-1)
52
+ return uv
53
+
54
+
55
+ def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
56
+ kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
57
+ kernel = kernel / kernel.sum()
58
+ kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
59
+ input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
60
+ input = F.conv2d(input, kernel, groups=input.shape[1])
61
+ return input
62
+
63
+
64
+ def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
65
+ batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
66
+ n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
67
+ splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
68
+ splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
69
+ results = []
70
+ for i in range(n_chunks):
71
+ chunk_args = tuple(arg[i] for arg in splited_args)
72
+ chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
73
+ results.append(fn(*chunk_args, **chunk_kwargs))
74
+
75
+ if isinstance(results[0], tuple):
76
+ return tuple(torch.cat(r, dim=0) for r in zip(*results))
77
+ else:
78
+ return torch.cat(results, dim=0)
79
+
80
+
81
+ def focal_to_fov(focal: torch.Tensor):
82
+ return 2 * torch.atan(0.5 / focal)
83
+
84
+
85
+ def fov_to_focal(fov: torch.Tensor):
86
+ return 0.5 / torch.tan(fov / 2)
87
+
88
+
89
+ def intrinsics_to_fov(intrinsics: torch.Tensor):
90
+ """
91
+ Returns field of view in radians from normalized intrinsics matrix.
92
+ ### Parameters:
93
+ - intrinsics: torch.Tensor of shape (..., 3, 3)
94
+
95
+ ### Returns:
96
+ - fov_x: torch.Tensor of shape (...)
97
+ - fov_y: torch.Tensor of shape (...)
98
+ """
99
+ focal_x = intrinsics[..., 0, 0]
100
+ focal_y = intrinsics[..., 1, 1]
101
+ return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
102
+
103
+
104
+ def point_map_to_depth_legacy(points: torch.Tensor):
105
+ height, width = points.shape[-3:-1]
106
+ diagonal = (height ** 2 + width ** 2) ** 0.5
107
+ uv = image_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
108
+
109
+ # Solve least squares problem
110
+ b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
111
+ A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
112
+
113
+ M = A.transpose(-2, -1) @ A
114
+ solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
115
+ focal, shift = solution.unbind(-1)
116
+
117
+ depth = points[..., 2] + shift[..., None, None]
118
+ fov_x = torch.atan(width / diagonal / focal) * 2
119
+ fov_y = torch.atan(height / diagonal / focal) * 2
120
+ return depth, fov_x, fov_y, shift
121
+
122
+
123
+ def point_map_to_depth(points: torch.Tensor, mask: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
124
+ """
125
+ Recover the depth map and FoV from a point map with unknown z shift and focal.
126
+
127
+ Note that it assumes:
128
+ - the optical center is at the center of the map
129
+ - the map is undistorted
130
+ - the map is isometric in the x and y directions
131
+
132
+ ### Parameters:
133
+ - `points: torch.Tensor` of shape (..., H, W, 3)
134
+ - `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
135
+
136
+ ### Returns:
137
+ - `depth: torch.Tensor` of shape (..., H, W)
138
+ - `fov_x: torch.Tensor` of shape (...)
139
+ - `fov_y: torch.Tensor` of shape (...)
140
+ - `shift: torch.Tensor` of shape (...), the z shift, making `depth = points[..., 2] + shift`
141
+ """
142
+ shape = points.shape
143
+ height, width = points.shape[-3], points.shape[-2]
144
+ diagonal = (height ** 2 + width ** 2) ** 0.5
145
+
146
+ points = points.reshape(-1, *shape[-3:])
147
+ mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
148
+ uv = image_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
149
+
150
+ points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
151
+ uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
152
+ mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
153
+
154
+ uv_lr_np = uv_lr.cpu().numpy()
155
+ points_lr_np = points_lr.detach().cpu().numpy()
156
+ mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
157
+ optim_shift, optim_focal = [], []
158
+ for i in range(points.shape[0]):
159
+ points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
160
+ uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
161
+ optim_shift_i, optim_focal_i = solve_optimal_shift_focal(uv_lr_i_np, points_lr_i_np, ransac_iters=None)
162
+ optim_shift.append(float(optim_shift_i))
163
+ optim_focal.append(float(optim_focal_i))
164
+ optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype)
165
+ optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype)
166
+
167
+ fov_x = 2 * torch.atan(width / diagonal / optim_focal)
168
+ fov_y = 2 * torch.atan(height / diagonal / optim_focal)
169
+
170
+ depth = (points[..., 2] + optim_shift[:, None, None]).reshape(shape[:-1])
171
+ fov_x = fov_x.reshape(shape[:-3])
172
+ fov_y = fov_y.reshape(shape[:-3])
173
+ optim_shift = optim_shift.reshape(shape[:-3])
174
+
175
+ return depth, fov_x, fov_y, optim_shift
176
+
177
+
178
+ def mask_aware_nearest_resize(mask: torch.BoolTensor, target_width: int, target_height: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]:
179
+ """
180
+ Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
181
+
182
+ ### Parameters
183
+ - `mask`: Input 2D mask of shape (..., H, W)
184
+ - `target_width`: target width of the resized map
185
+ - `target_height`: target height of the resized map
186
+
187
+ ### Returns
188
+ - `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension
189
+ - `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
190
+ """
191
+ height, width = mask.shape[-2:]
192
+ device = mask.device
193
+ filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
194
+ filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
195
+ filter_size = filter_h_i * filter_w_i
196
+ padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2)
197
+
198
+ # Window the original mask and uv
199
+ uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
200
+ indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
201
+ padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
202
+ padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
203
+ padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
204
+ padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
205
+ padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
206
+ padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
207
+ windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
208
+ windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
209
+ windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
210
+
211
+ # Gather the target pixels's local window
212
+ target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
213
+ target_corner = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
214
+ target_corner = torch.round(target_corner - 0.5).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
215
+
216
+ target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
217
+ target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
218
+ target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
219
+ target_window_indices = target_window_indices.expand_as(target_window_mask)
220
+
221
+ # Compute nearest neighbor in the local window for each pixel
222
+ dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
223
+ nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
224
+ nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
225
+ target_mask = torch.any(target_window_mask, dim=-1)
226
+ nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
227
+ batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
228
+
229
+ return (*batch_indices, nearest_i, nearest_j), target_mask
230
+
231
+
moge/utils/io.py ADDED
@@ -0,0 +1,347 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
3
+ from typing import IO
4
+ import zipfile
5
+ import json
6
+ import io
7
+ from typing import *
8
+ from pathlib import Path
9
+ import re
10
+
11
+ import numpy as np
12
+ import cv2
13
+
14
+ from .tools import timeit
15
+
16
+
17
+ LEGACY_SEGFORMER_CLASSES = [
18
+ 'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
19
+ 'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
20
+ 'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
21
+ 'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
22
+ 'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
23
+ 'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
24
+ 'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
25
+ 'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
26
+ 'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
27
+ 'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
28
+ 'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
29
+ 'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
30
+ 'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
31
+ 'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
32
+ 'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
33
+ 'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
34
+ 'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
35
+ 'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
36
+ 'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
37
+ 'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
38
+ 'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
39
+ 'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
40
+ 'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
41
+ 'clock', 'flag'
42
+ ]
43
+ LEGACY_SEGFORMER_LABELS = {k: i for i, k in enumerate(LEGACY_SEGFORMER_CLASSES)}
44
+
45
+
46
+ def write_rgbd_zip(
47
+ file: Union[IO, os.PathLike],
48
+ image: Union[np.ndarray, bytes],
49
+ depth: Union[np.ndarray, bytes], mask: Union[np.ndarray, bytes],
50
+ segmentation_mask: Union[np.ndarray, bytes] = None, segmentation_labels: Union[Dict[str, int], bytes] = None,
51
+ intrinsics: np.ndarray = None,
52
+ normal: np.ndarray = None, normal_mask: np.ndarray = None,
53
+ meta: Union[Dict[str, Any], bytes] = None,
54
+ *, image_quality: int = 95, depth_type: Literal['linear', 'log', 'disparity'] = 'linear', depth_format: Literal['png', 'exr'] = 'png', depth_max_dynamic_range: float = 1e4, png_compression: int = 7
55
+ ):
56
+ """
57
+ Write RGBD data as zip archive containing the image, depth, mask, segmentation_mask, and meta data.
58
+ In the zip file there will be:
59
+ - `meta.json`: The meta data as a JSON file.
60
+ - `image.jpg`: The RGB image as a JPEG file.
61
+ - `depth.png/exr`: The depth map as a PNG or EXR file, depending on the `depth_type`.
62
+ - `mask.png` (optional): The mask as a uint8 PNG file.
63
+ - `segmentation_mask.png` (optional): The segformer mask as a uint8/uint16 PNG file.
64
+
65
+ You can provided those data as np.ndarray or bytes. If you provide them as np.ndarray, they will be properly processed and encoded.
66
+ If you provide them as bytes, they will be written as is, assuming they are already encoded.
67
+ """
68
+ if meta is None:
69
+ meta = {}
70
+ elif isinstance(meta, bytes):
71
+ meta = json.loads(meta.decode())
72
+
73
+ if isinstance(image, bytes):
74
+ image_bytes = image
75
+ elif isinstance(image, np.ndarray):
76
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
77
+ image_bytes = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, image_quality])[1].tobytes()
78
+
79
+ if isinstance(depth, bytes):
80
+ depth_bytes = depth
81
+ elif isinstance(depth, np.ndarray):
82
+ meta['depth_type'] = depth_type
83
+ if depth_type == 'linear':
84
+ if depth.dtype == np.float16:
85
+ depth_format = 'exr'
86
+ depth_bytes = cv2.imencode('.exr', depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])[1].tobytes()
87
+ elif np.issubdtype(depth.dtype, np.floating):
88
+ depth_format = 'exr'
89
+ depth_bytes = cv2.imencode('.exr', depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])[1].tobytes()
90
+ elif depth.dtype in [np.uint8, np.uint16]:
91
+ depth_format = 'png'
92
+ depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes()
93
+ elif depth_type == 'log':
94
+ depth_format = 'png'
95
+ depth = depth.astype(np.float32)
96
+ near = max(depth[mask].min(), 1e-3)
97
+ far = min(depth[mask].max(), near * depth_max_dynamic_range)
98
+ depth = ((np.log(depth.clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65535).astype(np.uint16)
99
+ depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes()
100
+ meta['depth_near'] = float(near)
101
+ meta['depth_far'] = float(far)
102
+ elif depth_type == 'disparity':
103
+ depth_format = 'png'
104
+ depth = depth.astype(np.float32)
105
+ depth = 1 / (depth + 1e-12)
106
+ depth = (depth / depth[mask].max()).clip(0, 1)
107
+ if np.unique(depth) < 200:
108
+ depth = (depth * 255).astype(np.uint8)
109
+ else:
110
+ depth = (depth * 65535).astype(np.uint16)
111
+ depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes()
112
+
113
+ if isinstance(mask, bytes):
114
+ mask_bytes = mask
115
+ elif isinstance(mask, np.ndarray):
116
+ mask_bytes = cv2.imencode('.png', mask.astype(np.uint8) * 255)[1].tobytes()
117
+
118
+ if segmentation_mask is not None:
119
+ if isinstance(segmentation_mask, bytes):
120
+ segmentation_mask_bytes = segmentation_mask
121
+ else:
122
+ segmentation_mask_bytes = cv2.imencode('.png', segmentation_mask)[1].tobytes()
123
+ assert segmentation_labels is not None, "You provided a segmentation mask, but not the corresponding labels."
124
+ if isinstance(segmentation_labels, bytes):
125
+ segmentation_labels = json.loads(segmentation_labels)
126
+ meta['segmentation_labels'] = segmentation_labels
127
+
128
+ if intrinsics is not None:
129
+ meta['intrinsics'] = intrinsics.tolist()
130
+
131
+ if normal is not None:
132
+ if isinstance(normal, bytes):
133
+ normal_bytes = normal
134
+ elif isinstance(normal, np.ndarray):
135
+ normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16)
136
+ normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR)
137
+ normal_bytes = cv2.imencode('.png', normal, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes()
138
+ if normal_mask is None:
139
+ normal_mask = np.ones(image.shape[:2], dtype=bool)
140
+ normal_mask_bytes = cv2.imencode('.png', normal_mask.astype(np.uint8) * 255)[1].tobytes()
141
+
142
+ meta_bytes = meta if isinstance(meta, bytes) else json.dumps(meta).encode()
143
+
144
+ with zipfile.ZipFile(file, 'w') as z:
145
+ z.writestr('meta.json', meta_bytes)
146
+ z.writestr('image.jpg', image_bytes)
147
+ z.writestr(f'depth.{depth_format}', depth_bytes)
148
+ z.writestr('mask.png', mask_bytes)
149
+ if segmentation_mask is not None:
150
+ z.writestr('segmentation_mask.png', segmentation_mask_bytes)
151
+ if normal is not None:
152
+ z.writestr('normal.png', normal_bytes)
153
+ z.writestr('normal_mask.png', normal_mask_bytes)
154
+
155
+
156
+ def read_rgbd_zip(file: Union[str, Path, IO], return_bytes: bool = False) -> Dict[str, Union[np.ndarray, Dict[str, Any], bytes]]:
157
+ """
158
+ Read an RGBD zip file and return the image, depth, mask, segmentation_mask, intrinsics, and meta data.
159
+
160
+ ### Parameters:
161
+ - `file: Union[str, Path, IO]`
162
+ The file path or file object to read from.
163
+ - `return_bytes: bool = False`
164
+ If True, return the image, depth, mask, and segmentation_mask as raw bytes.
165
+
166
+ ### Returns:
167
+ - `Tuple[Dict[str, Union[np.ndarray, Dict[str, Any]]], Dict[str, bytes]]`
168
+ A dictionary containing: (If missing, the value will be None; if return_bytes is True, the value will be bytes)
169
+ - `image`: RGB numpy.ndarray of shape (H, W, 3).
170
+ - `depth`: float32 numpy.ndarray of shape (H, W).
171
+ - `mask`: bool numpy.ndarray of shape (H, W).
172
+ - `segformer_mask`: uint8 numpy.ndarray of shape (H, W).
173
+ - `intrinsics`: float32 numpy.ndarray of shape (3, 3).
174
+ - `meta`: Dict[str, Any].
175
+ """
176
+ # Load & extract archive
177
+ with zipfile.ZipFile(file, 'r') as z:
178
+ meta = z.read('meta.json')
179
+ if not return_bytes:
180
+ meta = json.loads(z.read('meta.json'))
181
+
182
+ image = z.read('image.jpg')
183
+ if not return_bytes:
184
+ image = cv2.imdecode(np.frombuffer(z.read('image.jpg'), np.uint8), cv2.IMREAD_COLOR)
185
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
186
+
187
+ depth_name = next(s for s in z.namelist() if s.startswith('depth'))
188
+ depth = z.read(depth_name)
189
+ if not return_bytes:
190
+ depth = cv2.imdecode(np.frombuffer(z.read(depth_name), np.uint8), cv2.IMREAD_UNCHANGED)
191
+
192
+ if 'mask.png' in z.namelist():
193
+ mask = z.read('mask.png')
194
+ if not return_bytes:
195
+ mask = cv2.imdecode(np.frombuffer(z.read('mask.png'), np.uint8), cv2.IMREAD_UNCHANGED) > 0
196
+ else:
197
+ mask = None
198
+
199
+ if 'segformer_mask.png' in z.namelist():
200
+ # NOTE: Legacy support for segformer_mask.png
201
+ segmentation_mask = z.read('segformer_mask.png')
202
+ segmentation_labels = None
203
+ if not return_bytes:
204
+ segmentation_mask = cv2.imdecode(np.frombuffer(segmentation_mask, np.uint8), cv2.IMREAD_UNCHANGED)
205
+ segmentation_labels = LEGACY_SEGFORMER_LABELS
206
+ elif 'segmentation_mask.png' in z.namelist():
207
+ segmentation_mask = z.read('segmentation_mask.png')
208
+ segmentation_labels = None
209
+ if not return_bytes:
210
+ segmentation_mask = cv2.imdecode(np.frombuffer(segmentation_mask, np.uint8), cv2.IMREAD_UNCHANGED)
211
+ segmentation_labels = meta['segmentation_labels']
212
+ else:
213
+ segmentation_mask = None
214
+ segmentation_labels = None
215
+
216
+ if 'normal.png' in z.namelist():
217
+ normal = z.read('normal.png')
218
+ if not return_bytes:
219
+ normal = cv2.imdecode(np.frombuffer(z.read('normal.png'), np.uint8), cv2.IMREAD_UNCHANGED)
220
+ normal = cv2.cvtColor(normal, cv2.COLOR_BGR2RGB)
221
+ normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0]
222
+ normal = normal / np.linalg.norm(normal, axis=-1, keepdims=True)
223
+
224
+ if 'normal_mask.png' in z.namelist():
225
+ normal_mask = z.read('normal_mask.png')
226
+ normal_mask = cv2.imdecode(np.frombuffer(normal_mask, np.uint8), cv2.IMREAD_UNCHANGED) > 0
227
+ else:
228
+ normal_mask = np.ones(image.shape[:2], dtype=bool)
229
+ else:
230
+ normal, normal_mask = None, None
231
+
232
+ # recover linear depth
233
+ if not return_bytes:
234
+ if mask is None:
235
+ mask = np.ones(image.shape[:2], dtype=bool)
236
+ if meta['depth_type'] == 'linear':
237
+ depth = depth.astype(np.float32)
238
+ mask = mask & (depth > 0)
239
+ elif meta['depth_type'] == 'log':
240
+ near, far = meta['depth_near'], meta['depth_far']
241
+ if depth.dtype == np.uint16:
242
+ depth = depth.astype(np.float32) / 65535
243
+ elif depth.dtype == np.uint8:
244
+ depth = depth.astype(np.float32) / 255
245
+ depth = near ** (1 - depth) * far ** depth
246
+ mask = mask & ~np.isnan(depth)
247
+ elif meta['depth_type'] == 'disparity':
248
+ mask = mask & (depth > 0)
249
+ if depth.dtype == np.uint16:
250
+ depth = depth.astype(np.float32) / 65535
251
+ elif depth.dtype == np.uint8:
252
+ depth = depth.astype(np.float32) / 255
253
+ depth = 1 / (depth + 1e-12)
254
+
255
+ # intrinsics
256
+ if not return_bytes and 'intrinsics' in meta:
257
+ intrinsics = np.array(meta['intrinsics'], dtype=np.float32)
258
+ else:
259
+ intrinsics = None
260
+
261
+ # depth unit
262
+ if not return_bytes and 'depth_unit' in meta:
263
+ depth_unit_str = meta['depth_unit']
264
+ if r := re.match(r'([\d.]*)(\w*)', depth_unit_str):
265
+ digits, unit = r.groups()
266
+ depth_unit = float(digits or 1) * {'m': 1, 'cm': 0.01, 'mm': 0.001}[unit]
267
+ else:
268
+ depth_unit = None
269
+ else:
270
+ depth_unit = None
271
+
272
+ return_dict = {
273
+ 'image': image,
274
+ 'depth': depth,
275
+ 'mask': mask,
276
+ 'segmentation_mask': segmentation_mask,
277
+ 'segmentation_labels': segmentation_labels,
278
+ 'normal': normal,
279
+ 'normal_mask': normal_mask,
280
+ 'intrinsics': intrinsics,
281
+ 'depth_unit': depth_unit,
282
+ 'meta': meta,
283
+ }
284
+ return_dict = {k: v for k, v in return_dict.items() if v is not None}
285
+
286
+ return return_dict
287
+
288
+ def write_rgbxyz(file: Union[IO, Path], image: np.ndarray, points: np.ndarray, mask: np.ndarray = None, image_quality: int = 95):
289
+ if isinstance(image, bytes):
290
+ image_bytes = image
291
+ elif isinstance(image, np.ndarray):
292
+ image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
293
+ image_bytes = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, image_quality])[1].tobytes()
294
+
295
+ if isinstance(points, bytes):
296
+ points_bytes = points
297
+ elif isinstance(points, np.ndarray):
298
+ points_bytes = cv2.imencode('.exr', points.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])[1].tobytes()
299
+
300
+ if mask is None:
301
+ mask = np.ones(image.shape[:2], dtype=bool)
302
+ if isinstance(mask, bytes):
303
+ mask_bytes = mask
304
+ elif isinstance(mask, np.ndarray):
305
+ mask_bytes = cv2.imencode('.png', mask.astype(np.uint8) * 255)[1].tobytes()
306
+
307
+ is_archive = hasattr(file, 'write') or Path(file).suffix == '.zip'
308
+ if is_archive:
309
+ with zipfile.ZipFile(file, 'w') as z:
310
+ z.writestr('image.jpg', image_bytes)
311
+ z.writestr('points.exr', points_bytes)
312
+ if mask is not None:
313
+ z.writestr('mask.png', mask_bytes)
314
+ else:
315
+ file = Path(file)
316
+ file.mkdir(parents=True, exist_ok=True)
317
+ with open(file / 'image.jpg', 'wb') as f:
318
+ f.write(image_bytes)
319
+ with open(file / 'points.exr', 'wb') as f:
320
+ f.write(points_bytes)
321
+ if mask is not None:
322
+ with open(file / 'mask.png', 'wb') as f:
323
+ f.write(mask_bytes)
324
+
325
+
326
+ def read_rgbxyz(file: Union[IO, str, Path]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict[str, Any]]:
327
+ is_archive = hasattr(file, 'read') or Path(file).suffix == '.zip'
328
+ if is_archive:
329
+ with zipfile.ZipFile(file, 'r') as z:
330
+ image = cv2.imdecode(np.frombuffer(z.read('image.jpg'), np.uint8), cv2.IMREAD_COLOR)
331
+ points = cv2.imdecode(np.frombuffer(z.read('points.exr'), np.uint8), cv2.IMREAD_UNCHANGED)
332
+ if 'mask.png' in z.namelist():
333
+ mask = cv2.imdecode(np.frombuffer(z.read('mask.png'), np.uint8), cv2.IMREAD_UNCHANGED) > 0
334
+ else:
335
+ mask = np.ones(image.shape[:2], dtype=bool)
336
+ else:
337
+ file = Path(file)
338
+ file.mkdir(parents=True, exist_ok=True)
339
+ image = cv2.imread(str(file / 'image.jpg'), cv2.IMREAD_COLOR)
340
+ image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
341
+ points = cv2.imread(str(file / 'points.exr'), cv2.IMREAD_UNCHANGED)
342
+ if (file /'mask.png').exists():
343
+ mask = cv2.imread(str(file / 'mask.png'), cv2.IMREAD_UNCHANGED) > 0
344
+ else:
345
+ mask = np.ones(image.shape[:2], dtype=bool)
346
+
347
+ return image, points, mask
moge/utils/pipeline.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from abc import abstractmethod
3
+ from queue import Empty, Full
4
+ from threading import Thread
5
+ from queue import Queue
6
+ from multiprocessing import Process
7
+ from threading import Thread, Event
8
+ import multiprocessing
9
+ import threading
10
+ import inspect
11
+ import time
12
+ import uuid
13
+ from copy import deepcopy
14
+ import itertools
15
+ import functools
16
+
17
+ __all__ = [
18
+ 'Node',
19
+ 'Link',
20
+ 'ConcurrentNode',
21
+ 'Worker',
22
+ 'WorkerFunction',
23
+ 'Provider',
24
+ 'ProviderFunction',
25
+ 'Sequential',
26
+ 'Batch',
27
+ 'Unbatch',
28
+ 'Parallel',
29
+ 'Graph',
30
+ 'Buffer',
31
+ ]
32
+
33
+ TERMINATE_CHECK_INTERVAL = 0.5
34
+
35
+
36
+ class _ItemWrapper:
37
+ def __init__(self, data: Any, id: Union[int, List[int]] = None):
38
+ self.data = data
39
+ self.id = id
40
+
41
+
42
+ class Terminate(Exception):
43
+ pass
44
+
45
+
46
+ def _get_queue_item(queue: Queue, terminate_flag: Event, timeout: float = None) -> _ItemWrapper:
47
+ while True:
48
+ try:
49
+ item: _ItemWrapper = queue.get(block=True, timeout=TERMINATE_CHECK_INTERVAL if timeout is None else min(timeout, TERMINATE_CHECK_INTERVAL))
50
+ if terminate_flag.is_set():
51
+ raise Terminate()
52
+ return item
53
+ except Empty:
54
+ if terminate_flag.is_set():
55
+ raise Terminate()
56
+
57
+ if timeout is not None:
58
+ timeout -= TERMINATE_CHECK_INTERVAL
59
+ if timeout <= 0:
60
+ raise Empty()
61
+
62
+
63
+ def _put_queue_item(queue: Queue, item: _ItemWrapper, terminate_flag: Event):
64
+ while True:
65
+ try:
66
+ queue.put(item, block=True, timeout=TERMINATE_CHECK_INTERVAL)
67
+ if terminate_flag.is_set():
68
+ raise Terminate()
69
+ return
70
+ except Full:
71
+ if terminate_flag.is_set():
72
+ raise Terminate()
73
+
74
+ class Node:
75
+ def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
76
+ self.input: Queue = Queue(maxsize=in_buffer_size)
77
+ self.output: Queue = Queue(maxsize=out_buffer_size)
78
+ self.in_buffer_size = in_buffer_size
79
+ self.out_buffer_size = out_buffer_size
80
+
81
+ @abstractmethod
82
+ def start(self):
83
+ pass
84
+
85
+ @abstractmethod
86
+ def terminate(self):
87
+ pass
88
+
89
+ def stop(self):
90
+ self.terminate()
91
+ self.join()
92
+
93
+ @abstractmethod
94
+ def join(self):
95
+ pass
96
+
97
+ def put(self, data: Any, key: str = None, block: bool = True) -> None:
98
+ item = _ItemWrapper(data)
99
+ self.input.put(item, block=block)
100
+
101
+ def get(self, key: str = None, block: bool = True) -> Any:
102
+ item: _ItemWrapper = self.output.get(block=block)
103
+ return item.data
104
+
105
+ def __enter__(self):
106
+ self.start()
107
+ return self
108
+
109
+ def __exit__(self, exc_type, exc_value, traceback):
110
+ self.terminate()
111
+ self.join()
112
+
113
+
114
+ class ConcurrentNode(Node):
115
+ job: Union[Thread, Process]
116
+
117
+ def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
118
+ super().__init__(in_buffer_size, out_buffer_size)
119
+ self.running_as = running_as
120
+
121
+ @abstractmethod
122
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
123
+ pass
124
+
125
+ def start(self):
126
+ if self.running_as == 'thread':
127
+ terminate_flag = threading.Event()
128
+ job = Thread(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
129
+ elif self.running_as == 'process':
130
+ terminate_flag = multiprocessing.Event()
131
+ job = Process(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
132
+ job.start()
133
+ self.job = job
134
+ self.terminate_flag = terminate_flag
135
+
136
+ def terminate(self):
137
+ self.terminate_flag.set()
138
+
139
+ def join(self):
140
+ self.job.join()
141
+
142
+
143
+ class Worker(ConcurrentNode):
144
+ def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 0, out_buffer_size: int = 0) -> None:
145
+ super().__init__(running_as, in_buffer_size, out_buffer_size)
146
+
147
+ def init(self) -> None:
148
+ """
149
+ This method is called the the thread is started, to initialize any resources that is only held in the thread.
150
+ """
151
+ pass
152
+
153
+ @abstractmethod
154
+ def work(self, *args, **kwargs) -> Union[Any, Dict[str, Any]]:
155
+ """
156
+ This method defines the job that the node should do for each input item.
157
+ A item obtained from the input queue is passed as arguments to this method, and the result is placed in the output queue.
158
+ The method is executed concurrently with other nodes.
159
+ """
160
+ pass
161
+
162
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
163
+ self.init()
164
+ try:
165
+ while True:
166
+ item = _get_queue_item(input, terminate_flag)
167
+ result = self.work(item.data)
168
+ _put_queue_item(output, _ItemWrapper(result, item.id), terminate_flag)
169
+
170
+ except Terminate:
171
+ return
172
+
173
+
174
+ class Provider(ConcurrentNode):
175
+ """
176
+ A node that provides data to successive nodes. It takes no input and provides data to the output queue.
177
+ """
178
+ def __init__(self, running_as: Literal['thread', 'process'], out_buffer_size: int = 1) -> None:
179
+ super().__init__(running_as, 0, out_buffer_size)
180
+
181
+ def init(self) -> None:
182
+ """
183
+ This method is called the the thread or process is started, to initialize any resources that is only held in the thread or process.
184
+ """
185
+ pass
186
+
187
+ @abstractmethod
188
+ def provide(self) -> Generator[Any, None, None]:
189
+ pass
190
+
191
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
192
+ self.init()
193
+ try:
194
+ for data in self.provide():
195
+ _put_queue_item(output, _ItemWrapper(data), terminate_flag)
196
+ except Terminate:
197
+ return
198
+
199
+
200
+ class WorkerFunction(Worker):
201
+ def __init__(self, fn: Callable, running_as: 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
202
+ super().__init__(running_as, in_buffer_size, out_buffer_size)
203
+ self.fn = fn
204
+
205
+ def work(self, *args, **kwargs):
206
+ return self.fn(*args, **kwargs)
207
+
208
+
209
+ class ProviderFunction(Provider):
210
+ def __init__(self, fn: Callable, running_as: 'thread', out_buffer_size: int = 1) -> None:
211
+ super().__init__(running_as, out_buffer_size)
212
+ self.fn = fn
213
+
214
+ def provide(self):
215
+ for item in self.fn():
216
+ yield item
217
+
218
+
219
+ class Link:
220
+ def __init__(self, src: Queue, dst: Queue):
221
+ self.src = src
222
+ self.dst = dst
223
+
224
+ def _thread_fn(self):
225
+ try:
226
+ while True:
227
+ item = _get_queue_item(self.src, self.terminate_flag)
228
+ _put_queue_item(self.dst, item, self.terminate_flag)
229
+ except Terminate:
230
+ return
231
+
232
+ def start(self):
233
+ self.terminate_flag = threading.Event()
234
+ self.thread = Thread(target=self._thread_fn)
235
+ self.thread.start()
236
+
237
+ def terminate(self):
238
+ self.terminate_flag.set()
239
+
240
+ def join(self):
241
+ self.thread.join()
242
+
243
+
244
+ class Graph(Node):
245
+ """
246
+ Graph pipeline of nodes and links
247
+ """
248
+ nodes: List[Node]
249
+ links: List[Link]
250
+
251
+ def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
252
+ super().__init__(in_buffer_size, out_buffer_size)
253
+ self.nodes = []
254
+ self.links = []
255
+
256
+ def add(self, node: Node):
257
+ self.nodes.append(node)
258
+
259
+ def link(self, src: Union[Node, Tuple[Node, str]], dst: Union[Node, Tuple[Node, str]]):
260
+ """
261
+ Links the output of the source node to the input of the destination node.
262
+ If the source or destination node is None, the pipeline's input or output is used.
263
+ """
264
+ src_queue = self.input if src is None else src.output
265
+ dst_queue = self.output if dst is None else dst.input
266
+ self.links.append(Link(src_queue, dst_queue))
267
+
268
+ def chain(self, nodes: Iterable[Node]):
269
+ """
270
+ Link the output of each node to the input of the next node.
271
+ """
272
+ nodes = list(nodes)
273
+ for i in range(len(nodes) - 1):
274
+ self.link(nodes[i], nodes[i + 1])
275
+
276
+ def start(self):
277
+ for node in self.nodes:
278
+ node.start()
279
+ for link in self.links:
280
+ link.start()
281
+
282
+ def terminate(self):
283
+ for node in self.nodes:
284
+ node.terminate()
285
+ for link in self.links:
286
+ link.terminate()
287
+
288
+ def join(self):
289
+ for node in self.nodes:
290
+ node.join()
291
+ for link in self.links:
292
+ link.join()
293
+
294
+ def __iter__(self):
295
+ providers = [node for node in self.nodes if isinstance(node, Provider)]
296
+ if len(providers) == 0:
297
+ raise ValueError("No provider node found in the pipeline. If you want to iterate over the pipeline, the pipeline must be driven by a provider node.")
298
+ with self:
299
+ # while all(provider.job.is_alive() for provider in providers):
300
+ while True:
301
+ yield self.get()
302
+
303
+ def __call__(self, data: Any) -> Any:
304
+ """
305
+ Submit data to the pipeline's input queue, and return the output data asynchronously.
306
+ NOTE: The pipeline must be streamed (i.e., every output item is uniquely associated with an input item) for this to work.
307
+ """
308
+ # TODO
309
+
310
+
311
+ class Sequential(Graph):
312
+ """
313
+ Pipeline of nodes in sequential order, where each node takes the output of the previous node as input.
314
+ The order of input and output items is preserved (FIFO)
315
+ """
316
+ def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
317
+ """
318
+ Initialize the pipeline with a list of nodes to execute sequentially.
319
+ ### Parameters:
320
+ - nodes: List of nodes or functions to execute sequentially. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
321
+ - function_running_as: Whether to wrap the function as a thread or process worker. Default is 'thread'.
322
+ - in_buffer_size: Maximum size of the input queue of the pipeline. Default is 0 (unlimited).
323
+ - out_buffer_size: Maximum size of the output queue of the pipeline. Default is 0 (unlimited).
324
+ """
325
+ super().__init__(in_buffer_size, out_buffer_size)
326
+ for node in nodes:
327
+ if isinstance(node, Node):
328
+ pass
329
+ elif isinstance(node, Callable):
330
+ if inspect.isgeneratorfunction(node):
331
+ node = ProviderFunction(node, function_running_as)
332
+ else:
333
+ node = WorkerFunction(node, function_running_as)
334
+ else:
335
+ raise ValueError(f"Invalid node type: {type(node)}")
336
+ self.add(node)
337
+ self.chain([None, *self.nodes, None])
338
+
339
+
340
+ class Parallel(Node):
341
+ """
342
+ A FIFO node that runs multiple nodes in parallel to process the input items. Each input item is handed to one of the nodes whoever is available.
343
+ NOTE: It is FIFO if and only if all the nested nodes are FIFO.
344
+ """
345
+ nodes: List[Node]
346
+
347
+ def __init__(self, nodes: Iterable[Node], in_buffer_size: int = 1, out_buffer_size: int = 1, function_running_as: Literal['thread', 'process'] = 'thread'):
348
+ super().__init__(in_buffer_size, out_buffer_size)
349
+ self.nodes = []
350
+ for node in nodes:
351
+ if isinstance(node, Node):
352
+ pass
353
+ elif isinstance(node, Callable):
354
+ if inspect.isgeneratorfunction(node):
355
+ node = ProviderFunction(node, function_running_as)
356
+ else:
357
+ node = WorkerFunction(node, function_running_as)
358
+ else:
359
+ raise ValueError(f"Invalid node type: {type(node)}")
360
+ self.nodes.append(node)
361
+ self.output_order = Queue()
362
+ self.lock = threading.Lock()
363
+
364
+ def _in_thread_fn(self, node: Node):
365
+ try:
366
+ while True:
367
+ with self.lock:
368
+ # A better idea: first make sure its node is vacant, then get it a new item.
369
+ # Currently we will not be able to know which node is busy util there is at least one item already waiting in the queue of the node.
370
+ # This could lead to suboptimal scheduling.
371
+ item = _get_queue_item(self.input, self.terminate_flag)
372
+ self.output_order.put(node.output)
373
+ _put_queue_item(node.input, item, self.terminate_flag)
374
+ except Terminate:
375
+ return
376
+
377
+ def _out_thread_fn(self):
378
+ try:
379
+ while True:
380
+ queue = _get_queue_item(self.output_order, self.terminate_flag)
381
+ item = _get_queue_item(queue, self.terminate_flag)
382
+ _put_queue_item(self.output, item, self.terminate_flag)
383
+ except Terminate:
384
+ return
385
+
386
+ def start(self):
387
+ self.terminate_flag = threading.Event()
388
+ self.in_threads = []
389
+ for node in self.nodes:
390
+ thread = Thread(target=self._in_thread_fn, args=(node,))
391
+ thread.start()
392
+ self.in_threads.append(thread)
393
+ thread = Thread(target=self._out_thread_fn)
394
+ thread.start()
395
+ self.out_thread = thread
396
+ for node in self.nodes:
397
+ node.start()
398
+
399
+ def terminate(self):
400
+ self.terminate_flag.set()
401
+ for node in self.nodes:
402
+ node.terminate()
403
+
404
+ def join(self):
405
+ for thread in self.in_threads:
406
+ thread.join()
407
+ self.out_thread.join()
408
+
409
+
410
+ class UnorderedParallel(Graph):
411
+ """
412
+ Pipeline of nodes in parallel, where each input item is handed to one of the nodes whoever is available.
413
+ NOTE: The order of the output items is NOT guaranteed to be the same as the input items, depending on how fast the nodes handle their input.
414
+ """
415
+ def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
416
+ """
417
+ Initialize the pipeline with a list of nodes to execute in parallel. If a function is given, it is wrapped in a worker node.
418
+ ### Parameters:
419
+ - nodes: List of nodes or functions to execute in parallel. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
420
+ - function_running_as: Whether to wrap the function as a thread or process worker. Default is 'thread'.
421
+ - in_buffer_size: Maximum size of the input queue of the pipeline. Default is 0 (unlimited).
422
+ - out_buffer_size: Maximum size of the output queue of the pipeline. Default is 0 (unlimited).
423
+ """
424
+ super().__init__(in_buffer_size, out_buffer_size)
425
+ for node in nodes:
426
+ if isinstance(node, Node):
427
+ pass
428
+ elif isinstance(node, Callable):
429
+ if inspect.isgeneratorfunction(node):
430
+ node = ProviderFunction(node, function_running_as)
431
+ else:
432
+ node = WorkerFunction(node, function_running_as)
433
+ else:
434
+ raise ValueError(f"Invalid node type: {type(node)}")
435
+ self.add(node)
436
+ for i in range(len(nodes)):
437
+ self.chain([None, self.nodes[i], None])
438
+
439
+
440
+ class Batch(ConcurrentNode):
441
+ """
442
+ Groups every `batch_size` items into a batch (a list of items) and passes the batch to successive nodes.
443
+ The `patience` parameter specifies the maximum time to wait for a batch to be filled before sending it to the next node,
444
+ i.e., when the earliest item in the batch is out of `patience` seconds, the batch is sent regardless of its size.
445
+ """
446
+ def __init__(self, batch_size: int, patience: float = None, in_buffer_size: int = 1, out_buffer_size: int = 1):
447
+ assert batch_size > 0, "Batch size must be greater than 0."
448
+ super().__init__('thread', in_buffer_size, out_buffer_size)
449
+ self.batch_size = batch_size
450
+ self.patience = patience
451
+
452
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
453
+ try:
454
+ while True:
455
+ batch_id, batch_data = [], []
456
+ # Try to fill the batch
457
+ for i in range(self.batch_size):
458
+ if i == 0 or self.patience is None:
459
+ timeout = None
460
+ else:
461
+ timeout = self.patience - (time.time() - earliest_time)
462
+ if timeout < 0:
463
+ break
464
+ try:
465
+ item = _get_queue_item(input, terminate_flag, timeout)
466
+ except Empty:
467
+ break
468
+
469
+ if i == 0:
470
+ earliest_time = time.time()
471
+ batch_data.append(item.data)
472
+ batch_id.append(item.id)
473
+
474
+ batch = _ItemWrapper(batch_data, batch_id)
475
+ _put_queue_item(output, batch, terminate_flag)
476
+ except Terminate:
477
+ return
478
+
479
+
480
+ class Unbatch(ConcurrentNode):
481
+ """
482
+ Ungroups every batch (a list of items) into individual items and passes them to successive nodes.
483
+ """
484
+ def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
485
+ super().__init__('thread', in_buffer_size, out_buffer_size)
486
+
487
+ def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
488
+ try:
489
+ while True:
490
+ batch = _get_queue_item(input, terminate_flag)
491
+ for id, data in zip(batch.id or itertools.repeat(None), batch.data):
492
+ item = _ItemWrapper(data, id)
493
+ _put_queue_item(output, item, terminate_flag)
494
+ except Terminate:
495
+ return
496
+
497
+
498
+ class Buffer(Node):
499
+ "A FIFO node that buffers items in a queue. Usefull achieve better temporal balance when its successor node has a variable processing time."
500
+ def __init__(self, size: int):
501
+ super().__init__(size, size)
502
+ self.size = size
503
+ self.input = self.output = Queue(maxsize=size)
moge/utils/tools.py ADDED
@@ -0,0 +1,240 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import time
3
+ from pathlib import Path
4
+ from numbers import Number
5
+
6
+
7
+ def catch_exception(fn):
8
+ def wrapper(*args, **kwargs):
9
+ try:
10
+ return fn(*args, **kwargs)
11
+ except Exception as e:
12
+ import traceback
13
+ print(f"Exception in {fn.__name__}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})")
14
+ traceback.print_exc(chain=False)
15
+ time.sleep(0.1)
16
+ return None
17
+ return wrapper
18
+
19
+
20
+ class CallbackOnException:
21
+ def __init__(self, callback: Callable, exception: type):
22
+ self.exception = exception
23
+ self.callback = callback
24
+
25
+ def __enter__(self):
26
+ return self
27
+
28
+ def __exit__(self, exc_type, exc_val, exc_tb):
29
+ if isinstance(exc_val, self.exception):
30
+ self.callback()
31
+ return True
32
+ return False
33
+
34
+ def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
35
+ for k, v in d.items():
36
+ if isinstance(v, dict):
37
+ for sub_key in traverse_nested_dict_keys(v):
38
+ yield (k, ) + sub_key
39
+ else:
40
+ yield (k, )
41
+
42
+
43
+ def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
44
+ for k in keys:
45
+ d = d.get(k, default)
46
+ if d is None:
47
+ break
48
+ return d
49
+
50
+ def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
51
+ for k in keys[:-1]:
52
+ d = d.setdefault(k, {})
53
+ d[keys[-1]] = value
54
+
55
+
56
+ def key_average(list_of_dicts: list) -> Dict[str, Any]:
57
+ """
58
+ Returns a dictionary with the average value of each key in the input list of dictionaries.
59
+ """
60
+ _nested_dict_keys = set()
61
+ for d in list_of_dicts:
62
+ _nested_dict_keys.update(traverse_nested_dict_keys(d))
63
+ _nested_dict_keys = sorted(_nested_dict_keys)
64
+ result = {}
65
+ for k in _nested_dict_keys:
66
+ values = [
67
+ get_nested_dict(d, k) for d in list_of_dicts
68
+ if get_nested_dict(d, k) is not None
69
+ ]
70
+ avg = sum(values) / len(values) if values else float('nan')
71
+ set_nested_dict(result, k, avg)
72
+ return result
73
+
74
+
75
+ def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
76
+ """
77
+ Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
78
+ """
79
+ items = []
80
+ if parent_key is None:
81
+ parent_key = ()
82
+ for k, v in d.items():
83
+ new_key = parent_key + (k, )
84
+ if isinstance(v, MutableMapping):
85
+ items.extend(flatten_nested_dict(v, new_key).items())
86
+ else:
87
+ items.append((new_key, v))
88
+ return dict(items)
89
+
90
+
91
+ def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
92
+ """
93
+ Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
94
+ """
95
+ result = {}
96
+ for k, v in d.items():
97
+ sub_dict = result
98
+ for k_ in k[:-1]:
99
+ if k_ not in sub_dict:
100
+ sub_dict[k_] = {}
101
+ sub_dict = sub_dict[k_]
102
+ sub_dict[k[-1]] = v
103
+ return result
104
+
105
+
106
+ def read_jsonl(file):
107
+ import json
108
+ with open(file, 'r') as f:
109
+ data = f.readlines()
110
+ return [json.loads(line) for line in data]
111
+
112
+
113
+ def write_jsonl(data: List[dict], file):
114
+ import json
115
+ with open(file, 'w') as f:
116
+ for item in data:
117
+ f.write(json.dumps(item) + '\n')
118
+
119
+
120
+ def save_metrics(save_path: Union[str, Path], all_metrics: Dict[str, List[Dict]]):
121
+ import pandas as pd
122
+ import json
123
+
124
+ with open(save_path, 'w') as f:
125
+ json.dump(all_metrics, f, indent=4)
126
+
127
+
128
+ def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
129
+ import pandas as pd
130
+ data = [flatten_nested_dict(d) for d in data]
131
+ df = pd.DataFrame(data)
132
+ df = df.sort_index(axis=1)
133
+ df.columns = pd.MultiIndex.from_tuples(df.columns)
134
+ return df
135
+
136
+
137
+ def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
138
+ if isinstance(d, str):
139
+ for old, new in mapping.items():
140
+ d = d.replace(old, new)
141
+ elif isinstance(d, list):
142
+ for i, item in enumerate(d):
143
+ d[i] = recursive_replace(item, mapping)
144
+ elif isinstance(d, dict):
145
+ for k, v in d.items():
146
+ d[k] = recursive_replace(v, mapping)
147
+ return d
148
+
149
+
150
+ class timeit:
151
+ _history: Dict[str, List['timeit']] = {}
152
+
153
+ def __init__(self, name: str = None, verbose: bool = True, multiple: bool = False):
154
+ self.name = name
155
+ self.verbose = verbose
156
+ self.start = None
157
+ self.end = None
158
+ self.multiple = multiple
159
+ if multiple and name not in timeit._history:
160
+ timeit._history[name] = []
161
+
162
+ def __call__(self, func: Callable):
163
+ import inspect
164
+ if inspect.iscoroutinefunction(func):
165
+ async def wrapper(*args, **kwargs):
166
+ with timeit(self.name or func.__qualname__):
167
+ ret = await func(*args, **kwargs)
168
+ return ret
169
+ return wrapper
170
+ else:
171
+ def wrapper(*args, **kwargs):
172
+ with timeit(self.name or func.__qualname__):
173
+ ret = func(*args, **kwargs)
174
+ return ret
175
+ return wrapper
176
+
177
+ def __enter__(self):
178
+ self.start = time.time()
179
+
180
+ @property
181
+ def time(self) -> float:
182
+ assert self.start is not None, "Time not yet started."
183
+ assert self.end is not None, "Time not yet ended."
184
+ return self.end - self.start
185
+
186
+ @property
187
+ def history(self) -> List['timeit']:
188
+ return timeit._history.get(self.name, [])
189
+
190
+ def __exit__(self, exc_type, exc_val, exc_tb):
191
+ self.end = time.time()
192
+ if self.multiple:
193
+ timeit._history[self.name].append(self)
194
+ if self.verbose:
195
+ if self.multiple:
196
+ avg = sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
197
+ print(f"{self.name or 'It'} took {avg} seconds in average.")
198
+ else:
199
+ print(f"{self.name or 'It'} took {self.time} seconds.")
200
+
201
+
202
+ def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
203
+ first = strings[0]
204
+
205
+ for start in range(len(first)):
206
+ if any(s[start] != strings[0][start] for s in strings):
207
+ break
208
+
209
+ for end in range(1, min(len(s) for s in strings)):
210
+ if any(s[-end] != first[-end] for s in strings):
211
+ break
212
+
213
+ return [s[start:len(s) - end + 1] for s in strings]
214
+
215
+
216
+ def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
217
+ from concurrent.futures import ThreadPoolExecutor
218
+ from contextlib import nullcontext
219
+ from tqdm import tqdm
220
+
221
+ if pbar is not None:
222
+ pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
223
+ else:
224
+ pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)
225
+
226
+ def decorator(fn: Callable):
227
+ with (
228
+ ThreadPoolExecutor(max_workers=num_workers) as executor,
229
+ pbar
230
+ ):
231
+ pbar.refresh()
232
+ @catch_exception
233
+ def _fn(input):
234
+ ret = fn(input)
235
+ pbar.update()
236
+ return ret
237
+ executor.map(_fn, inputs)
238
+ executor.shutdown(wait=True)
239
+
240
+ return decorator
moge/utils/vis.py ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import matplotlib
3
+
4
+
5
+ def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
6
+ if mask is None:
7
+ depth = np.where(depth > 0, depth, np.nan)
8
+ else:
9
+ depth = np.where((depth > 0) & mask, depth, np.nan)
10
+ disp = 1 / depth
11
+ if normalize:
12
+ min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.999)
13
+ disp = (disp - min_disp) / (max_disp - min_disp)
14
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp), 0)
15
+ colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3]
16
+ return colored
17
+
18
+
19
+ def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray:
20
+ if mask is not None:
21
+ depth = np.where(mask, depth, np.nan)
22
+
23
+ min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999)
24
+ depth = (depth - min_depth) / (max_depth - min_depth)
25
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](depth), 0)
26
+ colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3]
27
+ return colored
28
+
29
+
30
+ def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
31
+ if mask is not None:
32
+ disparity = np.where(mask, disparity, np.nan)
33
+
34
+ if normalize:
35
+ min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999)
36
+ disparity = (disparity - min_disp) / (max_disp - min_disp)
37
+ colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity), 0)
38
+ colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3]
39
+ return colored
40
+
41
+
42
+ def colorize_segmentation(segmentation: np.ndarray, cmap: str = 'Set1') -> np.ndarray:
43
+ colored = matplotlib.colormaps[cmap]((segmentation % 20) / 20)
44
+ colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3]
45
+ return colored
46
+
47
+
48
+ def colorize_normal(normal: np.ndarray) -> np.ndarray:
49
+ normal = normal * [0.5, -0.5, -0.5] + 0.5
50
+ normal = (normal.clip(0, 1) * 255).astype(np.uint8)
51
+ return normal
moge/utils/webfile.py ADDED
@@ -0,0 +1,73 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import requests
2
+ from typing import *
3
+
4
+ __all__ = ["WebFile"]
5
+
6
+
7
+ class WebFile:
8
+ def __init__(self, url: str, session: Optional[requests.Session] = None, headers: Optional[Dict[str, str]] = None, size: Optional[int] = None):
9
+ self.url = url
10
+ self.session = session or requests.Session()
11
+ self.session.headers.update(headers or {})
12
+ self._offset = 0
13
+ self.size = size if size is not None else self._fetch_size()
14
+
15
+ def _fetch_size(self):
16
+ with self.session.get(self.url, stream=True) as response:
17
+ response.raise_for_status()
18
+ content_length = response.headers.get("Content-Length")
19
+ if content_length is None:
20
+ raise ValueError("Missing Content-Length in header")
21
+ return int(content_length)
22
+
23
+ def _fetch_data(self, offset: int, n: int) -> bytes:
24
+ headers = {"Range": f"bytes={offset}-{min(offset + n - 1, self.size)}"}
25
+ response = self.session.get(self.url, headers=headers)
26
+ response.raise_for_status()
27
+ return response.content
28
+
29
+ def seekable(self) -> bool:
30
+ return True
31
+
32
+ def tell(self) -> int:
33
+ return self._offset
34
+
35
+ def available(self) -> int:
36
+ return self.size - self._offset
37
+
38
+ def seek(self, offset: int, whence: int = 0) -> None:
39
+ if whence == 0:
40
+ new_offset = offset
41
+ elif whence == 1:
42
+ new_offset = self._offset + offset
43
+ elif whence == 2:
44
+ new_offset = self.size + offset
45
+ else:
46
+ raise ValueError("Invalid value for whence")
47
+
48
+ self._offset = max(0, min(new_offset, self.size))
49
+
50
+ def read(self, n: Optional[int] = None) -> bytes:
51
+ if n is None or n < 0:
52
+ n = self.available()
53
+ else:
54
+ n = min(n, self.available())
55
+
56
+ if n == 0:
57
+ return b''
58
+
59
+ data = self._fetch_data(self._offset, n)
60
+ self._offset += len(data)
61
+
62
+ return data
63
+
64
+ def close(self) -> None:
65
+ pass
66
+
67
+ def __enter__(self):
68
+ return self
69
+
70
+ def __exit__(self, exc_type, exc_value, traceback):
71
+ pass
72
+
73
+
moge/utils/webzipfile.py ADDED
@@ -0,0 +1,128 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ import io
3
+ import os
4
+ from zipfile import (
5
+ ZipInfo, BadZipFile, ZipFile, ZipExtFile,
6
+ sizeFileHeader, structFileHeader, stringFileHeader,
7
+ _FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS,
8
+ _MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED
9
+ )
10
+ import struct
11
+ from requests import Session
12
+
13
+ from .webfile import WebFile
14
+
15
+
16
+ class _SharedWebFile(WebFile):
17
+ def __init__(self, webfile: WebFile, pos: int):
18
+ super().__init__(webfile.url, webfile.session, size=webfile.size)
19
+ self.seek(pos)
20
+
21
+
22
+ class WebZipFile(ZipFile):
23
+ "Lock-free version of ZipFile that reads from a WebFile, allowing for concurrent reads."
24
+ def __init__(self, url: str, session: Optional[Session] = None, headers: Optional[Dict[str, str]] = None):
25
+ """Open the ZIP file with mode read 'r', write 'w', exclusive create 'x',
26
+ or append 'a'."""
27
+ webf = WebFile(url, session=session, headers=headers)
28
+ super().__init__(webf, mode='r')
29
+
30
+ def open(self, name, mode="r", pwd=None, *, force_zip64=False):
31
+ """Return file-like object for 'name'.
32
+
33
+ name is a string for the file name within the ZIP file, or a ZipInfo
34
+ object.
35
+
36
+ mode should be 'r' to read a file already in the ZIP file, or 'w' to
37
+ write to a file newly added to the archive.
38
+
39
+ pwd is the password to decrypt files (only used for reading).
40
+
41
+ When writing, if the file size is not known in advance but may exceed
42
+ 2 GiB, pass force_zip64 to use the ZIP64 format, which can handle large
43
+ files. If the size is known in advance, it is best to pass a ZipInfo
44
+ instance for name, with zinfo.file_size set.
45
+ """
46
+ if mode not in {"r", "w"}:
47
+ raise ValueError('open() requires mode "r" or "w"')
48
+ if pwd and (mode == "w"):
49
+ raise ValueError("pwd is only supported for reading files")
50
+ if not self.fp:
51
+ raise ValueError(
52
+ "Attempt to use ZIP archive that was already closed")
53
+
54
+ assert mode == "r", "Only read mode is supported for now"
55
+
56
+ # Make sure we have an info object
57
+ if isinstance(name, ZipInfo):
58
+ # 'name' is already an info object
59
+ zinfo = name
60
+ elif mode == 'w':
61
+ zinfo = ZipInfo(name)
62
+ zinfo.compress_type = self.compression
63
+ zinfo._compresslevel = self.compresslevel
64
+ else:
65
+ # Get info object for name
66
+ zinfo = self.getinfo(name)
67
+
68
+ if mode == 'w':
69
+ return self._open_to_write(zinfo, force_zip64=force_zip64)
70
+
71
+ if self._writing:
72
+ raise ValueError("Can't read from the ZIP file while there "
73
+ "is an open writing handle on it. "
74
+ "Close the writing handle before trying to read.")
75
+
76
+ # Open for reading:
77
+ self._fileRefCnt += 1
78
+ zef_file = _SharedWebFile(self.fp, zinfo.header_offset)
79
+
80
+ try:
81
+ # Skip the file header:
82
+ fheader = zef_file.read(sizeFileHeader)
83
+ if len(fheader) != sizeFileHeader:
84
+ raise BadZipFile("Truncated file header")
85
+ fheader = struct.unpack(structFileHeader, fheader)
86
+ if fheader[_FH_SIGNATURE] != stringFileHeader:
87
+ raise BadZipFile("Bad magic number for file header")
88
+
89
+ fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
90
+ if fheader[_FH_EXTRA_FIELD_LENGTH]:
91
+ zef_file.seek(fheader[_FH_EXTRA_FIELD_LENGTH], whence=1)
92
+
93
+ if zinfo.flag_bits & _MASK_COMPRESSED_PATCH:
94
+ # Zip 2.7: compressed patched data
95
+ raise NotImplementedError("compressed patched data (flag bit 5)")
96
+
97
+ if zinfo.flag_bits & _MASK_STRONG_ENCRYPTION:
98
+ # strong encryption
99
+ raise NotImplementedError("strong encryption (flag bit 6)")
100
+
101
+ if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & _MASK_UTF_FILENAME:
102
+ # UTF-8 filename
103
+ fname_str = fname.decode("utf-8")
104
+ else:
105
+ fname_str = fname.decode(self.metadata_encoding or "cp437")
106
+
107
+ if fname_str != zinfo.orig_filename:
108
+ raise BadZipFile(
109
+ 'File name in directory %r and header %r differ.'
110
+ % (zinfo.orig_filename, fname))
111
+
112
+ # check for encrypted flag & handle password
113
+ is_encrypted = zinfo.flag_bits & _MASK_ENCRYPTED
114
+ if is_encrypted:
115
+ if not pwd:
116
+ pwd = self.pwd
117
+ if pwd and not isinstance(pwd, bytes):
118
+ raise TypeError("pwd: expected bytes, got %s" % type(pwd).__name__)
119
+ if not pwd:
120
+ raise RuntimeError("File %r is encrypted, password "
121
+ "required for extraction" % name)
122
+ else:
123
+ pwd = None
124
+
125
+ return ZipExtFile(zef_file, mode, zinfo, pwd, True)
126
+ except:
127
+ zef_file.close()
128
+ raise
packages.txt ADDED
@@ -0,0 +1 @@
 
 
1
+ python3-opencv
requirements.txt ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ opencv-python
2
+ plyfile
3
+ pygltflib
4
+ transformers
5
+ scikit-learn
utils3d/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ A package for common utility functions in 3D computer graphics and vision. Providing NumPy utilities in `utils3d.numpy`, PyTorch utilities in `utils3d.torch`, and IO utilities in `utils3d.io`.
3
+ """
4
+ import importlib
5
+
6
+ __all__ = ['numpy', 'torch', 'io']
7
+
8
+ def __getattr__(module_name: str):
9
+ return importlib.import_module(f'.{module_name}', __package__)
10
+
11
+ if __name__ == '__main__':
12
+ from . import torch
13
+ from . import numpy
14
+ from . import io
utils3d/io/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from .wavefront_obj import *
2
+ from .colmap import *
3
+ from .ply import *
4
+ from .glb import *
utils3d/io/colmap.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+ from scipy.spatial.transform import Rotation
6
+
7
+
8
+ __all__ = ['read_extrinsics_from_colmap', 'read_intrinsics_from_colmap', 'write_extrinsics_as_colmap', 'write_intrinsics_as_colmap']
9
+
10
+
11
+ def write_extrinsics_as_colmap(file: Union[str, Path], extrinsics: np.ndarray, image_names: Union[str, List[str]] = 'image_{i:04d}.png', camera_ids: List[int] = None):
12
+ """
13
+ Write extrinsics to colmap `images.txt` file.
14
+ Args:
15
+ file: Path to `images.txt` file.
16
+ extrinsics: (N, 4, 4) array of extrinsics.
17
+ image_names: str or List of str, image names. Length is N.
18
+ If str, it should be a format string with `i` as the index. (i starts from 1, in correspondence with IMAGE_ID in colmap)
19
+ camera_ids: List of int, camera ids. Length is N.
20
+ If None, it will be set to [1, 2, ..., N].
21
+ """
22
+ assert extrinsics.shape[1:] == (4, 4) and extrinsics.ndim == 3 or extrinsics.shape == (4, 4)
23
+ if extrinsics.ndim == 2:
24
+ extrinsics = extrinsics[np.newaxis, ...]
25
+ quats = Rotation.from_matrix(extrinsics[:, :3, :3]).as_quat()
26
+ trans = extrinsics[:, :3, 3]
27
+ if camera_ids is None:
28
+ camera_ids = list(range(1, len(extrinsics) + 1))
29
+ if isinstance(image_names, str):
30
+ image_names = [image_names.format(i=i) for i in range(1, len(extrinsics) + 1)]
31
+ assert len(extrinsics) == len(image_names) == len(camera_ids), \
32
+ f'Number of extrinsics ({len(extrinsics)}), image_names ({len(image_names)}), and camera_ids ({len(camera_ids)}) must be the same'
33
+ with open(file, 'w') as fp:
34
+ print("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME", file=fp)
35
+ for i, (quat, t, name, camera_id) in enumerate(zip(quats.tolist(), trans.tolist(), image_names, camera_ids)):
36
+ # Colmap has wxyz order while scipy.spatial.transform.Rotation has xyzw order. Haha, wcnm.
37
+ qx, qy, qz, qw = quat
38
+ tx, ty, tz = t
39
+ print(f'{i + 1} {qw:f} {qx:f} {qy:f} {qz:f} {tx:f} {ty:f} {tz:f} {camera_id:d} {name}', file=fp)
40
+ print()
41
+
42
+
43
+ def write_intrinsics_as_colmap(file: Union[str, Path], intrinsics: np.ndarray, width: int, height: int, normalized: bool = False):
44
+ """
45
+ Write intrinsics to colmap `cameras.txt` file. Currently only support PINHOLE model (no distortion)
46
+ Args:
47
+ file: Path to `cameras.txt` file.
48
+ intrinsics: (N, 3, 3) array of intrinsics.
49
+ width: Image width.
50
+ height: Image height.
51
+ normalized: Whether the intrinsics are normalized. If True, the intrinsics will unnormalized for writing.
52
+ """
53
+ assert intrinsics.shape[1:] == (3, 3) and intrinsics.ndim == 3 or intrinsics.shape == (3, 3)
54
+ if intrinsics.ndim == 2:
55
+ intrinsics = intrinsics[np.newaxis, ...]
56
+ if normalized:
57
+ intrinsics = intrinsics * np.array([width, height, 1])[:, None]
58
+ with open(file, 'w') as fp:
59
+ print("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]", file=fp)
60
+ for i, intr in enumerate(intrinsics):
61
+ fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2]
62
+ print(f'{i + 1} PINHOLE {width:d} {height:d} {fx:f} {fy:f} {cx:f} {cy:f}', file=fp)
63
+
64
+
65
+ def read_extrinsics_from_colmap(file: Union[str, Path]) -> Union[np.ndarray, List[int], List[str]]:
66
+ """
67
+ Read extrinsics from colmap `images.txt` file.
68
+ Args:
69
+ file: Path to `images.txt` file.
70
+ Returns:
71
+ extrinsics: (N, 4, 4) array of extrinsics.
72
+ camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1.
73
+ image_names: List of str, image names. Length is N.
74
+ """
75
+ with open(file) as fp:
76
+ lines = fp.readlines()
77
+ image_names, quats, trans, camera_ids = [], [], [], []
78
+ i_line = 0
79
+ for line in lines:
80
+ line = line.strip()
81
+ if line.startswith('#'):
82
+ continue
83
+ i_line += 1
84
+ if i_line % 2 == 0:
85
+ continue
86
+ image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name = line.split()
87
+ quats.append([float(qx), float(qy), float(qz), float(qw)])
88
+ trans.append([float(tx), float(ty), float(tz)])
89
+ camera_ids.append(int(camera_id))
90
+ image_names.append(name)
91
+
92
+ quats = np.array(quats, dtype=np.float32)
93
+ trans = np.array(trans, dtype=np.float32)
94
+ rotation = Rotation.from_quat(quats).as_matrix()
95
+ extrinsics = np.concatenate([
96
+ np.concatenate([rotation, trans[..., None]], axis=-1),
97
+ np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :].repeat(len(quats), axis=0)
98
+ ], axis=-2)
99
+
100
+ return extrinsics, camera_ids, image_names
101
+
102
+
103
+ def read_intrinsics_from_colmap(file: Union[str, Path], normalize: bool = False) -> Tuple[List[int], np.ndarray, np.ndarray]:
104
+ """
105
+ Read intrinsics from colmap `cameras.txt` file.
106
+ Args:
107
+ file: Path to `cameras.txt` file.
108
+ normalize: Whether to normalize the intrinsics. If True, the intrinsics will be normalized. (mapping coordinates to [0, 1] range)
109
+ Returns:
110
+ camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1.
111
+ intrinsics: (N, 3, 3) array of intrinsics.
112
+ distortions: (N, 5) array of distortions.
113
+ """
114
+ with open(file) as fp:
115
+ lines = fp.readlines()
116
+ intrinsics, distortions, camera_ids = [], [], []
117
+ for line in lines:
118
+ line = line.strip()
119
+ if not line or line.startswith('#'):
120
+ continue
121
+ camera_id, model, width, height, *params = line.split()
122
+ camera_id, width, height = int(camera_id), int(width), int(height)
123
+ if model == 'PINHOLE':
124
+ fx, fy, cx, cy = map(float, params[:4])
125
+ k1 = k2 = k3 = p1 = p2 = 0.0
126
+ elif model == 'OPENCV':
127
+ fx, fy, cx, cy, k1, k2, p1, p2, k3 = *map(float, params[:8]), 0.0
128
+ elif model == 'SIMPLE_RADIAL':
129
+ f, cx, cy, k = map(float, params[:4])
130
+ fx = fy = f
131
+ k1, k2, p1, p2, k3 = k, 0.0, 0.0, 0.0, 0.0
132
+ camera_ids.append(camera_id)
133
+ if normalize:
134
+ fx, fy, cx, cy = fx / width, fy / height, cx / width, cy / height
135
+ intrinsics.append([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
136
+ distortions.append([k1, k2, p1, p2, k3])
137
+ intrinsics = np.array(intrinsics, dtype=np.float32)
138
+ distortions = np.array(distortions, dtype=np.float32)
139
+ return camera_ids, intrinsics, distortions
utils3d/io/glb.py ADDED
@@ -0,0 +1,105 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import *
2
+ from pathlib import Path
3
+
4
+ import numpy as np
5
+
6
+
7
+ def write_glb(path: Union[str, Path], vertices: np.ndarray, faces: np.ndarray, vertex_colors: np.ndarray = None, uv: np.ndarray = None):
8
+ import pygltflib
9
+
10
+ has_colors = vertex_colors is not None
11
+ has_uv = uv is not None
12
+
13
+ triangles_bytes = faces.astype(np.uint32).flatten().tobytes()
14
+ vertices_bytes = vertices.astype(np.float32).tobytes()
15
+ vertex_colors_bytes = vertex_colors.astype(np.float32).tobytes() if has_colors else None
16
+ uv_bytes = uv.astype(np.float32).tobytes() if has_uv else None
17
+
18
+
19
+ gltf = pygltflib.GLTF2(
20
+ scene=0,
21
+ scenes=[pygltflib.Scene(nodes=[0])],
22
+ nodes=[pygltflib.Node(mesh=0)],
23
+ meshes=[
24
+ pygltflib.Mesh(
25
+ primitives=[
26
+ pygltflib.Primitive(
27
+ attributes=pygltflib.Attributes(
28
+ POSITION=1,
29
+ COLOR_0=2 if has_colors else None,
30
+ TEXCOORD_0=2 + has_colors if has_uv else None
31
+ ),
32
+ indices=0
33
+ )
34
+ ]
35
+ )
36
+ ],
37
+ accessors=list(filter(None, [
38
+ pygltflib.Accessor( # triangles accessor
39
+ bufferView=0,
40
+ componentType=pygltflib.UNSIGNED_INT,
41
+ count=faces.size,
42
+ type=pygltflib.SCALAR,
43
+ max=[int(faces.max())],
44
+ min=[int(faces.min())],
45
+ ),
46
+ pygltflib.Accessor( # vertices accessor
47
+ bufferView=1,
48
+ componentType=pygltflib.FLOAT,
49
+ count=len(vertices),
50
+ type=pygltflib.VEC3,
51
+ max=vertices.max(axis=0).tolist(),
52
+ min=vertices.min(axis=0).tolist(),
53
+ ),
54
+ pygltflib.Accessor( # vertex colors accessor
55
+ bufferView=2,
56
+ componentType=pygltflib.FLOAT,
57
+ count=len(vertices),
58
+ type=pygltflib.VEC3,
59
+ max=vertex_colors.max(axis=0).tolist(),
60
+ min=vertex_colors.min(axis=0).tolist(),
61
+ ) if has_colors else None,
62
+ pygltflib.Accessor( # uv accessor
63
+ bufferView=3,
64
+ componentType=pygltflib.FLOAT,
65
+ count=len(uv),
66
+ type=pygltflib.VEC2,
67
+ max=uv.max(axis=0).tolist(),
68
+ min=uv.min(axis=0).tolist(),
69
+ ) if has_uv else None,
70
+ ])),
71
+ bufferViews=list(filter(None, [
72
+ pygltflib.BufferView( # triangles buffer view
73
+ buffer=0,
74
+ byteLength=len(triangles_bytes),
75
+ target=pygltflib.ELEMENT_ARRAY_BUFFER,
76
+ ),
77
+ pygltflib.BufferView( # vertices buffer view
78
+ buffer=0,
79
+ byteOffset=len(triangles_bytes),
80
+ byteLength=len(vertices_bytes),
81
+ target=pygltflib.ARRAY_BUFFER,
82
+ ),
83
+ pygltflib.BufferView( # vertex colors buffer view
84
+ buffer=0,
85
+ byteOffset=len(triangles_bytes) + len(vertices_bytes),
86
+ byteLength=len(vertex_colors_bytes),
87
+ target=pygltflib.ARRAY_BUFFER,
88
+ ) if has_colors else None,
89
+ pygltflib.BufferView( # uv buffer view
90
+ buffer=0,
91
+ byteOffset=len(triangles_bytes) + len(vertices_bytes) + (len(vertex_colors_bytes) if has_colors else 0),
92
+ byteLength=len(uv_bytes),
93
+ target=pygltflib.ARRAY_BUFFER,
94
+ ) if has_uv else None,
95
+ ])),
96
+ buffers=[
97
+ pygltflib.Buffer(
98
+ byteLength=len(triangles_bytes) + len(vertices_bytes) + (len(vertex_colors_bytes) if has_colors else 0) + (len(uv_bytes) if has_uv else 0),
99
+ )
100
+ ]
101
+ )
102
+ gltf.set_binary_blob(triangles_bytes + vertices_bytes + (vertex_colors_bytes or b'') + (uv_bytes or b''))
103
+ with open(path, 'wb') as f:
104
+ for chunk in gltf.save_to_bytes():
105
+ f.write(chunk)
utils3d/io/ply.py ADDED
@@ -0,0 +1,104 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+
3
+ from typing import *
4
+ from pathlib import Path
5
+
6
+
7
+ def read_ply(
8
+ file: Union[str, Path],
9
+ encoding: Union[str, None] = None,
10
+ ignore_unknown: bool = False
11
+ ) -> Tuple[np.ndarray, np.ndarray]:
12
+ """
13
+ Read .ply file, without preprocessing.
14
+
15
+ Args:
16
+ file (Any): filepath
17
+ encoding (str, optional):
18
+
19
+ Returns:
20
+ Tuple[np.ndarray, np.ndarray]: vertices, faces
21
+ """
22
+ import plyfile
23
+ plydata = plyfile.PlyData.read(file)
24
+ vertices = np.stack([plydata['vertex'][k] for k in ['x', 'y', 'z']], axis=-1)
25
+ if 'face' in plydata:
26
+ faces = np.array(plydata['face']['vertex_indices'].tolist())
27
+ else:
28
+ faces = None
29
+ return vertices, faces
30
+
31
+
32
+ def write_ply(
33
+ file: Union[str, Path],
34
+ vertices: np.ndarray,
35
+ faces: np.ndarray = None,
36
+ edges: np.ndarray = None,
37
+ vertex_colors: np.ndarray = None,
38
+ edge_colors: np.ndarray = None,
39
+ text: bool = False
40
+ ):
41
+ """
42
+ Write .ply file, without preprocessing.
43
+
44
+ Args:
45
+ file (Any): filepath
46
+ vertices (np.ndarray): [N, 3]
47
+ faces (np.ndarray): [T, E]
48
+ edges (np.ndarray): [E, 2]
49
+ vertex_colors (np.ndarray, optional): [N, 3]. Defaults to None.
50
+ edge_colors (np.ndarray, optional): [E, 3]. Defaults to None.
51
+ text (bool, optional): save data in text format. Defaults to False.
52
+ """
53
+ import plyfile
54
+ assert vertices.ndim == 2 and vertices.shape[1] == 3
55
+ vertices = vertices.astype(np.float32)
56
+ if faces is not None:
57
+ assert faces.ndim == 2
58
+ faces = faces.astype(np.int32)
59
+ if edges is not None:
60
+ assert edges.ndim == 2 and edges.shape[1] == 2
61
+ edges = edges.astype(np.int32)
62
+
63
+ if vertex_colors is not None:
64
+ assert vertex_colors.ndim == 2 and vertex_colors.shape[1] == 3
65
+ if vertex_colors.dtype in [np.float32, np.float64]:
66
+ vertex_colors = vertex_colors * 255
67
+ vertex_colors = np.clip(vertex_colors, 0, 255).astype(np.uint8)
68
+ vertices_data = np.zeros(len(vertices), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
69
+ vertices_data['x'] = vertices[:, 0]
70
+ vertices_data['y'] = vertices[:, 1]
71
+ vertices_data['z'] = vertices[:, 2]
72
+ vertices_data['red'] = vertex_colors[:, 0]
73
+ vertices_data['green'] = vertex_colors[:, 1]
74
+ vertices_data['blue'] = vertex_colors[:, 2]
75
+ else:
76
+ vertices_data = np.array([tuple(v) for v in vertices], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
77
+
78
+ if faces is not None:
79
+ faces_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (faces.shape[1],))])
80
+ faces_data['vertex_indices'] = faces
81
+
82
+ if edges is not None:
83
+ if edge_colors is not None:
84
+ assert edge_colors.ndim == 2 and edge_colors.shape[1] == 3
85
+ if edge_colors.dtype in [np.float32, np.float64]:
86
+ edge_colors = edge_colors * 255
87
+ edge_colors = np.clip(edge_colors, 0, 255).astype(np.uint8)
88
+ edges_data = np.zeros(len(edges), dtype=[('vertex1', 'i4'), ('vertex2', 'i4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
89
+ edges_data['vertex1'] = edges[:, 0]
90
+ edges_data['vertex2'] = edges[:, 1]
91
+ edges_data['red'] = edge_colors[:, 0]
92
+ edges_data['green'] = edge_colors[:, 1]
93
+ edges_data['blue'] = edge_colors[:, 2]
94
+ else:
95
+ edges_data = np.array([tuple(e) for e in edges], dtype=[('vertex1', 'i4'), ('vertex2', 'i4')])
96
+
97
+ ply_data = [plyfile.PlyElement.describe(vertices_data, 'vertex')]
98
+ if faces is not None:
99
+ ply_data.append(plyfile.PlyElement.describe(faces_data, 'face'))
100
+ if edges is not None:
101
+ ply_data.append(plyfile.PlyElement.describe(edges_data, 'edge'))
102
+
103
+ plyfile.PlyData(ply_data, text=text).write(file)
104
+
utils3d/io/wavefront_obj.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from io import TextIOWrapper
2
+ from typing import Dict, Any, Union, Iterable
3
+ import numpy as np
4
+ from pathlib import Path
5
+
6
+ __all__ = [
7
+ 'read_obj',
8
+ 'write_obj',
9
+ 'simple_write_obj'
10
+ ]
11
+
12
+ def read_obj(
13
+ file : Union[str, Path, TextIOWrapper],
14
+ encoding: Union[str, None] = None,
15
+ ignore_unknown: bool = False
16
+ ):
17
+ """
18
+ Read wavefront .obj file, without preprocessing.
19
+
20
+ Why bothering having this read_obj() while we already have other libraries like `trimesh`?
21
+ This function read the raw format from .obj file and keeps the order of vertices and faces,
22
+ while trimesh which involves modification like merge/split vertices, which could break the orders of vertices and faces,
23
+ Those libraries are commonly aiming at geometry processing and rendering supporting various formats.
24
+ If you want mesh geometry processing, you may turn to `trimesh` for more features.
25
+
26
+ ### Parameters
27
+ `file` (str, Path, TextIOWrapper): filepath or file object
28
+ encoding (str, optional):
29
+
30
+ ### Returns
31
+ obj (dict): A dict containing .obj components
32
+ {
33
+ 'mtllib': [],
34
+ 'v': [[0,1, 0.2, 1.0], [1.2, 0.0, 0.0], ...],
35
+ 'vt': [[0.5, 0.5], ...],
36
+ 'vn': [[0., 0.7, 0.7], [0., -0.7, 0.7], ...],
37
+ 'f': [[0, 1, 2], [2, 3, 4],...],
38
+ 'usemtl': [{'name': 'mtl1', 'f': 7}]
39
+ }
40
+ """
41
+ if hasattr(file,'read'):
42
+ lines = file.read().splitlines()
43
+ else:
44
+ with open(file, 'r', encoding=encoding) as fp:
45
+ lines = fp.read().splitlines()
46
+ mtllib = []
47
+ v, vt, vn, vp = [], [], [], [] # Vertex coordinates, Vertex texture coordinate, Vertex normal, Vertex parameter
48
+ f, ft, fn = [], [], [] # Face indices, Face texture indices, Face normal indices
49
+ o = []
50
+ s = []
51
+ usemtl = []
52
+
53
+ def pad(l: list, n: Any):
54
+ return l + [n] * (3 - len(l))
55
+
56
+ for i, line in enumerate(lines):
57
+ sq = line.strip().split()
58
+ if len(sq) == 0:
59
+ continue
60
+ if sq[0] == 'v':
61
+ assert 4 <= len(sq) <= 5, f'Invalid format of line {i}: {line}'
62
+ v.append([float(e) for e in sq[1:]][:3])
63
+ elif sq[0] == 'vt':
64
+ assert 3 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
65
+ vt.append([float(e) for e in sq[1:]][:2])
66
+ elif sq[0] == 'vn':
67
+ assert len(sq) == 4, f'Invalid format of line {i}: {line}'
68
+ vn.append([float(e) for e in sq[1:]])
69
+ elif sq[0] == 'vp':
70
+ assert 2 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
71
+ vp.append(pad([float(e) for e in sq[1:]], 0))
72
+ elif sq[0] == 'f':
73
+ spliting = [pad([int(j) - 1 for j in e.split('/')], -1) for e in sq[1:]]
74
+ f.append([e[0] for e in spliting])
75
+ ft.append([e[1] for e in spliting])
76
+ fn.append([e[2] for e in spliting])
77
+ elif sq[0] == 'usemtl':
78
+ assert len(sq) == 2
79
+ usemtl.append((sq[1], len(f)))
80
+ elif sq[0] == 'o':
81
+ assert len(sq) == 2
82
+ o.append((sq[1], len(f)))
83
+ elif sq[0] == 's':
84
+ s.append((sq[1], len(f)))
85
+ elif sq[0] == 'mtllib':
86
+ assert len(sq) == 2
87
+ mtllib.append(sq[1])
88
+ elif sq[0][0] == '#':
89
+ continue
90
+ else:
91
+ if not ignore_unknown:
92
+ raise Exception(f'Unknown keyword {sq[0]}')
93
+
94
+ min_poly_vertices = min(len(f) for f in f)
95
+ max_poly_vertices = max(len(f) for f in f)
96
+
97
+ return {
98
+ 'mtllib': mtllib,
99
+ 'v': np.array(v, dtype=np.float32),
100
+ 'vt': np.array(vt, dtype=np.float32),
101
+ 'vn': np.array(vn, dtype=np.float32),
102
+ 'vp': np.array(vp, dtype=np.float32),
103
+ 'f': np.array(f, dtype=np.int32) if min_poly_vertices == max_poly_vertices else f,
104
+ 'ft': np.array(ft, dtype=np.int32) if min_poly_vertices == max_poly_vertices else ft,
105
+ 'fn': np.array(fn, dtype=np.int32) if min_poly_vertices == max_poly_vertices else fn,
106
+ 'o': o,
107
+ 's': s,
108
+ 'usemtl': usemtl,
109
+ }
110
+
111
+
112
+ def write_obj(
113
+ file: Union[str, Path],
114
+ obj: Dict[str, Any],
115
+ encoding: Union[str, None] = None
116
+ ):
117
+ with open(file, 'w', encoding=encoding) as fp:
118
+ for k in ['v', 'vt', 'vn', 'vp']:
119
+ if k not in obj:
120
+ continue
121
+ for v in obj[k]:
122
+ print(k, *map(float, v), file=fp)
123
+ for f in obj['f']:
124
+ print('f', *((str('/').join(map(int, i)) if isinstance(int(i), Iterable) else i) for i in f), file=fp)
125
+
126
+
127
+ def simple_write_obj(
128
+ file: Union[str, Path],
129
+ vertices: np.ndarray,
130
+ faces: np.ndarray,
131
+ encoding: Union[str, None] = None
132
+ ):
133
+ """
134
+ Write wavefront .obj file, without preprocessing.
135
+
136
+ Args:
137
+ vertices (np.ndarray): [N, 3]
138
+ faces (np.ndarray): [T, 3]
139
+ file (Any): filepath
140
+ encoding (str, optional):
141
+ """
142
+ with open(file, 'w', encoding=encoding) as fp:
143
+ for v in vertices:
144
+ print('v', *map(float, v), file=fp)
145
+ for f in faces:
146
+ print('f', *map(int, f + 1), file=fp)
utils3d/numpy/__init__.py ADDED
@@ -0,0 +1,135 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ 3D utility functions workings with NumPy.
3
+ """
4
+ import importlib
5
+ import itertools
6
+ import numpy
7
+
8
+
9
+ __modules_all__ = {
10
+ 'mesh':[
11
+ 'triangulate',
12
+ 'compute_face_normal',
13
+ 'compute_face_angle',
14
+ 'compute_vertex_normal',
15
+ 'compute_vertex_normal_weighted',
16
+ 'remove_corrupted_faces',
17
+ 'merge_duplicate_vertices',
18
+ 'remove_unreferenced_vertices',
19
+ 'subdivide_mesh_simple',
20
+ 'mesh_relations',
21
+ 'flatten_mesh_indices'
22
+ ],
23
+ 'quadmesh': [
24
+ 'calc_quad_candidates',
25
+ 'calc_quad_distortion',
26
+ 'calc_quad_direction',
27
+ 'calc_quad_smoothness',
28
+ 'sovle_quad',
29
+ 'sovle_quad_qp',
30
+ 'tri_to_quad'
31
+ ],
32
+ 'utils': [
33
+ 'sliding_window_1d',
34
+ 'sliding_window_nd',
35
+ 'sliding_window_2d',
36
+ 'max_pool_1d',
37
+ 'max_pool_2d',
38
+ 'max_pool_nd',
39
+ 'depth_edge',
40
+ 'depth_aliasing',
41
+ 'interpolate',
42
+ 'image_scrcoord',
43
+ 'image_uv',
44
+ 'image_pixel_center',
45
+ 'image_pixel',
46
+ 'image_mesh',
47
+ 'image_mesh_from_depth',
48
+ 'depth_to_normal',
49
+ 'point_to_normal',
50
+ 'chessboard',
51
+ 'cube',
52
+ 'square',
53
+ 'camera_frustum',
54
+ ],
55
+ 'transforms': [
56
+ 'perspective',
57
+ 'perspective_from_fov',
58
+ 'perspective_from_fov_xy',
59
+ 'intrinsics_from_focal_center',
60
+ 'intrinsics_from_fov',
61
+ 'view_look_at',
62
+ 'extrinsics_look_at',
63
+ 'perspective_to_intrinsics',
64
+ 'perspective_to_near_far',
65
+ 'intrinsics_to_perspective',
66
+ 'extrinsics_to_view',
67
+ 'view_to_extrinsics',
68
+ 'normalize_intrinsics',
69
+ 'crop_intrinsics',
70
+ 'pixel_to_uv',
71
+ 'pixel_to_ndc',
72
+ 'uv_to_pixel',
73
+ 'project_depth',
74
+ 'depth_buffer_to_linear',
75
+ 'unproject_cv',
76
+ 'unproject_gl',
77
+ 'project_cv',
78
+ 'project_gl',
79
+ 'quaternion_to_matrix',
80
+ 'axis_angle_to_matrix',
81
+ 'matrix_to_quaternion',
82
+ 'extrinsics_to_essential',
83
+ 'euler_axis_angle_rotation',
84
+ 'euler_angles_to_matrix',
85
+ 'skew_symmetric',
86
+ 'rotation_matrix_from_vectors',
87
+ 'ray_intersection',
88
+ 'se3_matrix',
89
+ 'slerp_quaternion',
90
+ 'slerp_vector',
91
+ 'lerp',
92
+ 'lerp_se3_matrix',
93
+ 'piecewise_lerp',
94
+ 'piecewise_lerp_se3_matrix',
95
+ 'apply_transform'
96
+ ],
97
+ 'spline': [
98
+ 'linear_spline_interpolate',
99
+ ],
100
+ 'rasterization': [
101
+ 'RastContext',
102
+ 'rasterize_triangle_faces',
103
+ 'rasterize_edges',
104
+ 'texture',
105
+ 'warp_image_by_depth',
106
+ ],
107
+ }
108
+
109
+
110
+ __all__ = list(itertools.chain(*__modules_all__.values()))
111
+
112
+ def __getattr__(name):
113
+ try:
114
+ return globals()[name]
115
+ except KeyError:
116
+ pass
117
+
118
+ try:
119
+ module_name = next(m for m in __modules_all__ if name in __modules_all__[m])
120
+ except StopIteration:
121
+ raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
122
+ module = importlib.import_module(f'.{module_name}', __name__)
123
+ for key in __modules_all__[module_name]:
124
+ globals()[key] = getattr(module, key)
125
+
126
+ return globals()[name]
127
+
128
+
129
+ if __name__ == '__main__':
130
+ from .quadmesh import *
131
+ from .transforms import *
132
+ from .mesh import *
133
+ from .utils import *
134
+ from .rasterization import *
135
+ from .spline import *
utils3d/numpy/_helpers.py ADDED
@@ -0,0 +1,88 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # decorator
2
+ import numpy as np
3
+ from numbers import Number
4
+ import inspect
5
+
6
+
7
+ def get_args_order(func, args, kwargs):
8
+ """
9
+ Get the order of the arguments of a function.
10
+ """
11
+ names = inspect.getfullargspec(func).args
12
+ names_idx = {name: i for i, name in enumerate(names)}
13
+ args_order = []
14
+ kwargs_order = {}
15
+ for name, arg in kwargs.items():
16
+ if name in names:
17
+ kwargs_order[name] = names_idx[name]
18
+ names.remove(name)
19
+ for i, arg in enumerate(args):
20
+ if i < len(names):
21
+ args_order.append(names_idx[names[i]])
22
+ return args_order, kwargs_order
23
+
24
+
25
+ def broadcast_args(args, kwargs, args_dim, kwargs_dim):
26
+ spatial = []
27
+ for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())):
28
+ if isinstance(arg, np.ndarray) and arg_dim is not None:
29
+ arg_spatial = arg.shape[:arg.ndim-arg_dim]
30
+ if len(arg_spatial) > len(spatial):
31
+ spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
32
+ for j in range(len(arg_spatial)):
33
+ if spatial[-j] < arg_spatial[-j]:
34
+ if spatial[-j] == 1:
35
+ spatial[-j] = arg_spatial[-j]
36
+ else:
37
+ raise ValueError("Cannot broadcast arguments.")
38
+ for i, arg in enumerate(args):
39
+ if isinstance(arg, np.ndarray) and args_dim[i] is not None:
40
+ args[i] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]])
41
+ for key, arg in kwargs.items():
42
+ if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
43
+ kwargs[key] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]])
44
+ return args, kwargs, spatial
45
+
46
+
47
+ def batched(*dims):
48
+ """
49
+ Decorator that allows a function to be called with batched arguments.
50
+ """
51
+ def decorator(func):
52
+ def wrapper(*args, **kwargs):
53
+ args = list(args)
54
+ # get arguments dimensions
55
+ args_order, kwargs_order = get_args_order(func, args, kwargs)
56
+ args_dim = [dims[i] for i in args_order]
57
+ kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
58
+ # convert to numpy array
59
+ for i, arg in enumerate(args):
60
+ if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
61
+ args[i] = np.array(arg)
62
+ for key, arg in kwargs.items():
63
+ if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None:
64
+ kwargs[key] = np.array(arg)
65
+ # broadcast arguments
66
+ args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
67
+ for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
68
+ if isinstance(arg, np.ndarray) and arg_dim is not None:
69
+ args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]])
70
+ for key, arg in kwargs.items():
71
+ if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
72
+ kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]])
73
+ # call function
74
+ results = func(*args, **kwargs)
75
+ type_results = type(results)
76
+ results = list(results) if isinstance(results, (tuple, list)) else [results]
77
+ # restore spatial dimensions
78
+ for i, result in enumerate(results):
79
+ results[i] = result.reshape([*spatial, *result.shape[1:]])
80
+ if type_results == tuple:
81
+ results = tuple(results)
82
+ elif type_results == list:
83
+ results = list(results)
84
+ else:
85
+ results = results[0]
86
+ return results
87
+ return wrapper
88
+ return decorator
utils3d/numpy/mesh.py ADDED
@@ -0,0 +1,355 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ from typing import *
3
+ from ._helpers import batched
4
+
5
+
6
+ __all__ = [
7
+ 'triangulate',
8
+ 'compute_face_normal',
9
+ 'compute_face_angle',
10
+ 'compute_vertex_normal',
11
+ 'compute_vertex_normal_weighted',
12
+ 'remove_corrupted_faces',
13
+ 'merge_duplicate_vertices',
14
+ 'remove_unreferenced_vertices',
15
+ 'subdivide_mesh_simple',
16
+ 'mesh_relations',
17
+ 'flatten_mesh_indices'
18
+ ]
19
+
20
+
21
+ def triangulate(
22
+ faces: np.ndarray,
23
+ vertices: np.ndarray = None,
24
+ backslash: np.ndarray = None
25
+ ) -> np.ndarray:
26
+ """
27
+ Triangulate a polygonal mesh.
28
+
29
+ Args:
30
+ faces (np.ndarray): [L, P] polygonal faces
31
+ vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices.
32
+ If given, the triangulation is performed according to the distance
33
+ between vertices. Defaults to None.
34
+ backslash (np.ndarray, optional): [L] boolean array indicating
35
+ how to triangulate the quad faces. Defaults to None.
36
+
37
+ Returns:
38
+ (np.ndarray): [L * (P - 2), 3] triangular faces
39
+ """
40
+ if faces.shape[-1] == 3:
41
+ return faces
42
+ P = faces.shape[-1]
43
+ if vertices is not None:
44
+ assert faces.shape[-1] == 4, "now only support quad mesh"
45
+ if backslash is None:
46
+ backslash = np.linalg.norm(vertices[faces[:, 0]] - vertices[faces[:, 2]], axis=-1) < \
47
+ np.linalg.norm(vertices[faces[:, 1]] - vertices[faces[:, 3]], axis=-1)
48
+ if backslash is None:
49
+ loop_indice = np.stack([
50
+ np.zeros(P - 2, dtype=int),
51
+ np.arange(1, P - 1, 1, dtype=int),
52
+ np.arange(2, P, 1, dtype=int)
53
+ ], axis=1)
54
+ return faces[:, loop_indice].reshape((-1, 3))
55
+ else:
56
+ assert faces.shape[-1] == 4, "now only support quad mesh"
57
+ faces = np.where(
58
+ backslash[:, None],
59
+ faces[:, [0, 1, 2, 0, 2, 3]],
60
+ faces[:, [0, 1, 3, 3, 1, 2]]
61
+ ).reshape((-1, 3))
62
+ return faces
63
+
64
+
65
+ @batched(2, None)
66
+ def compute_face_normal(
67
+ vertices: np.ndarray,
68
+ faces: np.ndarray
69
+ ) -> np.ndarray:
70
+ """
71
+ Compute face normals of a triangular mesh
72
+
73
+ Args:
74
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
75
+ faces (np.ndarray): [T, 3] triangular face indices
76
+
77
+ Returns:
78
+ normals (np.ndarray): [..., T, 3] face normals
79
+ """
80
+ normal = np.cross(
81
+ vertices[..., faces[:, 1], :] - vertices[..., faces[:, 0], :],
82
+ vertices[..., faces[:, 2], :] - vertices[..., faces[:, 0], :]
83
+ )
84
+ normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
85
+ normal_norm[normal_norm == 0] = 1
86
+ normal /= normal_norm
87
+ return normal
88
+
89
+
90
+ @batched(2, None)
91
+ def compute_face_angle(
92
+ vertices: np.ndarray,
93
+ faces: np.ndarray,
94
+ eps: float = 1e-12
95
+ ) -> np.ndarray:
96
+ """
97
+ Compute face angles of a triangular mesh
98
+
99
+ Args:
100
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
101
+ faces (np.ndarray): [T, 3] triangular face indices
102
+
103
+ Returns:
104
+ angles (np.ndarray): [..., T, 3] face angles
105
+ """
106
+ face_angle = np.zeros_like(faces, dtype=vertices.dtype)
107
+ for i in range(3):
108
+ edge1 = vertices[..., faces[:, (i + 1) % 3], :] - vertices[..., faces[:, i], :]
109
+ edge2 = vertices[..., faces[:, (i + 2) % 3], :] - vertices[..., faces[:, i], :]
110
+ face_angle[..., i] = np.arccos(np.sum(
111
+ edge1 / np.clip(np.linalg.norm(edge1, axis=-1, keepdims=True), eps, None) *
112
+ edge2 / np.clip(np.linalg.norm(edge2, axis=-1, keepdims=True), eps, None),
113
+ axis=-1
114
+ ))
115
+ return face_angle
116
+
117
+
118
+ @batched(2, None, 2)
119
+ def compute_vertex_normal(
120
+ vertices: np.ndarray,
121
+ faces: np.ndarray,
122
+ face_normal: np.ndarray = None
123
+ ) -> np.ndarray:
124
+ """
125
+ Compute vertex normals of a triangular mesh by averaging neightboring face normals
126
+ TODO: can be improved.
127
+
128
+ Args:
129
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
130
+ faces (np.ndarray): [T, 3] triangular face indices
131
+ face_normal (np.ndarray, optional): [..., T, 3] face normals.
132
+ None to compute face normals from vertices and faces. Defaults to None.
133
+
134
+ Returns:
135
+ normals (np.ndarray): [..., N, 3] vertex normals
136
+ """
137
+ if face_normal is None:
138
+ face_normal = compute_face_normal(vertices, faces)
139
+ vertex_normal = np.zeros_like(vertices, dtype=vertices.dtype)
140
+ for n in range(vertices.shape[0]):
141
+ for i in range(3):
142
+ vertex_normal[n, :, 0] += np.bincount(faces[:, i], weights=face_normal[n, :, 0], minlength=vertices.shape[1])
143
+ vertex_normal[n, :, 1] += np.bincount(faces[:, i], weights=face_normal[n, :, 1], minlength=vertices.shape[1])
144
+ vertex_normal[n, :, 2] += np.bincount(faces[:, i], weights=face_normal[n, :, 2], minlength=vertices.shape[1])
145
+ vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True)
146
+ vertex_normal_norm[vertex_normal_norm == 0] = 1
147
+ vertex_normal /= vertex_normal_norm
148
+ return vertex_normal
149
+
150
+
151
+ @batched(2, None, 2)
152
+ def compute_vertex_normal_weighted(
153
+ vertices: np.ndarray,
154
+ faces: np.ndarray,
155
+ face_normal: np.ndarray = None
156
+ ) -> np.ndarray:
157
+ """
158
+ Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals
159
+ according to the angles
160
+
161
+ Args:
162
+ vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
163
+ faces (np.ndarray): [..., T, 3] triangular face indices
164
+ face_normal (np.ndarray, optional): [..., T, 3] face normals.
165
+ None to compute face normals from vertices and faces. Defaults to None.
166
+
167
+ Returns:
168
+ normals (np.ndarray): [..., N, 3] vertex normals
169
+ """
170
+ if face_normal is None:
171
+ face_normal = compute_face_normal(vertices, faces)
172
+ face_angle = compute_face_angle(vertices, faces)
173
+ vertex_normal = np.zeros_like(vertices)
174
+ for n in range(vertices.shape[0]):
175
+ for i in range(3):
176
+ vertex_normal[n, :, 0] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 0] * face_angle[n, :, i], minlength=vertices.shape[1])
177
+ vertex_normal[n, :, 1] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 1] * face_angle[n, :, i], minlength=vertices.shape[1])
178
+ vertex_normal[n, :, 2] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 2] * face_angle[n, :, i], minlength=vertices.shape[1])
179
+ vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True)
180
+ vertex_normal_norm[vertex_normal_norm == 0] = 1
181
+ vertex_normal /= vertex_normal_norm
182
+ return vertex_normal
183
+
184
+
185
+ def remove_corrupted_faces(
186
+ faces: np.ndarray
187
+ ) -> np.ndarray:
188
+ """
189
+ Remove corrupted faces (faces with duplicated vertices)
190
+
191
+ Args:
192
+ faces (np.ndarray): [T, 3] triangular face indices
193
+
194
+ Returns:
195
+ np.ndarray: [T_, 3] triangular face indices
196
+ """
197
+ corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0])
198
+ return faces[~corrupted]
199
+
200
+
201
+ def merge_duplicate_vertices(
202
+ vertices: np.ndarray,
203
+ faces: np.ndarray,
204
+ tol: float = 1e-6
205
+ ) -> Tuple[np.ndarray, np.ndarray]:
206
+ """
207
+ Merge duplicate vertices of a triangular mesh.
208
+ Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly.
209
+
210
+ Args:
211
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
212
+ faces (np.ndarray): [T, 3] triangular face indices
213
+ tol (float, optional): tolerance for merging. Defaults to 1e-6.
214
+
215
+ Returns:
216
+ vertices (np.ndarray): [N_, 3] 3-dimensional vertices
217
+ faces (np.ndarray): [T, 3] triangular face indices
218
+ """
219
+ vertices_round = np.round(vertices / tol)
220
+ _, uni_i, uni_inv = np.unique(vertices_round, return_index=True, return_inverse=True, axis=0)
221
+ vertices = vertices[uni_i]
222
+ faces = uni_inv[faces]
223
+ return vertices, faces
224
+
225
+
226
+ def remove_unreferenced_vertices(
227
+ faces: np.ndarray,
228
+ *vertice_attrs,
229
+ return_indices: bool = False
230
+ ) -> Tuple[np.ndarray, ...]:
231
+ """
232
+ Remove unreferenced vertices of a mesh.
233
+ Unreferenced vertices are removed, and the face indices are updated accordingly.
234
+
235
+ Args:
236
+ faces (np.ndarray): [T, P] face indices
237
+ *vertice_attrs: vertex attributes
238
+
239
+ Returns:
240
+ faces (np.ndarray): [T, P] face indices
241
+ *vertice_attrs: vertex attributes
242
+ indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None.
243
+ """
244
+ P = faces.shape[-1]
245
+ fewer_indices, inv_map = np.unique(faces, return_inverse=True)
246
+ faces = inv_map.astype(np.int32).reshape(-1, P)
247
+ ret = [faces]
248
+ for attr in vertice_attrs:
249
+ ret.append(attr[fewer_indices])
250
+ if return_indices:
251
+ ret.append(fewer_indices)
252
+ return tuple(ret)
253
+
254
+
255
+ def subdivide_mesh_simple(
256
+ vertices: np.ndarray,
257
+ faces: np.ndarray,
258
+ n: int = 1
259
+ ) -> Tuple[np.ndarray, np.ndarray]:
260
+ """
261
+ Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles.
262
+ NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list.
263
+
264
+ Args:
265
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
266
+ faces (np.ndarray): [T, 3] triangular face indices
267
+ n (int, optional): number of subdivisions. Defaults to 1.
268
+
269
+ Returns:
270
+ vertices (np.ndarray): [N_, 3] subdivided 3-dimensional vertices
271
+ faces (np.ndarray): [4 * T, 3] subdivided triangular face indices
272
+ """
273
+ for _ in range(n):
274
+ edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0)
275
+ edges = np.sort(edges, axis=2)
276
+ uni_edges, uni_inv = np.unique(edges.reshape(-1, 2), return_inverse=True, axis=0)
277
+ uni_inv = uni_inv.reshape(3, -1)
278
+ midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2
279
+
280
+ n_vertices = vertices.shape[0]
281
+ vertices = np.concatenate([vertices, midpoints], axis=0)
282
+ faces = np.concatenate([
283
+ np.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1),
284
+ np.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1),
285
+ np.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1),
286
+ np.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1),
287
+ ], axis=0)
288
+ return vertices, faces
289
+
290
+
291
+ def mesh_relations(
292
+ faces: np.ndarray,
293
+ ) -> Tuple[np.ndarray, np.ndarray]:
294
+ """
295
+ Calculate the relation between vertices and faces.
296
+ NOTE: The input mesh must be a manifold triangle mesh.
297
+
298
+ Args:
299
+ faces (np.ndarray): [T, 3] triangular face indices
300
+
301
+ Returns:
302
+ edges (np.ndarray): [E, 2] edge indices
303
+ edge2face (np.ndarray): [E, 2] edge to face relation. The second column is -1 if the edge is boundary.
304
+ face2edge (np.ndarray): [T, 3] face to edge relation
305
+ face2face (np.ndarray): [T, 3] face to face relation
306
+ """
307
+ T = faces.shape[0]
308
+ edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=1).reshape(-1, 2) # [3T, 2]
309
+ edges = np.sort(edges, axis=1) # [3T, 2]
310
+ edges, face2edge, occurence = np.unique(edges, axis=0, return_inverse=True, return_counts=True) # [E, 2], [3T], [E]
311
+ E = edges.shape[0]
312
+ assert np.all(occurence <= 2), "The input mesh is not a manifold mesh."
313
+
314
+ # Edge to face relation
315
+ padding = np.arange(E, dtype=np.int32)[occurence == 1]
316
+ padded_face2edge = np.concatenate([face2edge, padding], axis=0) # [2E]
317
+ edge2face = np.argsort(padded_face2edge, kind='stable').reshape(-1, 2) // 3 # [E, 2]
318
+ edge2face_valid = edge2face[:, 1] < T # [E]
319
+ edge2face[~edge2face_valid, 1] = -1
320
+
321
+ # Face to edge relation
322
+ face2edge = face2edge.reshape(-1, 3) # [T, 3]
323
+
324
+ # Face to face relation
325
+ face2face = edge2face[face2edge] # [T, 3, 2]
326
+ face2face = face2face[face2face != np.arange(T)[:, None, None]].reshape(T, 3) # [T, 3]
327
+
328
+ return edges, edge2face, face2edge, face2face
329
+
330
+
331
+ @overload
332
+ def flatten_mesh_indices(faces1: np.ndarray, attr1: np.ndarray, *other_faces_attrs_pairs: np.ndarray) -> Tuple[np.ndarray, ...]:
333
+ """
334
+ Rearrange the indices of a mesh to a flattened version. Vertices will be no longer shared.
335
+
336
+ ### Parameters:
337
+ - `faces1`: [T, P] face indices of the first attribute
338
+ - `attr1`: [N1, ...] attributes of the first mesh
339
+ - ...
340
+
341
+ ### Returns:
342
+ - `faces`: [T, P] flattened face indices, contigous from 0 to T * P - 1
343
+ - `attr1`: [T * P, ...] attributes of the first mesh, where every P values correspond to a face
344
+ _ ...
345
+ """
346
+ def flatten_mesh_indices(*args: np.ndarray) -> Tuple[np.ndarray, ...]:
347
+ assert len(args) % 2 == 0, "The number of arguments must be even."
348
+ T, P = args[0].shape
349
+ assert all(arg.shape[0] == T and arg.shape[1] == P for arg in args[::2]), "The faces must have the same shape."
350
+ attr_flat = []
351
+ for faces_, attr_ in zip(args[::2], args[1::2]):
352
+ attr_flat_ = attr_[faces_].reshape(-1, *attr_.shape[1:])
353
+ attr_flat.append(attr_flat_)
354
+ faces_flat = np.arange(T * P, dtype=np.int32).reshape(T, P)
355
+ return faces_flat, *attr_flat
utils3d/numpy/quadmesh.py ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import scipy as sp
3
+ import scipy.optimize as spopt
4
+ import piqp
5
+ from typing import *
6
+
7
+
8
+ __all__ = [
9
+ 'calc_quad_candidates',
10
+ 'calc_quad_distortion',
11
+ 'calc_quad_direction',
12
+ 'calc_quad_smoothness',
13
+ 'sovle_quad',
14
+ 'sovle_quad_qp',
15
+ 'tri_to_quad'
16
+ ]
17
+
18
+
19
+ def calc_quad_candidates(
20
+ edges: np.ndarray,
21
+ face2edge: np.ndarray,
22
+ edge2face: np.ndarray,
23
+ ):
24
+ """
25
+ Calculate the candidate quad faces.
26
+
27
+ Args:
28
+ edges (np.ndarray): [E, 2] edge indices
29
+ face2edge (np.ndarray): [T, 3] face to edge relation
30
+ edge2face (np.ndarray): [E, 2] edge to face relation
31
+
32
+ Returns:
33
+ quads (np.ndarray): [Q, 4] quad candidate indices
34
+ quad2edge (np.ndarray): [Q, 4] edge to quad candidate relation
35
+ quad2adj (np.ndarray): [Q, 8] adjacent quad candidates of each quad candidate
36
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
37
+ """
38
+ E = edges.shape[0]
39
+ T = face2edge.shape[0]
40
+
41
+ quads_valid = edge2face[:, 1] != -1
42
+ Q = quads_valid.sum()
43
+ quad2face = edge2face[quads_valid] # [Q, 2]
44
+ quad2edge = face2edge[quad2face] # [Q, 2, 3]
45
+ flag = quad2edge == np.arange(E)[quads_valid][:, None, None] # [Q, 2, 3]
46
+ flag = flag.argmax(axis=-1) # [Q, 2]
47
+ quad2edge = np.stack([
48
+ quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 1) % 3],
49
+ quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 2) % 3],
50
+ ], axis=-1).reshape(Q, 4) # [Q, 4]
51
+
52
+ quads = np.concatenate([
53
+ np.where(
54
+ (edges[quad2edge[:, 0:1], 1:] == edges[quad2edge[:, 1:2], :]).any(axis=-1),
55
+ edges[quad2edge[:, 0:1], [[0, 1]]],
56
+ edges[quad2edge[:, 0:1], [[1, 0]]],
57
+ ),
58
+ np.where(
59
+ (edges[quad2edge[:, 2:3], 1:] == edges[quad2edge[:, 3:4], :]).any(axis=-1),
60
+ edges[quad2edge[:, 2:3], [[0, 1]]],
61
+ edges[quad2edge[:, 2:3], [[1, 0]]],
62
+ ),
63
+ ], axis=1) # [Q, 4]
64
+
65
+ quad2adj = edge2face[quad2edge] # [Q, 4, 2]
66
+ quad2adj = quad2adj[quad2adj != quad2face[:, [0,0,1,1], None]].reshape(Q, 4) # [Q, 4]
67
+ quad2adj_valid = quad2adj != -1
68
+ quad2adj = face2edge[quad2adj] # [Q, 4, 3]
69
+ quad2adj[~quad2adj_valid, 0] = quad2edge[~quad2adj_valid]
70
+ quad2adj[~quad2adj_valid, 1:] = -1
71
+ quad2adj = quad2adj[quad2adj != quad2edge[..., None]].reshape(Q, 8) # [Q, 8]
72
+ edge_valid = -np.ones(E, dtype=np.int32)
73
+ edge_valid[quads_valid] = np.arange(Q)
74
+ quad2adj_valid = quad2adj != -1
75
+ quad2adj[quad2adj_valid] = edge_valid[quad2adj[quad2adj_valid]] # [Q, 8]
76
+
77
+ return quads, quad2edge, quad2adj, quads_valid
78
+
79
+
80
+ def calc_quad_distortion(
81
+ vertices: np.ndarray,
82
+ quads: np.ndarray,
83
+ ):
84
+ """
85
+ Calculate the distortion of each candidate quad face.
86
+
87
+ Args:
88
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
89
+ quads (np.ndarray): [Q, 4] quad face indices
90
+
91
+ Returns:
92
+ distortion (np.ndarray): [Q] distortion of each quad face
93
+ """
94
+ edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3]
95
+ edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3]
96
+ edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3]
97
+ edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3]
98
+ cross = vertices[quads[:, 0]] - vertices[quads[:, 2]] # [Q, 3]
99
+
100
+ len0 = np.maximum(np.linalg.norm(edge0, axis=-1), 1e-10) # [Q]
101
+ len1 = np.maximum(np.linalg.norm(edge1, axis=-1), 1e-10) # [Q]
102
+ len2 = np.maximum(np.linalg.norm(edge2, axis=-1), 1e-10) # [Q]
103
+ len3 = np.maximum(np.linalg.norm(edge3, axis=-1), 1e-10) # [Q]
104
+ len_cross = np.maximum(np.linalg.norm(cross, axis=-1), 1e-10) # [Q]
105
+
106
+ angle0 = np.arccos(np.clip(np.sum(-edge0 * edge1, axis=-1) / (len0 * len1), -1, 1)) # [Q]
107
+ angle1 = np.arccos(np.clip(np.sum(-edge1 * cross, axis=-1) / (len1 * len_cross), -1, 1)) \
108
+ + np.arccos(np.clip(np.sum(cross * edge2, axis=-1) / (len_cross * len2), -1, 1)) # [Q]
109
+ angle2 = np.arccos(np.clip(np.sum(-edge2 * edge3, axis=-1) / (len2 * len3), -1, 1)) # [Q]
110
+ angle3 = np.arccos(np.clip(np.sum(-edge3 * -cross, axis=-1) / (len3 * len_cross), -1, 1)) \
111
+ + np.arccos(np.clip(np.sum(-cross * edge0, axis=-1) / (len_cross * len0), -1, 1)) # [Q]
112
+
113
+ normal0 = np.cross(edge0, edge1) # [Q, 3]
114
+ normal1 = np.cross(edge2, edge3) # [Q, 3]
115
+ normal0 = normal0 / np.maximum(np.linalg.norm(normal0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
116
+ normal1 = normal1 / np.maximum(np.linalg.norm(normal1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
117
+ angle_normal = np.arccos(np.clip(np.sum(normal0 * normal1, axis=-1), -1, 1)) # [Q]
118
+
119
+ D90 = np.pi / 2
120
+ D180 = np.pi
121
+ D360 = np.pi * 2
122
+ ang_eng = (np.abs(angle0 - D90)**2 + np.abs(angle1 - D90)**2 + np.abs(angle2 - D90)**2 + np.abs(angle3 - D90)**2) / 4 # [Q]
123
+ dist_eng = np.abs(angle0 - angle2)**2 / np.minimum(np.maximum(np.minimum(angle0, angle2), 1e-10), np.maximum(D180 - np.maximum(angle0, angle2), 1e-10)) \
124
+ + np.abs(angle1 - angle3)**2 / np.minimum(np.maximum(np.minimum(angle1, angle3), 1e-10), np.maximum(D180 - np.maximum(angle1, angle3), 1e-10)) # [Q]
125
+ plane_eng = np.where(angle_normal < D90/2, np.abs(angle_normal)**2, 1e10) # [Q]
126
+ eng = ang_eng + 2 * dist_eng + 2 * plane_eng # [Q]
127
+
128
+ return eng
129
+
130
+
131
+ def calc_quad_direction(
132
+ vertices: np.ndarray,
133
+ quads: np.ndarray,
134
+ ):
135
+ """
136
+ Calculate the direction of each candidate quad face.
137
+
138
+ Args:
139
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
140
+ quads (np.ndarray): [Q, 4] quad face indices
141
+
142
+ Returns:
143
+ direction (np.ndarray): [Q, 4] direction of each quad face.
144
+ Represented by the angle between the crossing and each edge.
145
+ """
146
+ mid0 = (vertices[quads[:, 0]] + vertices[quads[:, 1]]) / 2 # [Q, 3]
147
+ mid1 = (vertices[quads[:, 1]] + vertices[quads[:, 2]]) / 2 # [Q, 3]
148
+ mid2 = (vertices[quads[:, 2]] + vertices[quads[:, 3]]) / 2 # [Q, 3]
149
+ mid3 = (vertices[quads[:, 3]] + vertices[quads[:, 0]]) / 2 # [Q, 3]
150
+
151
+ cross0 = mid2 - mid0 # [Q, 3]
152
+ cross1 = mid3 - mid1 # [Q, 3]
153
+ cross0 = cross0 / np.maximum(np.linalg.norm(cross0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
154
+ cross1 = cross1 / np.maximum(np.linalg.norm(cross1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
155
+
156
+ edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3]
157
+ edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3]
158
+ edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3]
159
+ edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3]
160
+ edge0 = edge0 / np.maximum(np.linalg.norm(edge0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
161
+ edge1 = edge1 / np.maximum(np.linalg.norm(edge1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
162
+ edge2 = edge2 / np.maximum(np.linalg.norm(edge2, axis=-1, keepdims=True), 1e-10) # [Q, 3]
163
+ edge3 = edge3 / np.maximum(np.linalg.norm(edge3, axis=-1, keepdims=True), 1e-10) # [Q, 3]
164
+
165
+ direction = np.stack([
166
+ np.arccos(np.clip(np.sum(cross0 * edge0, axis=-1), -1, 1)),
167
+ np.arccos(np.clip(np.sum(cross1 * edge1, axis=-1), -1, 1)),
168
+ np.arccos(np.clip(np.sum(-cross0 * edge2, axis=-1), -1, 1)),
169
+ np.arccos(np.clip(np.sum(-cross1 * edge3, axis=-1), -1, 1)),
170
+ ], axis=-1) # [Q, 4]
171
+
172
+ return direction
173
+
174
+
175
+ def calc_quad_smoothness(
176
+ quad2edge: np.ndarray,
177
+ quad2adj: np.ndarray,
178
+ quads_direction: np.ndarray,
179
+ ):
180
+ """
181
+ Calculate the smoothness of each candidate quad face connection.
182
+
183
+ Args:
184
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
185
+ quads_direction (np.ndarray): [Q, 4] direction of each quad face
186
+
187
+ Returns:
188
+ smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
189
+ """
190
+ Q = quad2adj.shape[0]
191
+ quad2adj_valid = quad2adj != -1
192
+ connections = np.stack([
193
+ np.arange(Q)[:, None].repeat(8, axis=1),
194
+ quad2adj,
195
+ ], axis=-1)[quad2adj_valid] # [C, 2]
196
+ shared_edge_idx_0 = np.array([[0, 0, 1, 1, 2, 2, 3, 3]]).repeat(Q, axis=0)[quad2adj_valid] # [C]
197
+ shared_edge_idx_1 = np.argmax(quad2edge[quad2adj][quad2adj_valid] == quad2edge[connections[:, 0], shared_edge_idx_0][:, None], axis=-1) # [C]
198
+ valid_smoothness = np.abs(quads_direction[connections[:, 0], shared_edge_idx_0] - quads_direction[connections[:, 1], shared_edge_idx_1])**2 # [C]
199
+ smoothness = np.zeros([Q, 8], dtype=np.float32)
200
+ smoothness[quad2adj_valid] = valid_smoothness
201
+ return smoothness
202
+
203
+
204
+ def sovle_quad(
205
+ face2edge: np.ndarray,
206
+ edge2face: np.ndarray,
207
+ quad2adj: np.ndarray,
208
+ quads_distortion: np.ndarray,
209
+ quads_smoothness: np.ndarray,
210
+ quads_valid: np.ndarray,
211
+ ):
212
+ """
213
+ Solve the quad mesh from the candidate quad faces.
214
+
215
+ Args:
216
+ face2edge (np.ndarray): [T, 3] face to edge relation
217
+ edge2face (np.ndarray): [E, 2] edge to face relation
218
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
219
+ quads_distortion (np.ndarray): [Q] distortion of each quad face
220
+ quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
221
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
222
+
223
+ Returns:
224
+ weights (np.ndarray): [Q] weight of each valid quad face
225
+ """
226
+ T = face2edge.shape[0]
227
+ E = edge2face.shape[0]
228
+ Q = quads_distortion.shape[0]
229
+ edge_valid = -np.ones(E, dtype=np.int32)
230
+ edge_valid[quads_valid] = np.arange(Q)
231
+
232
+ quads_connection = np.stack([
233
+ np.arange(Q)[:, None].repeat(8, axis=1),
234
+ quad2adj,
235
+ ], axis=-1)[quad2adj != -1] # [C, 2]
236
+ quads_connection = np.sort(quads_connection, axis=-1) # [C, 2]
237
+ quads_connection, quads_connection_idx = np.unique(quads_connection, axis=0, return_index=True) # [C, 2], [C]
238
+ quads_smoothness = quads_smoothness[quad2adj != -1] # [C]
239
+ quads_smoothness = quads_smoothness[quads_connection_idx] # [C]
240
+ C = quads_connection.shape[0]
241
+
242
+ # Construct the linear programming problem
243
+
244
+ # Variables:
245
+ # quads_weight: [Q] weight of each quad face
246
+ # tri_min_weight: [T] minimum weight of each triangle face
247
+ # conn_min_weight: [C] minimum weight of each quad face connection
248
+ # conn_max_weight: [C] maximum weight of each quad face connection
249
+ # Objective:
250
+ # mimi
251
+
252
+ c = np.concatenate([
253
+ quads_distortion - 3,
254
+ quads_smoothness*4 - 2,
255
+ quads_smoothness*4,
256
+ ], axis=0) # [Q+C]
257
+
258
+ A_ub_triplet = np.concatenate([
259
+ np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3]
260
+ np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3]
261
+ np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3]
262
+ np.stack([np.arange(T, T+C), np.arange(Q, Q+C), np.ones(C)], axis=1), # [C, 3]
263
+ np.stack([np.arange(T, T+C), quads_connection[:, 0], -np.ones(C)], axis=1), # [C, 3]
264
+ np.stack([np.arange(T, T+C), quads_connection[:, 1], -np.ones(C)], axis=1), # [C, 3]
265
+ np.stack([np.arange(T+C, T+2*C), np.arange(Q+C, Q+2*C), -np.ones(C)], axis=1), # [C, 3]
266
+ np.stack([np.arange(T+C, T+2*C), quads_connection[:, 0], np.ones(C)], axis=1), # [C, 3]
267
+ np.stack([np.arange(T+C, T+2*C), quads_connection[:, 1], np.ones(C)], axis=1), # [C, 3]
268
+ ], axis=0) # [3T+6C, 3]
269
+ A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3]
270
+ A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T+2*C, Q+2*C]) # [T,
271
+ b_ub = np.concatenate([np.ones(T), -np.ones(C), np.ones(C)], axis=0) # [T+2C]
272
+ bound = np.stack([
273
+ np.concatenate([np.zeros(Q), -np.ones(C), np.zeros(C)], axis=0),
274
+ np.concatenate([np.ones(Q), np.ones(C), np.ones(C)], axis=0),
275
+ ], axis=1) # [Q+2C, 2]
276
+ A_eq = None
277
+ b_eq = None
278
+
279
+ print('Solver statistics:')
280
+ print(f' #T = {T}')
281
+ print(f' #Q = {Q}')
282
+ print(f' #C = {C}')
283
+
284
+ # Solve the linear programming problem
285
+ last_num_valid = 0
286
+ for i in range(100):
287
+ res_ = spopt.linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bound)
288
+ if not res_.success:
289
+ print(f' Iter {i} | Failed with {res_.message}')
290
+ break
291
+ res = res_
292
+ weights = res.x[:Q]
293
+ valid = (weights > 0.5)
294
+ num_valid = valid.sum()
295
+ print(f' Iter {i} | #Q_valid = {num_valid}')
296
+ if num_valid == last_num_valid:
297
+ break
298
+ last_num_valid = num_valid
299
+ A_eq_triplet = np.stack([
300
+ np.arange(num_valid),
301
+ np.arange(Q)[valid],
302
+ np.ones(num_valid),
303
+ ], axis=1) # [num_valid, 3]
304
+ A_eq = sp.sparse.coo_matrix((A_eq_triplet[:, 2], (A_eq_triplet[:, 0], A_eq_triplet[:, 1])), shape=[num_valid, Q+2*C]) # [num_valid, Q+C]
305
+ b_eq = np.where(weights[valid] > 0.5, 1, 0) # [num_valid]
306
+
307
+ # Return the result
308
+ quads_weight = res.x[:Q]
309
+ conn_min_weight = res.x[Q:Q+C]
310
+ conn_max_weight = res.x[Q+C:Q+2*C]
311
+ return quads_weight, conn_min_weight, conn_max_weight
312
+
313
+
314
+ def sovle_quad_qp(
315
+ face2edge: np.ndarray,
316
+ edge2face: np.ndarray,
317
+ quad2adj: np.ndarray,
318
+ quads_distortion: np.ndarray,
319
+ quads_smoothness: np.ndarray,
320
+ quads_valid: np.ndarray,
321
+ ):
322
+ """
323
+ Solve the quad mesh from the candidate quad faces.
324
+
325
+ Args:
326
+ face2edge (np.ndarray): [T, 3] face to edge relation
327
+ edge2face (np.ndarray): [E, 2] edge to face relation
328
+ quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
329
+ quads_distortion (np.ndarray): [Q] distortion of each quad face
330
+ quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
331
+ quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
332
+
333
+ Returns:
334
+ weights (np.ndarray): [Q] weight of each valid quad face
335
+ """
336
+ T = face2edge.shape[0]
337
+ E = edge2face.shape[0]
338
+ Q = quads_distortion.shape[0]
339
+ edge_valid = -np.ones(E, dtype=np.int32)
340
+ edge_valid[quads_valid] = np.arange(Q)
341
+
342
+ # Construct the quadratic programming problem
343
+ C_smoothness_triplet = np.stack([
344
+ np.arange(Q)[:, None].repeat(8, axis=1)[quad2adj != -1],
345
+ quad2adj[quad2adj != -1],
346
+ 5 * quads_smoothness[quad2adj != -1],
347
+ ], axis=-1) # [C, 3]
348
+ # C_smoothness_triplet = np.concatenate([
349
+ # C_smoothness_triplet,
350
+ # np.stack([np.arange(Q), np.arange(Q), 20*np.ones(Q)], axis=1),
351
+ # ], axis=0) # [C+Q, 3]
352
+ C_smoothness = sp.sparse.coo_matrix((C_smoothness_triplet[:, 2], (C_smoothness_triplet[:, 0], C_smoothness_triplet[:, 1])), shape=[Q, Q]) # [Q, Q]
353
+ C_smoothness = C_smoothness.tocsc()
354
+ C_dist = quads_distortion - 20 # [Q]
355
+
356
+ A_eq = sp.sparse.coo_matrix((np.zeros(Q), (np.zeros(Q), np.arange(Q))), shape=[1, Q]) # [1, Q]\
357
+ A_eq = A_eq.tocsc()
358
+ b_eq = np.array([0])
359
+
360
+ A_ub_triplet = np.concatenate([
361
+ np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3]
362
+ np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3]
363
+ np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3]
364
+ ], axis=0) # [3T, 3]
365
+ A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3]
366
+ A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T, Q]) # [T, Q]
367
+ A_ub = A_ub.tocsc()
368
+ b_ub = np.ones(T)
369
+
370
+ lb = np.zeros(Q)
371
+ ub = np.ones(Q)
372
+
373
+ solver = piqp.SparseSolver()
374
+ solver.settings.verbose = True
375
+ solver.settings.compute_timings = True
376
+ solver.setup(C_smoothness, C_dist, A_eq, b_eq, A_ub, b_ub, lb, ub)
377
+
378
+ status = solver.solve()
379
+
380
+ # x = cp.Variable(Q)
381
+ # prob = cp.Problem(
382
+ # cp.Minimize(cp.quad_form(x, C_smoothness) + C_dist.T @ x),
383
+ # [
384
+ # A_ub @ x <= b_ub,
385
+ # x >= 0, x <= 1,
386
+ # ]
387
+ # )
388
+
389
+ # # Solve the quadratic programming problem
390
+ # prob.solve(solver=cp.PIQP, verbose=True)
391
+
392
+ # Return the result
393
+ weights = solver.result.x
394
+ return weights
395
+
396
+
397
+ def tri_to_quad(
398
+ vertices: np.ndarray,
399
+ faces: np.ndarray,
400
+ ) -> Tuple[np.ndarray, np.ndarray]:
401
+ """
402
+ Convert a triangle mesh to a quad mesh.
403
+ NOTE: The input mesh must be a manifold mesh.
404
+
405
+ Args:
406
+ vertices (np.ndarray): [N, 3] 3-dimensional vertices
407
+ faces (np.ndarray): [T, 3] triangular face indices
408
+
409
+ Returns:
410
+ vertices (np.ndarray): [N_, 3] 3-dimensional vertices
411
+ faces (np.ndarray): [Q, 4] quad face indices
412
+ """
413
+ raise NotImplementedError
414
+
415
+
416
+ if __name__ == '__main__':
417
+ import os
418
+ import sys
419
+ sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))
420
+ import utils3d
421
+ import numpy as np
422
+ import cv2
423
+ from vis import vis_edge_color
424
+
425
+ file = 'miku'
426
+
427
+ vertices, faces = utils3d.io.read_ply(f'test/assets/{file}.ply')
428
+ edges, edge2face, face2edge, face2face = calc_relations(faces)
429
+ quad_cands, quad2edge, quad2adj, quad_valid = calc_quad_candidates(edges, face2edge, edge2face)
430
+ distortion = calc_quad_distortion(vertices, quad_cands)
431
+ direction = calc_quad_direction(vertices, quad_cands)
432
+ smoothness = calc_quad_smoothness(quad2edge, quad2adj, direction)
433
+ boundary_edges = edges[edge2face[:, 1] == -1]
434
+ quads_weight, conn_min_weight, conn_max_weight = sovle_quad(face2edge, edge2face, quad2adj, distortion, smoothness, quad_valid)
435
+ quads = quad_cands[quads_weight > 0.5]
436
+ print('Mesh statistics')
437
+ print(f' #V = {vertices.shape[0]}')
438
+ print(f' #F = {faces.shape[0]}')
439
+ print(f' #E = {edges.shape[0]}')
440
+ print(f' #B = {boundary_edges.shape[0]}')
441
+ print(f' #Q_cand = {quad_cands.shape[0]}')
442
+ print(f' #Q = {quads.shape[0]}')
443
+
444
+ utils3d.io.write_ply(f'test/assets/{file}_boundary_edges.ply', vertices=vertices, edges=boundary_edges)
445
+ utils3d.io.write_ply(f'test/assets/{file}_quad_candidates.ply', vertices=vertices, faces=quads)
446
+
447
+ edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8)
448
+ distortion = (distortion - distortion.min()) / (distortion.max() - distortion.min())
449
+ distortion = (distortion * 255).astype(np.uint8)
450
+ edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap(distortion, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
451
+ utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_distortion.ply', **vis_edge_color(vertices, edges, edge_colors))
452
+
453
+ edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8)
454
+ edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap((quads_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
455
+ utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_weights.ply', **vis_edge_color(vertices, edges, edge_colors))
456
+ utils3d.io.write_ply(f'test/assets/{file}_quad.ply', vertices=vertices, faces=quads)
457
+
458
+ quad_centers = vertices[quad_cands].mean(axis=1)
459
+ conns = np.stack([
460
+ np.arange(quad_cands.shape[0])[:, None].repeat(8, axis=1),
461
+ quad2adj,
462
+ ], axis=-1)[quad2adj != -1] # [C, 2]
463
+ conns, conns_idx = np.unique(np.sort(conns, axis=-1), axis=0, return_index=True) # [C, 2], [C]
464
+ smoothness = smoothness[quad2adj != -1][conns_idx] # [C]
465
+ conns_color = cv2.cvtColor(cv2.applyColorMap((smoothness * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
466
+ utils3d.io.write_ply(f'test/assets/{file}_quad_conn_smoothness.ply', **vis_edge_color(quad_centers, conns, conns_color))
467
+ conns_color = cv2.cvtColor(cv2.applyColorMap((conn_min_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
468
+ utils3d.io.write_ply(f'test/assets/{file}_quad_conn_min.ply', **vis_edge_color(quad_centers, conns, conns_color))
469
+ conns_color = cv2.cvtColor(cv2.applyColorMap((conn_max_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
470
+ utils3d.io.write_ply(f'test/assets/{file}_quad_conn_max.ply', **vis_edge_color(quad_centers, conns, conns_color))
471
+
472
+
utils3d/numpy/rasterization.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ from typing import *
3
+
4
+ import numpy as np
5
+ import moderngl
6
+
7
+ from . import transforms, utils, mesh
8
+
9
+
10
+ __all__ = [
11
+ 'RastContext',
12
+ 'rasterize_triangle_faces',
13
+ 'rasterize_edges',
14
+ 'texture',
15
+ 'warp_image_by_depth',
16
+ ]
17
+
18
+
19
+ def map_np_dtype(dtype) -> str:
20
+ if dtype == int:
21
+ return 'i4'
22
+ elif dtype == np.uint8:
23
+ return 'u1'
24
+ elif dtype == np.uint32:
25
+ return 'u2'
26
+ elif dtype == np.float16:
27
+ return 'f2'
28
+ elif dtype == np.float32:
29
+ return 'f4'
30
+
31
+
32
+ def one_value(dtype):
33
+ if dtype == 'u1':
34
+ return 255
35
+ elif dtype == 'u2':
36
+ return 65535
37
+ else:
38
+ return 1
39
+
40
+
41
+ class RastContext:
42
+ def __init__(self, standalone: bool = True, backend: str = None, **kwargs):
43
+ """
44
+ Create a moderngl context.
45
+
46
+ Args:
47
+ standalone (bool, optional): whether to create a standalone context. Defaults to True.
48
+ backend (str, optional): backend to use. Defaults to None.
49
+
50
+ Keyword Args:
51
+ See moderngl.create_context
52
+ """
53
+ if backend is None:
54
+ self.mgl_ctx = moderngl.create_context(standalone=standalone, **kwargs)
55
+ else:
56
+ self.mgl_ctx = moderngl.create_context(standalone=standalone, backend=backend, **kwargs)
57
+
58
+ self.__prog_src = {}
59
+ self.__prog = {}
60
+
61
+ def __del__(self):
62
+ self.mgl_ctx.release()
63
+
64
+ def screen_quad(self) -> moderngl.VertexArray:
65
+ self.screen_quad_vbo = self.mgl_ctx.buffer(np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]], dtype='f4'))
66
+ self.screen_quad_ibo = self.mgl_ctx.buffer(np.array([0, 1, 2, 0, 2, 3], dtype=np.int32))
67
+
68
+ def program_vertex_attribute(self, n: int) -> moderngl.Program:
69
+ assert n in [1, 2, 3, 4], 'vertex attribute only supports channels 1, 2, 3, 4'
70
+
71
+ if 'vertex_attribute_vsh' not in self.__prog_src:
72
+ with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.vsh'), 'r') as f:
73
+ self.__prog_src['vertex_attribute_vsh'] = f.read()
74
+ if 'vertex_attribute_fsh' not in self.__prog_src:
75
+ with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.fsh'), 'r') as f:
76
+ self.__prog_src['vertex_attribute_fsh'] = f.read()
77
+
78
+ if f'vertex_attribute_{n}' not in self.__prog:
79
+ vsh = self.__prog_src['vertex_attribute_vsh'].replace('vecN', f'vec{n}')
80
+ fsh = self.__prog_src['vertex_attribute_fsh'].replace('vecN', f'vec{n}')
81
+ self.__prog[f'vertex_attribute_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh)
82
+
83
+ return self.__prog[f'vertex_attribute_{n}']
84
+
85
+ def program_texture(self, n: int) -> moderngl.Program:
86
+ assert n in [1, 2, 3, 4], 'texture only supports channels 1, 2, 3, 4'
87
+
88
+ if 'texture_vsh' not in self.__prog_src:
89
+ with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.vsh'), 'r') as f:
90
+ self.__prog_src['texture_vsh'] = f.read()
91
+ if 'texture_fsh' not in self.__prog_src:
92
+ with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.fsh'), 'r') as f:
93
+ self.__prog_src['texture_fsh'] = f.read()
94
+
95
+ if f'texture_{n}' not in self.__prog:
96
+ vsh = self.__prog_src['texture_vsh'].replace('vecN', f'vec{n}')
97
+ fsh = self.__prog_src['texture_fsh'].replace('vecN', f'vec{n}')
98
+ self.__prog[f'texture_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh)
99
+ self.__prog[f'texture_{n}']['tex'] = 0
100
+ self.__prog[f'texture_{n}']['uv'] = 1
101
+
102
+ return self.__prog[f'texture_{n}']
103
+
104
+
105
+ def rasterize_triangle_faces(
106
+ ctx: RastContext,
107
+ vertices: np.ndarray,
108
+ faces: np.ndarray,
109
+ attr: np.ndarray,
110
+ width: int,
111
+ height: int,
112
+ transform: np.ndarray = None,
113
+ cull_backface: bool = True,
114
+ return_depth: bool = False,
115
+ image: np.ndarray = None,
116
+ depth: np.ndarray = None
117
+ ) -> Tuple[np.ndarray, np.ndarray]:
118
+ """
119
+ Rasterize vertex attribute.
120
+
121
+ Args:
122
+ vertices (np.ndarray): [N, 3]
123
+ faces (np.ndarray): [T, 3]
124
+ attr (np.ndarray): [N, C]
125
+ width (int): width of rendered image
126
+ height (int): height of rendered image
127
+ transform (np.ndarray): [4, 4] model-view-projection transformation matrix.
128
+ cull_backface (bool): whether to cull backface
129
+ image: (np.ndarray): [H, W, C] background image
130
+ depth: (np.ndarray): [H, W] background depth
131
+
132
+ Returns:
133
+ image (np.ndarray): [H, W, C] rendered image
134
+ depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.
135
+ """
136
+ assert vertices.ndim == 2 and vertices.shape[1] == 3
137
+ assert faces.ndim == 2 and faces.shape[1] == 3, f"Faces should be a 2D array with shape (T, 3), but got {faces.shape}"
138
+ assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}'
139
+ assert vertices.shape[0] == attr.shape[0]
140
+ assert vertices.dtype == np.float32
141
+ assert faces.dtype == np.uint32 or faces.dtype == np.int32
142
+ assert attr.dtype == np.float32, "Attribute should be float32"
143
+
144
+ C = attr.shape[1]
145
+ prog = ctx.program_vertex_attribute(C)
146
+
147
+ transform = np.eye(4, np.float32) if transform is None else transform
148
+
149
+ # Create buffers
150
+ ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(faces, dtype='i4'))
151
+ vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4'))
152
+ vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4'))
153
+ vao = ctx.mgl_ctx.vertex_array(
154
+ prog,
155
+ [
156
+ (vbo_vertices, '3f', 'i_position'),
157
+ (vbo_attr, f'{C}f', 'i_attr'),
158
+ ],
159
+ ibo,
160
+ mode=moderngl.TRIANGLES,
161
+ )
162
+
163
+ # Create framebuffer
164
+ image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None)
165
+ depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None)
166
+ fbo = ctx.mgl_ctx.framebuffer(
167
+ color_attachments=[image_tex],
168
+ depth_attachment=depth_tex,
169
+ )
170
+
171
+ # Render
172
+ prog['u_mvp'].write(transform.transpose().copy().astype('f4'))
173
+ fbo.use()
174
+ fbo.viewport = (0, 0, width, height)
175
+ ctx.mgl_ctx.depth_func = '<'
176
+ ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST)
177
+ if cull_backface:
178
+ ctx.mgl_ctx.enable(ctx.mgl_ctx.CULL_FACE)
179
+ else:
180
+ ctx.mgl_ctx.disable(ctx.mgl_ctx.CULL_FACE)
181
+ vao.render()
182
+ ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST)
183
+
184
+ # Read
185
+ image = np.zeros((height, width, C), dtype='f4')
186
+ image_tex.read_into(image)
187
+ image = image[::-1, :, :]
188
+ if return_depth:
189
+ depth = np.zeros((height, width), dtype='f4')
190
+ depth_tex.read_into(depth)
191
+ depth = depth[::-1, :]
192
+ else:
193
+ depth = None
194
+
195
+ # Release
196
+ vao.release()
197
+ ibo.release()
198
+ vbo_vertices.release()
199
+ vbo_attr.release()
200
+ fbo.release()
201
+ image_tex.release()
202
+ depth_tex.release()
203
+
204
+ return image, depth
205
+
206
+
207
+ def rasterize_edges(
208
+ ctx: RastContext,
209
+ vertices: np.ndarray,
210
+ edges: np.ndarray,
211
+ attr: np.ndarray,
212
+ width: int,
213
+ height: int,
214
+ transform: np.ndarray = None,
215
+ line_width: float = 1.0,
216
+ return_depth: bool = False,
217
+ image: np.ndarray = None,
218
+ depth: np.ndarray = None
219
+ ) -> Tuple[np.ndarray, ...]:
220
+ """
221
+ Rasterize vertex attribute.
222
+
223
+ Args:
224
+ vertices (np.ndarray): [N, 3]
225
+ faces (np.ndarray): [T, 3]
226
+ attr (np.ndarray): [N, C]
227
+ width (int): width of rendered image
228
+ height (int): height of rendered image
229
+ transform (np.ndarray): [4, 4] model-view-projection matrix
230
+ line_width (float): width of line. Defaults to 1.0. NOTE: Values other than 1.0 may not work across all platforms.
231
+ cull_backface (bool): whether to cull backface
232
+
233
+ Returns:
234
+ image (np.ndarray): [H, W, C] rendered image
235
+ depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.
236
+ """
237
+ assert vertices.ndim == 2 and vertices.shape[1] == 3
238
+ assert edges.ndim == 2 and edges.shape[1] == 2, f"Edges should be a 2D array with shape (T, 2), but got {edges.shape}"
239
+ assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}'
240
+ assert vertices.shape[0] == attr.shape[0]
241
+ assert vertices.dtype == np.float32
242
+ assert edges.dtype == np.uint32 or edges.dtype == np.int32
243
+ assert attr.dtype == np.float32, "Attribute should be float32"
244
+
245
+ C = attr.shape[1]
246
+ prog = ctx.program_vertex_attribute(C)
247
+
248
+ transform = transform if transform is not None else np.eye(4, np.float32)
249
+
250
+ # Create buffers
251
+ ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(edges, dtype='i4'))
252
+ vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4'))
253
+ vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4'))
254
+ vao = ctx.mgl_ctx.vertex_array(
255
+ prog,
256
+ [
257
+ (vbo_vertices, '3f', 'i_position'),
258
+ (vbo_attr, f'{C}f', 'i_attr'),
259
+ ],
260
+ ibo,
261
+ mode=moderngl.LINES,
262
+ )
263
+
264
+ # Create framebuffer
265
+ image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None)
266
+ depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None)
267
+ fbo = ctx.mgl_ctx.framebuffer(
268
+ color_attachments=[image_tex],
269
+ depth_attachment=depth_tex,
270
+ )
271
+
272
+ # Render
273
+ prog['u_mvp'].write(transform.transpose().copy().astype('f4'))
274
+ fbo.use()
275
+ fbo.viewport = (0, 0, width, height)
276
+ ctx.mgl_ctx.depth_func = '<'
277
+ ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST)
278
+ ctx.mgl_ctx.line_width = line_width
279
+ vao.render()
280
+ ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST)
281
+
282
+ # Read
283
+ image = np.zeros((height, width, C), dtype='f4')
284
+ image_tex.read_into(image)
285
+ image = image[::-1, :, :]
286
+ if return_depth:
287
+ depth = np.zeros((height, width), dtype='f4')
288
+ depth_tex.read_into(depth)
289
+ depth = depth[::-1, :]
290
+ else:
291
+ depth = None
292
+
293
+ # Release
294
+ vao.release()
295
+ ibo.release()
296
+ vbo_vertices.release()
297
+ vbo_attr.release()
298
+ fbo.release()
299
+ image_tex.release()
300
+ depth_tex.release()
301
+
302
+ return image, depth
303
+
304
+
305
+ def texture(
306
+ ctx: RastContext,
307
+ uv: np.ndarray,
308
+ texture: np.ndarray,
309
+ interpolation: str= 'linear',
310
+ wrap: str = 'clamp'
311
+ ) -> np.ndarray:
312
+ """
313
+ Given an UV image, texturing from the texture map
314
+ """
315
+ assert len(texture.shape) == 3 and 1 <= texture.shape[2] <= 4
316
+ assert uv.shape[2] == 2
317
+ height, width = uv.shape[:2]
318
+ texture_dtype = map_np_dtype(texture.dtype)
319
+
320
+ # Create VAO
321
+ screen_quad_vbo = ctx.mgl_ctx.buffer(np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]], dtype='f4'))
322
+ screen_quad_ibo = ctx.mgl_ctx.buffer(np.array([0, 1, 2, 0, 2, 3], dtype=np.int32))
323
+ screen_quad_vao = ctx.mgl_ctx.vertex_array(ctx.program_texture(texture.shape[2]), [(screen_quad_vbo, '2f4', 'in_vert')], index_buffer=screen_quad_ibo, index_element_size=4)
324
+
325
+ # Create texture, set filter and bind. TODO: min mag filter, mipmap
326
+ texture_tex = ctx.mgl_ctx.texture((texture.shape[1], texture.shape[0]), texture.shape[2], dtype=texture_dtype, data=np.ascontiguousarray(texture))
327
+ if interpolation == 'linear':
328
+ texture_tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
329
+ elif interpolation == 'nearest':
330
+ texture_tex.filter = (moderngl.NEAREST, moderngl.NEAREST)
331
+ texture_tex.use(location=0)
332
+ texture_uv = ctx.mgl_ctx.texture((width, height), 2, dtype='f4', data=np.ascontiguousarray(uv.astype('f4', copy=False)))
333
+ texture_uv.filter = (moderngl.NEAREST, moderngl.NEAREST)
334
+ texture_uv.use(location=1)
335
+
336
+ # Create render buffer and frame buffer
337
+ rb = ctx.mgl_ctx.renderbuffer((uv.shape[1], uv.shape[0]), texture.shape[2], dtype=texture_dtype)
338
+ fbo = ctx.mgl_ctx.framebuffer(color_attachments=[rb])
339
+
340
+ # Render
341
+ fbo.use()
342
+ fbo.viewport = (0, 0, width, height)
343
+ ctx.mgl_ctx.disable(ctx.mgl_ctx.BLEND)
344
+ screen_quad_vao.render()
345
+
346
+ # Read buffer
347
+ image_buffer = np.frombuffer(fbo.read(components=texture.shape[2], attachment=0, dtype=texture_dtype), dtype=texture_dtype).reshape((height, width, texture.shape[2]))
348
+
349
+ # Release
350
+ texture_tex.release()
351
+ rb.release()
352
+ fbo.release()
353
+
354
+ return image_buffer
355
+
356
+
357
+ def warp_image_by_depth(
358
+ ctx: RastContext,
359
+ src_depth: np.ndarray,
360
+ src_image: np.ndarray = None,
361
+ width: int = None,
362
+ height: int = None,
363
+ *,
364
+ extrinsics_src: np.ndarray = None,
365
+ extrinsics_tgt: np.ndarray = None,
366
+ intrinsics_src: np.ndarray = None,
367
+ intrinsics_tgt: np.ndarray = None,
368
+ near: float = 0.1,
369
+ far: float = 100.0,
370
+ cull_backface: bool = True,
371
+ ssaa: int = 1,
372
+ return_depth: bool = False,
373
+ ) -> Tuple[np.ndarray, ...]:
374
+ """
375
+ Warp image by depth map.
376
+
377
+ Args:
378
+ ctx (RastContext): rasterizer context
379
+ src_depth (np.ndarray): [H, W]
380
+ src_image (np.ndarray, optional): [H, W, C]. The image to warp. Defaults to None (use uv coordinates).
381
+ width (int, optional): width of the output image. None to use depth map width. Defaults to None.
382
+ height (int, optional): height of the output image. None to use depth map height. Defaults to None.
383
+ extrinsics_src (np.ndarray, optional): extrinsics matrix of the source camera. Defaults to None (identity).
384
+ extrinsics_tgt (np.ndarray, optional): extrinsics matrix of the target camera. Defaults to None (identity).
385
+ intrinsics_src (np.ndarray, optional): intrinsics matrix of the source camera. Defaults to None (use the same as intrinsics_tgt).
386
+ intrinsics_tgt (np.ndarray, optional): intrinsics matrix of the target camera. Defaults to None (use the same as intrinsics_src).
387
+ cull_backface (bool, optional): whether to cull backface. Defaults to True.
388
+ ssaa (int, optional): super sampling anti-aliasing. Defaults to 1.
389
+
390
+ Returns:
391
+ tgt_image (np.ndarray): [H, W, C] warped image (or uv coordinates if image is None).
392
+ tgt_depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.
393
+ """
394
+ assert src_depth.ndim == 2
395
+
396
+ if width is None:
397
+ width = src_depth.shape[1]
398
+ if height is None:
399
+ height = src_depth.shape[0]
400
+ if src_image is not None:
401
+ assert src_image.shape[-2:] == src_depth.shape[-2:], f'Shape of source image {src_image.shape} does not match shape of source depth {src_depth.shape}'
402
+
403
+ # set up default camera parameters
404
+ extrinsics_src = np.eye(4) if extrinsics_src is None else extrinsics_src
405
+ extrinsics_tgt = np.eye(4) if extrinsics_tgt is None else extrinsics_tgt
406
+ intrinsics_src = intrinsics_tgt if intrinsics_src is None else intrinsics_src
407
+ intrinsics_tgt = intrinsics_src if intrinsics_tgt is None else intrinsics_tgt
408
+
409
+ assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters."
410
+
411
+ # check shapes
412
+ assert extrinsics_src.shape == (4, 4) and extrinsics_tgt.shape == (4, 4)
413
+ assert intrinsics_src.shape == (3, 3) and intrinsics_tgt.shape == (3, 3)
414
+
415
+ # convert to view and perspective matrices
416
+ view_tgt = transforms.extrinsics_to_view(extrinsics_tgt)
417
+ perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far)
418
+
419
+ # unproject depth map
420
+ uv, faces = utils.image_mesh(*src_depth.shape[-2:])
421
+ pts = transforms.unproject_cv(uv, src_depth.reshape(-1), extrinsics_src, intrinsics_src)
422
+ faces = mesh.triangulate(faces, vertices=pts)
423
+
424
+ # rasterize attributes
425
+ if src_image is not None:
426
+ attr = src_image.reshape(-1, src_image.shape[-1])
427
+ else:
428
+ attr = uv
429
+
430
+ tgt_image, tgt_depth = rasterize_triangle_faces(
431
+ ctx,
432
+ pts,
433
+ faces,
434
+ attr,
435
+ width * ssaa,
436
+ height * ssaa,
437
+ transform=perspective_tgt @ view_tgt,
438
+ cull_backface=cull_backface,
439
+ return_depth=return_depth,
440
+ )
441
+
442
+ if ssaa > 1:
443
+ tgt_image = tgt_image.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3))
444
+ tgt_depth = tgt_depth.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3)) if return_depth else None
445
+
446
+ return tgt_image, tgt_depth
447
+
448
+ def test():
449
+ """
450
+ Test if rasterization works. It will render a cube with random colors and save it as a CHECKME.png file.
451
+ """
452
+ ctx = RastContext(backend='egl')
453
+ vertices, faces = utils.cube(tri=True)
454
+ attr = np.random.rand(len(vertices), 3).astype(np.float32)
455
+ perspective = transforms.perspective(np.deg2rad(60), 1, 0.01, 100)
456
+ view = transforms.view_look_at(np.array([2, 2, 2]), np.array([0, 0, 0]), np.array([0, 1, 0]))
457
+ image, _ = rasterize_triangle_faces(
458
+ ctx,
459
+ vertices,
460
+ faces,
461
+ attr,
462
+ 512, 512,
463
+ view=view,
464
+ projection=perspective,
465
+ cull_backface=True,
466
+ ssaa=1,
467
+ return_depth=True,
468
+ )
469
+ import cv2
470
+ cv2.imwrite('CHECKME.png', cv2.cvtColor((image.clip(0, 1) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
471
+