first commit
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- .gitignore +25 -0
- app.py +111 -0
- moge/model/__init__.py +1 -0
- moge/model/dinov2/__init__.py +6 -0
- moge/model/dinov2/hub/__init__.py +4 -0
- moge/model/dinov2/hub/backbones.py +156 -0
- moge/model/dinov2/hub/utils.py +39 -0
- moge/model/dinov2/layers/__init__.py +11 -0
- moge/model/dinov2/layers/attention.py +89 -0
- moge/model/dinov2/layers/block.py +259 -0
- moge/model/dinov2/layers/dino_head.py +58 -0
- moge/model/dinov2/layers/drop_path.py +34 -0
- moge/model/dinov2/layers/layer_scale.py +27 -0
- moge/model/dinov2/layers/mlp.py +40 -0
- moge/model/dinov2/layers/patch_embed.py +88 -0
- moge/model/dinov2/layers/swiglu_ffn.py +72 -0
- moge/model/dinov2/models/__init__.py +43 -0
- moge/model/dinov2/models/vision_transformer.py +396 -0
- moge/model/dinov2/utils/__init__.py +4 -0
- moge/model/dinov2/utils/cluster.py +95 -0
- moge/model/dinov2/utils/config.py +72 -0
- moge/model/dinov2/utils/dtype.py +37 -0
- moge/model/dinov2/utils/param_groups.py +103 -0
- moge/model/dinov2/utils/utils.py +95 -0
- moge/model/moge_model.py +376 -0
- moge/model/utils.py +38 -0
- moge/utils/__init__.py +0 -0
- moge/utils/blob.py +314 -0
- moge/utils/download.py +55 -0
- moge/utils/geometry_numpy.py +175 -0
- moge/utils/geometry_torch.py +231 -0
- moge/utils/io.py +347 -0
- moge/utils/pipeline.py +503 -0
- moge/utils/tools.py +240 -0
- moge/utils/vis.py +51 -0
- moge/utils/webfile.py +73 -0
- moge/utils/webzipfile.py +128 -0
- packages.txt +1 -0
- requirements.txt +5 -0
- utils3d/__init__.py +14 -0
- utils3d/io/__init__.py +4 -0
- utils3d/io/colmap.py +139 -0
- utils3d/io/glb.py +105 -0
- utils3d/io/ply.py +104 -0
- utils3d/io/wavefront_obj.py +146 -0
- utils3d/numpy/__init__.py +135 -0
- utils3d/numpy/_helpers.py +88 -0
- utils3d/numpy/mesh.py +355 -0
- utils3d/numpy/quadmesh.py +472 -0
- utils3d/numpy/rasterization.py +471 -0
.gitignore
ADDED
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
/data
|
2 |
+
/download
|
3 |
+
/extract
|
4 |
+
/view_point_cloud
|
5 |
+
/view_depth_map
|
6 |
+
/blobcache
|
7 |
+
/snapshot
|
8 |
+
/reference_embeddings
|
9 |
+
/.gradio
|
10 |
+
/debug
|
11 |
+
/workspace
|
12 |
+
/mlruns
|
13 |
+
/infer_output
|
14 |
+
/video_output
|
15 |
+
/eval_output
|
16 |
+
/.blobcache
|
17 |
+
/test_images
|
18 |
+
/test_videos
|
19 |
+
/vis
|
20 |
+
/videos
|
21 |
+
/raid
|
22 |
+
/blobmnt
|
23 |
+
/eval_dump
|
24 |
+
/pretrained
|
25 |
+
__pycache__/
|
app.py
ADDED
@@ -0,0 +1,111 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
import uuid
|
5 |
+
import tempfile
|
6 |
+
from typing import Union
|
7 |
+
import spaces
|
8 |
+
import atexit
|
9 |
+
from concurrent.futures import ThreadPoolExecutor
|
10 |
+
|
11 |
+
import gradio as gr
|
12 |
+
import cv2
|
13 |
+
import torch
|
14 |
+
import numpy as np
|
15 |
+
|
16 |
+
from moge.model import MoGeModel
|
17 |
+
from moge.utils.vis import colorize_depth
|
18 |
+
import utils3d
|
19 |
+
|
20 |
+
model = MoGeModel.from_pretrained('Ruicheng/moge-vitl').cuda().eval()
|
21 |
+
thread_pool_executor = ThreadPoolExecutor(max_workers=1)
|
22 |
+
|
23 |
+
|
24 |
+
def delete_later(path: Union[str, os.PathLike], delay: int = 300):
|
25 |
+
def _delete():
|
26 |
+
try:
|
27 |
+
os.remove(path)
|
28 |
+
except:
|
29 |
+
pass
|
30 |
+
def _wait_and_delete():
|
31 |
+
time.sleep(delay)
|
32 |
+
_delete(path)
|
33 |
+
thread_pool_executor.submit(_wait_and_delete)
|
34 |
+
atexit.register(_delete)
|
35 |
+
|
36 |
+
@spaces.GPU
|
37 |
+
def run(image: np.ndarray, remove_edge: bool = True):
|
38 |
+
run_id = str(uuid.uuid4())
|
39 |
+
|
40 |
+
larger_size = max(image.shape[:2])
|
41 |
+
if larger_size > 1024:
|
42 |
+
scale = 1024 / larger_size
|
43 |
+
image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA)
|
44 |
+
|
45 |
+
image_tensor = torch.tensor(image, dtype=torch.float32, device=torch.device('cuda')).permute(2, 0, 1) / 255
|
46 |
+
output = model.infer(image_tensor, resolution_level=9, apply_mask=True)
|
47 |
+
points, depth, mask = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy()
|
48 |
+
|
49 |
+
if remove_edge:
|
50 |
+
mask = mask & ~utils3d.numpy.depth_edge(depth, mask=mask, rtol=0.02)
|
51 |
+
mask = mask & (depth > 0)
|
52 |
+
|
53 |
+
_, faces, indices = utils3d.numpy.image_mesh(width=image.shape[1], height=image.shape[0], mask=mask)
|
54 |
+
faces = utils3d.numpy.triangulate(faces)
|
55 |
+
|
56 |
+
tempdir = Path(tempfile.gettempdir(), 'moge')
|
57 |
+
tempdir.mkdir(exist_ok=True)
|
58 |
+
|
59 |
+
output_glb_path = Path(tempdir, f'{run_id}.glb')
|
60 |
+
output_glb_path.parent.mkdir(exist_ok=True)
|
61 |
+
tempfile.TemporaryFile()
|
62 |
+
utils3d.io.write_glb(
|
63 |
+
output_glb_path,
|
64 |
+
vertices=points.reshape(-1, 3)[indices] * [-1, -1, 1],
|
65 |
+
faces=faces,
|
66 |
+
vertex_colors=image.reshape(-1, 3)[indices] / 255,
|
67 |
+
)
|
68 |
+
|
69 |
+
output_ply_path = Path(tempdir, f'{run_id}.ply')
|
70 |
+
output_ply_path.parent.mkdir(exist_ok=True)
|
71 |
+
utils3d.io.write_ply(
|
72 |
+
output_ply_path,
|
73 |
+
vertices=points.reshape(-1, 3)[indices] * [-1, -1, 1],
|
74 |
+
faces=faces,
|
75 |
+
vertex_colors=image.reshape(-1, 3)[indices] / 255,
|
76 |
+
)
|
77 |
+
|
78 |
+
colorized_depth = colorize_depth(depth)
|
79 |
+
|
80 |
+
delete_later(output_glb_path, delay=300)
|
81 |
+
delete_later(output_ply_path, delay=300)
|
82 |
+
|
83 |
+
return colorized_depth, output_glb_path, output_ply_path.as_posix()
|
84 |
+
|
85 |
+
|
86 |
+
DESCRIPTION = """
|
87 |
+
MoGe turns 2D images into 3D point maps.
|
88 |
+
|
89 |
+
NOTE:
|
90 |
+
* If the image is too large (> 1024px), it will be resized accordingly.
|
91 |
+
* The color in the 3D viewer may look dark due to rendering of 3D viewer. You may download the 3D model as .glb or .ply file to view it in other 3D viewers.
|
92 |
+
"""
|
93 |
+
|
94 |
+
if __name__ == '__main__':
|
95 |
+
|
96 |
+
gr.Interface(
|
97 |
+
fn=run,
|
98 |
+
inputs=[
|
99 |
+
gr.Image(type="numpy", image_mode="RGB"),
|
100 |
+
gr.Checkbox(True, label="Remove edges"),
|
101 |
+
],
|
102 |
+
outputs=[
|
103 |
+
gr.Image(type="numpy", label="Depth map (colorized)"),
|
104 |
+
gr.Model3D(display_mode="solid", clear_color=[1.0, 1.0, 1.0, 1.0], label="3D Viewer"),
|
105 |
+
gr.File(type="filepath", label="Download the model as .ply file"),
|
106 |
+
],
|
107 |
+
title="MoGe Live Demo",
|
108 |
+
description=DESCRIPTION,
|
109 |
+
clear_btn=None,
|
110 |
+
allow_flagging="never",
|
111 |
+
).launch(share=False)
|
moge/model/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .moge_model import MoGeModel
|
moge/model/dinov2/__init__.py
ADDED
@@ -0,0 +1,6 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
__version__ = "0.0.1"
|
moge/model/dinov2/hub/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
moge/model/dinov2/hub/backbones.py
ADDED
@@ -0,0 +1,156 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
from typing import Union
|
8 |
+
|
9 |
+
import torch
|
10 |
+
|
11 |
+
from .utils import _DINOV2_BASE_URL, _make_dinov2_model_name
|
12 |
+
|
13 |
+
|
14 |
+
class Weights(Enum):
|
15 |
+
LVD142M = "LVD142M"
|
16 |
+
|
17 |
+
|
18 |
+
def _make_dinov2_model(
|
19 |
+
*,
|
20 |
+
arch_name: str = "vit_large",
|
21 |
+
img_size: int = 518,
|
22 |
+
patch_size: int = 14,
|
23 |
+
init_values: float = 1.0,
|
24 |
+
ffn_layer: str = "mlp",
|
25 |
+
block_chunks: int = 0,
|
26 |
+
num_register_tokens: int = 0,
|
27 |
+
interpolate_antialias: bool = False,
|
28 |
+
interpolate_offset: float = 0.1,
|
29 |
+
pretrained: bool = True,
|
30 |
+
weights: Union[Weights, str] = Weights.LVD142M,
|
31 |
+
**kwargs,
|
32 |
+
):
|
33 |
+
from ..models import vision_transformer as vits
|
34 |
+
|
35 |
+
if isinstance(weights, str):
|
36 |
+
try:
|
37 |
+
weights = Weights[weights]
|
38 |
+
except KeyError:
|
39 |
+
raise AssertionError(f"Unsupported weights: {weights}")
|
40 |
+
|
41 |
+
model_base_name = _make_dinov2_model_name(arch_name, patch_size)
|
42 |
+
vit_kwargs = dict(
|
43 |
+
img_size=img_size,
|
44 |
+
patch_size=patch_size,
|
45 |
+
init_values=init_values,
|
46 |
+
ffn_layer=ffn_layer,
|
47 |
+
block_chunks=block_chunks,
|
48 |
+
num_register_tokens=num_register_tokens,
|
49 |
+
interpolate_antialias=interpolate_antialias,
|
50 |
+
interpolate_offset=interpolate_offset,
|
51 |
+
)
|
52 |
+
vit_kwargs.update(**kwargs)
|
53 |
+
model = vits.__dict__[arch_name](**vit_kwargs)
|
54 |
+
|
55 |
+
if pretrained:
|
56 |
+
model_full_name = _make_dinov2_model_name(arch_name, patch_size, num_register_tokens)
|
57 |
+
url = _DINOV2_BASE_URL + f"/{model_base_name}/{model_full_name}_pretrain.pth"
|
58 |
+
state_dict = torch.hub.load_state_dict_from_url(url, map_location="cpu")
|
59 |
+
model.load_state_dict(state_dict, strict=True)
|
60 |
+
|
61 |
+
return model
|
62 |
+
|
63 |
+
|
64 |
+
def dinov2_vits14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
65 |
+
"""
|
66 |
+
DINOv2 ViT-S/14 model (optionally) pretrained on the LVD-142M dataset.
|
67 |
+
"""
|
68 |
+
return _make_dinov2_model(arch_name="vit_small", pretrained=pretrained, weights=weights, **kwargs)
|
69 |
+
|
70 |
+
|
71 |
+
def dinov2_vitb14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
72 |
+
"""
|
73 |
+
DINOv2 ViT-B/14 model (optionally) pretrained on the LVD-142M dataset.
|
74 |
+
"""
|
75 |
+
return _make_dinov2_model(arch_name="vit_base", pretrained=pretrained, weights=weights, **kwargs)
|
76 |
+
|
77 |
+
|
78 |
+
def dinov2_vitl14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
79 |
+
"""
|
80 |
+
DINOv2 ViT-L/14 model (optionally) pretrained on the LVD-142M dataset.
|
81 |
+
"""
|
82 |
+
return _make_dinov2_model(arch_name="vit_large", pretrained=pretrained, weights=weights, **kwargs)
|
83 |
+
|
84 |
+
|
85 |
+
def dinov2_vitg14(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
86 |
+
"""
|
87 |
+
DINOv2 ViT-g/14 model (optionally) pretrained on the LVD-142M dataset.
|
88 |
+
"""
|
89 |
+
return _make_dinov2_model(
|
90 |
+
arch_name="vit_giant2",
|
91 |
+
ffn_layer="swiglufused",
|
92 |
+
weights=weights,
|
93 |
+
pretrained=pretrained,
|
94 |
+
**kwargs,
|
95 |
+
)
|
96 |
+
|
97 |
+
|
98 |
+
def dinov2_vits14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
99 |
+
"""
|
100 |
+
DINOv2 ViT-S/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
101 |
+
"""
|
102 |
+
return _make_dinov2_model(
|
103 |
+
arch_name="vit_small",
|
104 |
+
pretrained=pretrained,
|
105 |
+
weights=weights,
|
106 |
+
num_register_tokens=4,
|
107 |
+
interpolate_antialias=True,
|
108 |
+
interpolate_offset=0.0,
|
109 |
+
**kwargs,
|
110 |
+
)
|
111 |
+
|
112 |
+
|
113 |
+
def dinov2_vitb14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
114 |
+
"""
|
115 |
+
DINOv2 ViT-B/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
116 |
+
"""
|
117 |
+
return _make_dinov2_model(
|
118 |
+
arch_name="vit_base",
|
119 |
+
pretrained=pretrained,
|
120 |
+
weights=weights,
|
121 |
+
num_register_tokens=4,
|
122 |
+
interpolate_antialias=True,
|
123 |
+
interpolate_offset=0.0,
|
124 |
+
**kwargs,
|
125 |
+
)
|
126 |
+
|
127 |
+
|
128 |
+
def dinov2_vitl14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
129 |
+
"""
|
130 |
+
DINOv2 ViT-L/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
131 |
+
"""
|
132 |
+
return _make_dinov2_model(
|
133 |
+
arch_name="vit_large",
|
134 |
+
pretrained=pretrained,
|
135 |
+
weights=weights,
|
136 |
+
num_register_tokens=4,
|
137 |
+
interpolate_antialias=True,
|
138 |
+
interpolate_offset=0.0,
|
139 |
+
**kwargs,
|
140 |
+
)
|
141 |
+
|
142 |
+
|
143 |
+
def dinov2_vitg14_reg(*, pretrained: bool = True, weights: Union[Weights, str] = Weights.LVD142M, **kwargs):
|
144 |
+
"""
|
145 |
+
DINOv2 ViT-g/14 model with registers (optionally) pretrained on the LVD-142M dataset.
|
146 |
+
"""
|
147 |
+
return _make_dinov2_model(
|
148 |
+
arch_name="vit_giant2",
|
149 |
+
ffn_layer="swiglufused",
|
150 |
+
weights=weights,
|
151 |
+
pretrained=pretrained,
|
152 |
+
num_register_tokens=4,
|
153 |
+
interpolate_antialias=True,
|
154 |
+
interpolate_offset=0.0,
|
155 |
+
**kwargs,
|
156 |
+
)
|
moge/model/dinov2/hub/utils.py
ADDED
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import itertools
|
7 |
+
import math
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
_DINOV2_BASE_URL = "https://dl.fbaipublicfiles.com/dinov2"
|
15 |
+
|
16 |
+
|
17 |
+
def _make_dinov2_model_name(arch_name: str, patch_size: int, num_register_tokens: int = 0) -> str:
|
18 |
+
compact_arch_name = arch_name.replace("_", "")[:4]
|
19 |
+
registers_suffix = f"_reg{num_register_tokens}" if num_register_tokens else ""
|
20 |
+
return f"dinov2_{compact_arch_name}{patch_size}{registers_suffix}"
|
21 |
+
|
22 |
+
|
23 |
+
class CenterPadding(nn.Module):
|
24 |
+
def __init__(self, multiple):
|
25 |
+
super().__init__()
|
26 |
+
self.multiple = multiple
|
27 |
+
|
28 |
+
def _get_pad(self, size):
|
29 |
+
new_size = math.ceil(size / self.multiple) * self.multiple
|
30 |
+
pad_size = new_size - size
|
31 |
+
pad_size_left = pad_size // 2
|
32 |
+
pad_size_right = pad_size - pad_size_left
|
33 |
+
return pad_size_left, pad_size_right
|
34 |
+
|
35 |
+
@torch.inference_mode()
|
36 |
+
def forward(self, x):
|
37 |
+
pads = list(itertools.chain.from_iterable(self._get_pad(m) for m in x.shape[:1:-1]))
|
38 |
+
output = F.pad(x, pads)
|
39 |
+
return output
|
moge/model/dinov2/layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from .dino_head import DINOHead
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
moge/model/dinov2/layers/attention.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
import warnings
|
13 |
+
|
14 |
+
from torch import Tensor
|
15 |
+
from torch import nn
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger("dinov2")
|
19 |
+
|
20 |
+
|
21 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
22 |
+
try:
|
23 |
+
if XFORMERS_ENABLED:
|
24 |
+
from xformers.ops import memory_efficient_attention, unbind
|
25 |
+
|
26 |
+
XFORMERS_AVAILABLE = True
|
27 |
+
# warnings.warn("xFormers is available (Attention)")
|
28 |
+
else:
|
29 |
+
# warnings.warn("xFormers is disabled (Attention)")
|
30 |
+
raise ImportError
|
31 |
+
except ImportError:
|
32 |
+
XFORMERS_AVAILABLE = False
|
33 |
+
# warnings.warn("xFormers is not available (Attention)")
|
34 |
+
|
35 |
+
|
36 |
+
class Attention(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int = 8,
|
41 |
+
qkv_bias: bool = False,
|
42 |
+
proj_bias: bool = True,
|
43 |
+
attn_drop: float = 0.0,
|
44 |
+
proj_drop: float = 0.0,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
self.num_heads = num_heads
|
48 |
+
head_dim = dim // num_heads
|
49 |
+
self.scale = head_dim**-0.5
|
50 |
+
|
51 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
52 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
53 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
54 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
55 |
+
|
56 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
57 |
+
B, N, C = x.shape
|
58 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
59 |
+
|
60 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
61 |
+
attn = q @ k.transpose(-2, -1)
|
62 |
+
|
63 |
+
attn = attn.softmax(dim=-1)
|
64 |
+
attn = self.attn_drop(attn)
|
65 |
+
|
66 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
67 |
+
x = self.proj(x)
|
68 |
+
x = self.proj_drop(x)
|
69 |
+
return x
|
70 |
+
|
71 |
+
|
72 |
+
class MemEffAttention(Attention):
|
73 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
74 |
+
if not XFORMERS_AVAILABLE:
|
75 |
+
if attn_bias is not None:
|
76 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
77 |
+
return super().forward(x)
|
78 |
+
|
79 |
+
B, N, C = x.shape
|
80 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
81 |
+
|
82 |
+
q, k, v = unbind(qkv, 2)
|
83 |
+
|
84 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
85 |
+
x = x.reshape([B, N, C])
|
86 |
+
|
87 |
+
x = self.proj(x)
|
88 |
+
x = self.proj_drop(x)
|
89 |
+
return x
|
moge/model/dinov2/layers/block.py
ADDED
@@ -0,0 +1,259 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
import logging
|
11 |
+
import os
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
import warnings
|
14 |
+
|
15 |
+
import torch
|
16 |
+
from torch import nn, Tensor
|
17 |
+
|
18 |
+
from .attention import Attention, MemEffAttention
|
19 |
+
from .drop_path import DropPath
|
20 |
+
from .layer_scale import LayerScale
|
21 |
+
from .mlp import Mlp
|
22 |
+
|
23 |
+
|
24 |
+
logger = logging.getLogger("dinov2")
|
25 |
+
|
26 |
+
|
27 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
28 |
+
try:
|
29 |
+
if XFORMERS_ENABLED:
|
30 |
+
from xformers.ops import fmha, scaled_index_add, index_select_cat
|
31 |
+
|
32 |
+
XFORMERS_AVAILABLE = True
|
33 |
+
# warnings.warn("xFormers is available (Block)")
|
34 |
+
else:
|
35 |
+
# warnings.warn("xFormers is disabled (Block)")
|
36 |
+
raise ImportError
|
37 |
+
except ImportError:
|
38 |
+
XFORMERS_AVAILABLE = False
|
39 |
+
# warnings.warn("xFormers is not available (Block)")
|
40 |
+
|
41 |
+
|
42 |
+
class Block(nn.Module):
|
43 |
+
def __init__(
|
44 |
+
self,
|
45 |
+
dim: int,
|
46 |
+
num_heads: int,
|
47 |
+
mlp_ratio: float = 4.0,
|
48 |
+
qkv_bias: bool = False,
|
49 |
+
proj_bias: bool = True,
|
50 |
+
ffn_bias: bool = True,
|
51 |
+
drop: float = 0.0,
|
52 |
+
attn_drop: float = 0.0,
|
53 |
+
init_values=None,
|
54 |
+
drop_path: float = 0.0,
|
55 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
56 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
57 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
58 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
59 |
+
) -> None:
|
60 |
+
super().__init__()
|
61 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
62 |
+
self.norm1 = norm_layer(dim)
|
63 |
+
self.attn = attn_class(
|
64 |
+
dim,
|
65 |
+
num_heads=num_heads,
|
66 |
+
qkv_bias=qkv_bias,
|
67 |
+
proj_bias=proj_bias,
|
68 |
+
attn_drop=attn_drop,
|
69 |
+
proj_drop=drop,
|
70 |
+
)
|
71 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
72 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
73 |
+
|
74 |
+
self.norm2 = norm_layer(dim)
|
75 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
76 |
+
self.mlp = ffn_layer(
|
77 |
+
in_features=dim,
|
78 |
+
hidden_features=mlp_hidden_dim,
|
79 |
+
act_layer=act_layer,
|
80 |
+
drop=drop,
|
81 |
+
bias=ffn_bias,
|
82 |
+
)
|
83 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
84 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
85 |
+
|
86 |
+
self.sample_drop_ratio = drop_path
|
87 |
+
|
88 |
+
def forward(self, x: Tensor) -> Tensor:
|
89 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
90 |
+
return self.ls1(self.attn(self.norm1(x)))
|
91 |
+
|
92 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
93 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
94 |
+
|
95 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
96 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
97 |
+
x = drop_add_residual_stochastic_depth(
|
98 |
+
x,
|
99 |
+
residual_func=attn_residual_func,
|
100 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
101 |
+
)
|
102 |
+
x = drop_add_residual_stochastic_depth(
|
103 |
+
x,
|
104 |
+
residual_func=ffn_residual_func,
|
105 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
106 |
+
)
|
107 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
108 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
109 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
110 |
+
else:
|
111 |
+
x = x + attn_residual_func(x)
|
112 |
+
x = x + ffn_residual_func(x)
|
113 |
+
return x
|
114 |
+
|
115 |
+
|
116 |
+
def drop_add_residual_stochastic_depth(
|
117 |
+
x: Tensor,
|
118 |
+
residual_func: Callable[[Tensor], Tensor],
|
119 |
+
sample_drop_ratio: float = 0.0,
|
120 |
+
) -> Tensor:
|
121 |
+
# 1) extract subset using permutation
|
122 |
+
b, n, d = x.shape
|
123 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
124 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
125 |
+
x_subset = x[brange]
|
126 |
+
|
127 |
+
# 2) apply residual_func to get residual
|
128 |
+
residual = residual_func(x_subset)
|
129 |
+
|
130 |
+
x_flat = x.flatten(1)
|
131 |
+
residual = residual.flatten(1)
|
132 |
+
|
133 |
+
residual_scale_factor = b / sample_subset_size
|
134 |
+
|
135 |
+
# 3) add the residual
|
136 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
137 |
+
return x_plus_residual.view_as(x)
|
138 |
+
|
139 |
+
|
140 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
141 |
+
b, n, d = x.shape
|
142 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
143 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
144 |
+
residual_scale_factor = b / sample_subset_size
|
145 |
+
return brange, residual_scale_factor
|
146 |
+
|
147 |
+
|
148 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
149 |
+
if scaling_vector is None:
|
150 |
+
x_flat = x.flatten(1)
|
151 |
+
residual = residual.flatten(1)
|
152 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
153 |
+
else:
|
154 |
+
x_plus_residual = scaled_index_add(
|
155 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
156 |
+
)
|
157 |
+
return x_plus_residual
|
158 |
+
|
159 |
+
|
160 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
161 |
+
|
162 |
+
|
163 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
164 |
+
"""
|
165 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
166 |
+
"""
|
167 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
168 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
169 |
+
if all_shapes not in attn_bias_cache.keys():
|
170 |
+
seqlens = []
|
171 |
+
for b, x in zip(batch_sizes, x_list):
|
172 |
+
for _ in range(b):
|
173 |
+
seqlens.append(x.shape[1])
|
174 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
175 |
+
attn_bias._batch_sizes = batch_sizes
|
176 |
+
attn_bias_cache[all_shapes] = attn_bias
|
177 |
+
|
178 |
+
if branges is not None:
|
179 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
180 |
+
else:
|
181 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
182 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
183 |
+
|
184 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
185 |
+
|
186 |
+
|
187 |
+
def drop_add_residual_stochastic_depth_list(
|
188 |
+
x_list: List[Tensor],
|
189 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
190 |
+
sample_drop_ratio: float = 0.0,
|
191 |
+
scaling_vector=None,
|
192 |
+
) -> Tensor:
|
193 |
+
# 1) generate random set of indices for dropping samples in the batch
|
194 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
195 |
+
branges = [s[0] for s in branges_scales]
|
196 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
197 |
+
|
198 |
+
# 2) get attention bias and index+concat the tensors
|
199 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
200 |
+
|
201 |
+
# 3) apply residual_func to get residual, and split the result
|
202 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
203 |
+
|
204 |
+
outputs = []
|
205 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
206 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
207 |
+
return outputs
|
208 |
+
|
209 |
+
|
210 |
+
class NestedTensorBlock(Block):
|
211 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
212 |
+
"""
|
213 |
+
x_list contains a list of tensors to nest together and run
|
214 |
+
"""
|
215 |
+
assert isinstance(self.attn, MemEffAttention)
|
216 |
+
|
217 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
218 |
+
|
219 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
220 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
221 |
+
|
222 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
223 |
+
return self.mlp(self.norm2(x))
|
224 |
+
|
225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
+
x_list,
|
227 |
+
residual_func=attn_residual_func,
|
228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
230 |
+
)
|
231 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
232 |
+
x_list,
|
233 |
+
residual_func=ffn_residual_func,
|
234 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
235 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
236 |
+
)
|
237 |
+
return x_list
|
238 |
+
else:
|
239 |
+
|
240 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
241 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
242 |
+
|
243 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
244 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
245 |
+
|
246 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
247 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
248 |
+
x = x + ffn_residual_func(x)
|
249 |
+
return attn_bias.split(x)
|
250 |
+
|
251 |
+
def forward(self, x_or_x_list):
|
252 |
+
if isinstance(x_or_x_list, Tensor):
|
253 |
+
return super().forward(x_or_x_list)
|
254 |
+
elif isinstance(x_or_x_list, list):
|
255 |
+
if not XFORMERS_AVAILABLE:
|
256 |
+
raise AssertionError("xFormers is required for using nested tensors")
|
257 |
+
return self.forward_nested(x_or_x_list)
|
258 |
+
else:
|
259 |
+
raise AssertionError
|
moge/model/dinov2/layers/dino_head.py
ADDED
@@ -0,0 +1,58 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
from torch.nn.init import trunc_normal_
|
9 |
+
from torch.nn.utils import weight_norm
|
10 |
+
|
11 |
+
|
12 |
+
class DINOHead(nn.Module):
|
13 |
+
def __init__(
|
14 |
+
self,
|
15 |
+
in_dim,
|
16 |
+
out_dim,
|
17 |
+
use_bn=False,
|
18 |
+
nlayers=3,
|
19 |
+
hidden_dim=2048,
|
20 |
+
bottleneck_dim=256,
|
21 |
+
mlp_bias=True,
|
22 |
+
):
|
23 |
+
super().__init__()
|
24 |
+
nlayers = max(nlayers, 1)
|
25 |
+
self.mlp = _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=hidden_dim, use_bn=use_bn, bias=mlp_bias)
|
26 |
+
self.apply(self._init_weights)
|
27 |
+
self.last_layer = weight_norm(nn.Linear(bottleneck_dim, out_dim, bias=False))
|
28 |
+
self.last_layer.weight_g.data.fill_(1)
|
29 |
+
|
30 |
+
def _init_weights(self, m):
|
31 |
+
if isinstance(m, nn.Linear):
|
32 |
+
trunc_normal_(m.weight, std=0.02)
|
33 |
+
if isinstance(m, nn.Linear) and m.bias is not None:
|
34 |
+
nn.init.constant_(m.bias, 0)
|
35 |
+
|
36 |
+
def forward(self, x):
|
37 |
+
x = self.mlp(x)
|
38 |
+
eps = 1e-6 if x.dtype == torch.float16 else 1e-12
|
39 |
+
x = nn.functional.normalize(x, dim=-1, p=2, eps=eps)
|
40 |
+
x = self.last_layer(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
def _build_mlp(nlayers, in_dim, bottleneck_dim, hidden_dim=None, use_bn=False, bias=True):
|
45 |
+
if nlayers == 1:
|
46 |
+
return nn.Linear(in_dim, bottleneck_dim, bias=bias)
|
47 |
+
else:
|
48 |
+
layers = [nn.Linear(in_dim, hidden_dim, bias=bias)]
|
49 |
+
if use_bn:
|
50 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
51 |
+
layers.append(nn.GELU())
|
52 |
+
for _ in range(nlayers - 2):
|
53 |
+
layers.append(nn.Linear(hidden_dim, hidden_dim, bias=bias))
|
54 |
+
if use_bn:
|
55 |
+
layers.append(nn.BatchNorm1d(hidden_dim))
|
56 |
+
layers.append(nn.GELU())
|
57 |
+
layers.append(nn.Linear(hidden_dim, bottleneck_dim, bias=bias))
|
58 |
+
return nn.Sequential(*layers)
|
moge/model/dinov2/layers/drop_path.py
ADDED
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
9 |
+
|
10 |
+
|
11 |
+
from torch import nn
|
12 |
+
|
13 |
+
|
14 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
15 |
+
if drop_prob == 0.0 or not training:
|
16 |
+
return x
|
17 |
+
keep_prob = 1 - drop_prob
|
18 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
19 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
20 |
+
if keep_prob > 0.0:
|
21 |
+
random_tensor.div_(keep_prob)
|
22 |
+
output = x * random_tensor
|
23 |
+
return output
|
24 |
+
|
25 |
+
|
26 |
+
class DropPath(nn.Module):
|
27 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
28 |
+
|
29 |
+
def __init__(self, drop_prob=None):
|
30 |
+
super(DropPath, self).__init__()
|
31 |
+
self.drop_prob = drop_prob
|
32 |
+
|
33 |
+
def forward(self, x):
|
34 |
+
return drop_path(x, self.drop_prob, self.training)
|
moge/model/dinov2/layers/layer_scale.py
ADDED
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
7 |
+
|
8 |
+
from typing import Union
|
9 |
+
|
10 |
+
import torch
|
11 |
+
from torch import Tensor
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
class LayerScale(nn.Module):
|
16 |
+
def __init__(
|
17 |
+
self,
|
18 |
+
dim: int,
|
19 |
+
init_values: Union[float, Tensor] = 1e-5,
|
20 |
+
inplace: bool = False,
|
21 |
+
) -> None:
|
22 |
+
super().__init__()
|
23 |
+
self.inplace = inplace
|
24 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
25 |
+
|
26 |
+
def forward(self, x: Tensor) -> Tensor:
|
27 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
moge/model/dinov2/layers/mlp.py
ADDED
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
9 |
+
|
10 |
+
|
11 |
+
from typing import Callable, Optional
|
12 |
+
|
13 |
+
from torch import Tensor, nn
|
14 |
+
|
15 |
+
|
16 |
+
class Mlp(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
in_features: int,
|
20 |
+
hidden_features: Optional[int] = None,
|
21 |
+
out_features: Optional[int] = None,
|
22 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
23 |
+
drop: float = 0.0,
|
24 |
+
bias: bool = True,
|
25 |
+
) -> None:
|
26 |
+
super().__init__()
|
27 |
+
out_features = out_features or in_features
|
28 |
+
hidden_features = hidden_features or in_features
|
29 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
30 |
+
self.act = act_layer()
|
31 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
32 |
+
self.drop = nn.Dropout(drop)
|
33 |
+
|
34 |
+
def forward(self, x: Tensor) -> Tensor:
|
35 |
+
x = self.fc1(x)
|
36 |
+
x = self.act(x)
|
37 |
+
x = self.drop(x)
|
38 |
+
x = self.fc2(x)
|
39 |
+
x = self.drop(x)
|
40 |
+
return x
|
moge/model/dinov2/layers/patch_embed.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
9 |
+
|
10 |
+
from typing import Callable, Optional, Tuple, Union
|
11 |
+
|
12 |
+
from torch import Tensor
|
13 |
+
import torch.nn as nn
|
14 |
+
|
15 |
+
|
16 |
+
def make_2tuple(x):
|
17 |
+
if isinstance(x, tuple):
|
18 |
+
assert len(x) == 2
|
19 |
+
return x
|
20 |
+
|
21 |
+
assert isinstance(x, int)
|
22 |
+
return (x, x)
|
23 |
+
|
24 |
+
|
25 |
+
class PatchEmbed(nn.Module):
|
26 |
+
"""
|
27 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
28 |
+
|
29 |
+
Args:
|
30 |
+
img_size: Image size.
|
31 |
+
patch_size: Patch token size.
|
32 |
+
in_chans: Number of input image channels.
|
33 |
+
embed_dim: Number of linear projection output channels.
|
34 |
+
norm_layer: Normalization layer.
|
35 |
+
"""
|
36 |
+
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
40 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
41 |
+
in_chans: int = 3,
|
42 |
+
embed_dim: int = 768,
|
43 |
+
norm_layer: Optional[Callable] = None,
|
44 |
+
flatten_embedding: bool = True,
|
45 |
+
) -> None:
|
46 |
+
super().__init__()
|
47 |
+
|
48 |
+
image_HW = make_2tuple(img_size)
|
49 |
+
patch_HW = make_2tuple(patch_size)
|
50 |
+
patch_grid_size = (
|
51 |
+
image_HW[0] // patch_HW[0],
|
52 |
+
image_HW[1] // patch_HW[1],
|
53 |
+
)
|
54 |
+
|
55 |
+
self.img_size = image_HW
|
56 |
+
self.patch_size = patch_HW
|
57 |
+
self.patches_resolution = patch_grid_size
|
58 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
59 |
+
|
60 |
+
self.in_chans = in_chans
|
61 |
+
self.embed_dim = embed_dim
|
62 |
+
|
63 |
+
self.flatten_embedding = flatten_embedding
|
64 |
+
|
65 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
66 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
67 |
+
|
68 |
+
def forward(self, x: Tensor) -> Tensor:
|
69 |
+
_, _, H, W = x.shape
|
70 |
+
patch_H, patch_W = self.patch_size
|
71 |
+
|
72 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
73 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
74 |
+
|
75 |
+
x = self.proj(x) # B C H W
|
76 |
+
H, W = x.size(2), x.size(3)
|
77 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
78 |
+
x = self.norm(x)
|
79 |
+
if not self.flatten_embedding:
|
80 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
81 |
+
return x
|
82 |
+
|
83 |
+
def flops(self) -> float:
|
84 |
+
Ho, Wo = self.patches_resolution
|
85 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
86 |
+
if self.norm is not None:
|
87 |
+
flops += Ho * Wo * self.embed_dim
|
88 |
+
return flops
|
moge/model/dinov2/layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import os
|
7 |
+
from typing import Callable, Optional
|
8 |
+
import warnings
|
9 |
+
|
10 |
+
from torch import Tensor, nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
|
13 |
+
|
14 |
+
class SwiGLUFFN(nn.Module):
|
15 |
+
def __init__(
|
16 |
+
self,
|
17 |
+
in_features: int,
|
18 |
+
hidden_features: Optional[int] = None,
|
19 |
+
out_features: Optional[int] = None,
|
20 |
+
act_layer: Callable[..., nn.Module] = None,
|
21 |
+
drop: float = 0.0,
|
22 |
+
bias: bool = True,
|
23 |
+
) -> None:
|
24 |
+
super().__init__()
|
25 |
+
out_features = out_features or in_features
|
26 |
+
hidden_features = hidden_features or in_features
|
27 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
28 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
29 |
+
|
30 |
+
def forward(self, x: Tensor) -> Tensor:
|
31 |
+
x12 = self.w12(x)
|
32 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
33 |
+
hidden = F.silu(x1) * x2
|
34 |
+
return self.w3(hidden)
|
35 |
+
|
36 |
+
|
37 |
+
XFORMERS_ENABLED = os.environ.get("XFORMERS_DISABLED") is None
|
38 |
+
try:
|
39 |
+
if XFORMERS_ENABLED:
|
40 |
+
from xformers.ops import SwiGLU
|
41 |
+
|
42 |
+
XFORMERS_AVAILABLE = True
|
43 |
+
# warnings.warn("xFormers is available (SwiGLU)")
|
44 |
+
else:
|
45 |
+
# warnings.warn("xFormers is disabled (SwiGLU)")
|
46 |
+
raise ImportError
|
47 |
+
except ImportError:
|
48 |
+
SwiGLU = SwiGLUFFN
|
49 |
+
XFORMERS_AVAILABLE = False
|
50 |
+
|
51 |
+
# warnings.warn("xFormers is not available (SwiGLU)")
|
52 |
+
|
53 |
+
|
54 |
+
class SwiGLUFFNFused(SwiGLU):
|
55 |
+
def __init__(
|
56 |
+
self,
|
57 |
+
in_features: int,
|
58 |
+
hidden_features: Optional[int] = None,
|
59 |
+
out_features: Optional[int] = None,
|
60 |
+
act_layer: Callable[..., nn.Module] = None,
|
61 |
+
drop: float = 0.0,
|
62 |
+
bias: bool = True,
|
63 |
+
) -> None:
|
64 |
+
out_features = out_features or in_features
|
65 |
+
hidden_features = hidden_features or in_features
|
66 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
67 |
+
super().__init__(
|
68 |
+
in_features=in_features,
|
69 |
+
hidden_features=hidden_features,
|
70 |
+
out_features=out_features,
|
71 |
+
bias=bias,
|
72 |
+
)
|
moge/model/dinov2/models/__init__.py
ADDED
@@ -0,0 +1,43 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
|
8 |
+
from . import vision_transformer as vits
|
9 |
+
|
10 |
+
|
11 |
+
logger = logging.getLogger("dinov2")
|
12 |
+
|
13 |
+
|
14 |
+
def build_model(args, only_teacher=False, img_size=224):
|
15 |
+
args.arch = args.arch.removesuffix("_memeff")
|
16 |
+
if "vit" in args.arch:
|
17 |
+
vit_kwargs = dict(
|
18 |
+
img_size=img_size,
|
19 |
+
patch_size=args.patch_size,
|
20 |
+
init_values=args.layerscale,
|
21 |
+
ffn_layer=args.ffn_layer,
|
22 |
+
block_chunks=args.block_chunks,
|
23 |
+
qkv_bias=args.qkv_bias,
|
24 |
+
proj_bias=args.proj_bias,
|
25 |
+
ffn_bias=args.ffn_bias,
|
26 |
+
num_register_tokens=args.num_register_tokens,
|
27 |
+
interpolate_offset=args.interpolate_offset,
|
28 |
+
interpolate_antialias=args.interpolate_antialias,
|
29 |
+
)
|
30 |
+
teacher = vits.__dict__[args.arch](**vit_kwargs)
|
31 |
+
if only_teacher:
|
32 |
+
return teacher, teacher.embed_dim
|
33 |
+
student = vits.__dict__[args.arch](
|
34 |
+
**vit_kwargs,
|
35 |
+
drop_path_rate=args.drop_path_rate,
|
36 |
+
drop_path_uniform=args.drop_path_uniform,
|
37 |
+
)
|
38 |
+
embed_dim = student.embed_dim
|
39 |
+
return student, teacher, embed_dim
|
40 |
+
|
41 |
+
|
42 |
+
def build_model_from_cfg(cfg, only_teacher=False):
|
43 |
+
return build_model(cfg.student, only_teacher=only_teacher, img_size=cfg.crops.global_crops_size)
|
moge/model/dinov2/models/vision_transformer.py
ADDED
@@ -0,0 +1,396 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from ..layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
27 |
+
if not depth_first and include_root:
|
28 |
+
fn(module=module, name=name)
|
29 |
+
for child_name, child_module in module.named_children():
|
30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
32 |
+
if depth_first and include_root:
|
33 |
+
fn(module=module, name=name)
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
class BlockChunk(nn.ModuleList):
|
38 |
+
def forward(self, x):
|
39 |
+
for b in self:
|
40 |
+
x = b(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class DinoVisionTransformer(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
img_size=224,
|
48 |
+
patch_size=16,
|
49 |
+
in_chans=3,
|
50 |
+
embed_dim=768,
|
51 |
+
depth=12,
|
52 |
+
num_heads=12,
|
53 |
+
mlp_ratio=4.0,
|
54 |
+
qkv_bias=True,
|
55 |
+
ffn_bias=True,
|
56 |
+
proj_bias=True,
|
57 |
+
drop_path_rate=0.0,
|
58 |
+
drop_path_uniform=False,
|
59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
60 |
+
embed_layer=PatchEmbed,
|
61 |
+
act_layer=nn.GELU,
|
62 |
+
block_fn=Block,
|
63 |
+
ffn_layer="mlp",
|
64 |
+
block_chunks=1,
|
65 |
+
num_register_tokens=0,
|
66 |
+
interpolate_antialias=False,
|
67 |
+
interpolate_offset=0.1,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
img_size (int, tuple): input image size
|
72 |
+
patch_size (int, tuple): patch size
|
73 |
+
in_chans (int): number of input channels
|
74 |
+
embed_dim (int): embedding dimension
|
75 |
+
depth (int): depth of transformer
|
76 |
+
num_heads (int): number of attention heads
|
77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
78 |
+
qkv_bias (bool): enable bias for qkv if True
|
79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
80 |
+
ffn_bias (bool): enable bias for ffn if True
|
81 |
+
drop_path_rate (float): stochastic depth rate
|
82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
83 |
+
weight_init (str): weight init scheme
|
84 |
+
init_values (float): layer-scale init values
|
85 |
+
embed_layer (nn.Module): patch embedding layer
|
86 |
+
act_layer (nn.Module): MLP activation layer
|
87 |
+
block_fn (nn.Module): transformer block class
|
88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
96 |
+
|
97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
98 |
+
self.num_tokens = 1
|
99 |
+
self.n_blocks = depth
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.patch_size = patch_size
|
102 |
+
self.num_register_tokens = num_register_tokens
|
103 |
+
self.interpolate_antialias = interpolate_antialias
|
104 |
+
self.interpolate_offset = interpolate_offset
|
105 |
+
|
106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
107 |
+
num_patches = self.patch_embed.num_patches
|
108 |
+
|
109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
111 |
+
assert num_register_tokens >= 0
|
112 |
+
self.register_tokens = (
|
113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
114 |
+
)
|
115 |
+
|
116 |
+
if drop_path_uniform is True:
|
117 |
+
dpr = [drop_path_rate] * depth
|
118 |
+
else:
|
119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
120 |
+
|
121 |
+
if ffn_layer == "mlp":
|
122 |
+
logger.info("using MLP layer as FFN")
|
123 |
+
ffn_layer = Mlp
|
124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
125 |
+
logger.info("using SwiGLU layer as FFN")
|
126 |
+
ffn_layer = SwiGLUFFNFused
|
127 |
+
elif ffn_layer == "identity":
|
128 |
+
logger.info("using Identity layer as FFN")
|
129 |
+
|
130 |
+
def f(*args, **kwargs):
|
131 |
+
return nn.Identity()
|
132 |
+
|
133 |
+
ffn_layer = f
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
+
|
137 |
+
blocks_list = [
|
138 |
+
block_fn(
|
139 |
+
dim=embed_dim,
|
140 |
+
num_heads=num_heads,
|
141 |
+
mlp_ratio=mlp_ratio,
|
142 |
+
qkv_bias=qkv_bias,
|
143 |
+
proj_bias=proj_bias,
|
144 |
+
ffn_bias=ffn_bias,
|
145 |
+
drop_path=dpr[i],
|
146 |
+
norm_layer=norm_layer,
|
147 |
+
act_layer=act_layer,
|
148 |
+
ffn_layer=ffn_layer,
|
149 |
+
init_values=init_values,
|
150 |
+
)
|
151 |
+
for i in range(depth)
|
152 |
+
]
|
153 |
+
if block_chunks > 0:
|
154 |
+
self.chunked_blocks = True
|
155 |
+
chunked_blocks = []
|
156 |
+
chunksize = depth // block_chunks
|
157 |
+
for i in range(0, depth, chunksize):
|
158 |
+
# this is to keep the block index consistent if we chunk the block list
|
159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
161 |
+
else:
|
162 |
+
self.chunked_blocks = False
|
163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
164 |
+
|
165 |
+
self.norm = norm_layer(embed_dim)
|
166 |
+
self.head = nn.Identity()
|
167 |
+
|
168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
169 |
+
|
170 |
+
self.init_weights()
|
171 |
+
|
172 |
+
def init_weights(self):
|
173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
175 |
+
if self.register_tokens is not None:
|
176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
177 |
+
named_apply(init_weights_vit_timm, self)
|
178 |
+
|
179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
180 |
+
previous_dtype = x.dtype
|
181 |
+
npatch = x.shape[1] - 1
|
182 |
+
N = self.pos_embed.shape[1] - 1
|
183 |
+
if npatch == N and w == h:
|
184 |
+
return self.pos_embed
|
185 |
+
pos_embed = self.pos_embed.float()
|
186 |
+
class_pos_embed = pos_embed[:, 0]
|
187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
188 |
+
dim = x.shape[-1]
|
189 |
+
w0 = w // self.patch_size
|
190 |
+
h0 = h // self.patch_size
|
191 |
+
M = int(math.sqrt(N)) # Recover the number of patches in each dimension
|
192 |
+
assert N == M * M
|
193 |
+
kwargs = {}
|
194 |
+
if self.interpolate_offset:
|
195 |
+
# Historical kludge: add a small number to avoid floating point error in the interpolation, see https://github.com/facebookresearch/dino/issues/8
|
196 |
+
# Note: still needed for backward-compatibility, the underlying operators are using both output size and scale factors
|
197 |
+
sx = float(w0 + self.interpolate_offset) / M
|
198 |
+
sy = float(h0 + self.interpolate_offset) / M
|
199 |
+
kwargs["scale_factor"] = (sx, sy)
|
200 |
+
else:
|
201 |
+
# Simply specify an output size instead of a scale factor
|
202 |
+
kwargs["size"] = (w0, h0)
|
203 |
+
patch_pos_embed = nn.functional.interpolate(
|
204 |
+
patch_pos_embed.reshape(1, M, M, dim).permute(0, 3, 1, 2),
|
205 |
+
mode="bicubic",
|
206 |
+
antialias=self.interpolate_antialias,
|
207 |
+
**kwargs,
|
208 |
+
)
|
209 |
+
assert (w0, h0) == patch_pos_embed.shape[-2:]
|
210 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
211 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
212 |
+
|
213 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
214 |
+
B, nc, w, h = x.shape
|
215 |
+
x = self.patch_embed(x)
|
216 |
+
if masks is not None:
|
217 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
218 |
+
|
219 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
220 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
221 |
+
|
222 |
+
if self.register_tokens is not None:
|
223 |
+
x = torch.cat(
|
224 |
+
(
|
225 |
+
x[:, :1],
|
226 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
227 |
+
x[:, 1:],
|
228 |
+
),
|
229 |
+
dim=1,
|
230 |
+
)
|
231 |
+
|
232 |
+
return x
|
233 |
+
|
234 |
+
def forward_features_list(self, x_list, masks_list):
|
235 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
236 |
+
for blk in self.blocks:
|
237 |
+
x = blk(x)
|
238 |
+
|
239 |
+
all_x = x
|
240 |
+
output = []
|
241 |
+
for x, masks in zip(all_x, masks_list):
|
242 |
+
x_norm = self.norm(x)
|
243 |
+
output.append(
|
244 |
+
{
|
245 |
+
"x_norm_clstoken": x_norm[:, 0],
|
246 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
247 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
248 |
+
"x_prenorm": x,
|
249 |
+
"masks": masks,
|
250 |
+
}
|
251 |
+
)
|
252 |
+
return output
|
253 |
+
|
254 |
+
def forward_features(self, x, masks=None):
|
255 |
+
if isinstance(x, list):
|
256 |
+
return self.forward_features_list(x, masks)
|
257 |
+
|
258 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
259 |
+
|
260 |
+
for blk in self.blocks:
|
261 |
+
x = blk(x)
|
262 |
+
|
263 |
+
x_norm = self.norm(x)
|
264 |
+
return {
|
265 |
+
"x_norm_clstoken": x_norm[:, 0],
|
266 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
267 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
268 |
+
"x_prenorm": x,
|
269 |
+
"masks": masks,
|
270 |
+
}
|
271 |
+
|
272 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
273 |
+
x = self.prepare_tokens_with_masks(x)
|
274 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
275 |
+
output, total_block_len = [], len(self.blocks)
|
276 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
277 |
+
for i, blk in enumerate(self.blocks):
|
278 |
+
x = blk(x)
|
279 |
+
if i in blocks_to_take:
|
280 |
+
output.append(x)
|
281 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
282 |
+
return output
|
283 |
+
|
284 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
285 |
+
x = self.prepare_tokens_with_masks(x)
|
286 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
287 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
288 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
289 |
+
for block_chunk in self.blocks:
|
290 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
291 |
+
x = blk(x)
|
292 |
+
if i in blocks_to_take:
|
293 |
+
output.append(x)
|
294 |
+
i += 1
|
295 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
296 |
+
return output
|
297 |
+
|
298 |
+
def get_intermediate_layers(
|
299 |
+
self,
|
300 |
+
x: torch.Tensor,
|
301 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
302 |
+
reshape: bool = False,
|
303 |
+
return_class_token: bool = False,
|
304 |
+
norm=True,
|
305 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
306 |
+
if self.chunked_blocks:
|
307 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
308 |
+
else:
|
309 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
310 |
+
if norm:
|
311 |
+
outputs = [self.norm(out) for out in outputs]
|
312 |
+
class_tokens = [out[:, 0] for out in outputs]
|
313 |
+
outputs = [out[:, 1 + self.num_register_tokens :] for out in outputs]
|
314 |
+
if reshape:
|
315 |
+
B, _, w, h = x.shape
|
316 |
+
outputs = [
|
317 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
318 |
+
for out in outputs
|
319 |
+
]
|
320 |
+
if return_class_token:
|
321 |
+
return tuple(zip(outputs, class_tokens))
|
322 |
+
return tuple(outputs)
|
323 |
+
|
324 |
+
def forward(self, *args, is_training=False, **kwargs):
|
325 |
+
ret = self.forward_features(*args, **kwargs)
|
326 |
+
if is_training:
|
327 |
+
return ret
|
328 |
+
else:
|
329 |
+
return self.head(ret["x_norm_clstoken"])
|
330 |
+
|
331 |
+
|
332 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
333 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
334 |
+
if isinstance(module, nn.Linear):
|
335 |
+
trunc_normal_(module.weight, std=0.02)
|
336 |
+
if module.bias is not None:
|
337 |
+
nn.init.zeros_(module.bias)
|
338 |
+
|
339 |
+
|
340 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
341 |
+
model = DinoVisionTransformer(
|
342 |
+
patch_size=patch_size,
|
343 |
+
embed_dim=384,
|
344 |
+
depth=12,
|
345 |
+
num_heads=6,
|
346 |
+
mlp_ratio=4,
|
347 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
348 |
+
num_register_tokens=num_register_tokens,
|
349 |
+
**kwargs,
|
350 |
+
)
|
351 |
+
return model
|
352 |
+
|
353 |
+
|
354 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
355 |
+
model = DinoVisionTransformer(
|
356 |
+
patch_size=patch_size,
|
357 |
+
embed_dim=768,
|
358 |
+
depth=12,
|
359 |
+
num_heads=12,
|
360 |
+
mlp_ratio=4,
|
361 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
362 |
+
num_register_tokens=num_register_tokens,
|
363 |
+
**kwargs,
|
364 |
+
)
|
365 |
+
return model
|
366 |
+
|
367 |
+
|
368 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
369 |
+
model = DinoVisionTransformer(
|
370 |
+
patch_size=patch_size,
|
371 |
+
embed_dim=1024,
|
372 |
+
depth=24,
|
373 |
+
num_heads=16,
|
374 |
+
mlp_ratio=4,
|
375 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
376 |
+
num_register_tokens=num_register_tokens,
|
377 |
+
**kwargs,
|
378 |
+
)
|
379 |
+
return model
|
380 |
+
|
381 |
+
|
382 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
383 |
+
"""
|
384 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
385 |
+
"""
|
386 |
+
model = DinoVisionTransformer(
|
387 |
+
patch_size=patch_size,
|
388 |
+
embed_dim=1536,
|
389 |
+
depth=40,
|
390 |
+
num_heads=24,
|
391 |
+
mlp_ratio=4,
|
392 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
393 |
+
num_register_tokens=num_register_tokens,
|
394 |
+
**kwargs,
|
395 |
+
)
|
396 |
+
return model
|
moge/model/dinov2/utils/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
moge/model/dinov2/utils/cluster.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from enum import Enum
|
7 |
+
import os
|
8 |
+
from pathlib import Path
|
9 |
+
from typing import Any, Dict, Optional
|
10 |
+
|
11 |
+
|
12 |
+
class ClusterType(Enum):
|
13 |
+
AWS = "aws"
|
14 |
+
FAIR = "fair"
|
15 |
+
RSC = "rsc"
|
16 |
+
|
17 |
+
|
18 |
+
def _guess_cluster_type() -> ClusterType:
|
19 |
+
uname = os.uname()
|
20 |
+
if uname.sysname == "Linux":
|
21 |
+
if uname.release.endswith("-aws"):
|
22 |
+
# Linux kernel versions on AWS instances are of the form "5.4.0-1051-aws"
|
23 |
+
return ClusterType.AWS
|
24 |
+
elif uname.nodename.startswith("rsc"):
|
25 |
+
# Linux kernel versions on RSC instances are standard ones but hostnames start with "rsc"
|
26 |
+
return ClusterType.RSC
|
27 |
+
|
28 |
+
return ClusterType.FAIR
|
29 |
+
|
30 |
+
|
31 |
+
def get_cluster_type(cluster_type: Optional[ClusterType] = None) -> Optional[ClusterType]:
|
32 |
+
if cluster_type is None:
|
33 |
+
return _guess_cluster_type()
|
34 |
+
|
35 |
+
return cluster_type
|
36 |
+
|
37 |
+
|
38 |
+
def get_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
39 |
+
cluster_type = get_cluster_type(cluster_type)
|
40 |
+
if cluster_type is None:
|
41 |
+
return None
|
42 |
+
|
43 |
+
CHECKPOINT_DIRNAMES = {
|
44 |
+
ClusterType.AWS: "checkpoints",
|
45 |
+
ClusterType.FAIR: "checkpoint",
|
46 |
+
ClusterType.RSC: "checkpoint/dino",
|
47 |
+
}
|
48 |
+
return Path("/") / CHECKPOINT_DIRNAMES[cluster_type]
|
49 |
+
|
50 |
+
|
51 |
+
def get_user_checkpoint_path(cluster_type: Optional[ClusterType] = None) -> Optional[Path]:
|
52 |
+
checkpoint_path = get_checkpoint_path(cluster_type)
|
53 |
+
if checkpoint_path is None:
|
54 |
+
return None
|
55 |
+
|
56 |
+
username = os.environ.get("USER")
|
57 |
+
assert username is not None
|
58 |
+
return checkpoint_path / username
|
59 |
+
|
60 |
+
|
61 |
+
def get_slurm_partition(cluster_type: Optional[ClusterType] = None) -> Optional[str]:
|
62 |
+
cluster_type = get_cluster_type(cluster_type)
|
63 |
+
if cluster_type is None:
|
64 |
+
return None
|
65 |
+
|
66 |
+
SLURM_PARTITIONS = {
|
67 |
+
ClusterType.AWS: "learnlab",
|
68 |
+
ClusterType.FAIR: "learnlab",
|
69 |
+
ClusterType.RSC: "learn",
|
70 |
+
}
|
71 |
+
return SLURM_PARTITIONS[cluster_type]
|
72 |
+
|
73 |
+
|
74 |
+
def get_slurm_executor_parameters(
|
75 |
+
nodes: int, num_gpus_per_node: int, cluster_type: Optional[ClusterType] = None, **kwargs
|
76 |
+
) -> Dict[str, Any]:
|
77 |
+
# create default parameters
|
78 |
+
params = {
|
79 |
+
"mem_gb": 0, # Requests all memory on a node, see https://slurm.schedmd.com/sbatch.html
|
80 |
+
"gpus_per_node": num_gpus_per_node,
|
81 |
+
"tasks_per_node": num_gpus_per_node, # one task per GPU
|
82 |
+
"cpus_per_task": 10,
|
83 |
+
"nodes": nodes,
|
84 |
+
"slurm_partition": get_slurm_partition(cluster_type),
|
85 |
+
}
|
86 |
+
# apply cluster-specific adjustments
|
87 |
+
cluster_type = get_cluster_type(cluster_type)
|
88 |
+
if cluster_type == ClusterType.AWS:
|
89 |
+
params["cpus_per_task"] = 12
|
90 |
+
del params["mem_gb"]
|
91 |
+
elif cluster_type == ClusterType.RSC:
|
92 |
+
params["cpus_per_task"] = 12
|
93 |
+
# set additional parameters / apply overrides
|
94 |
+
params.update(kwargs)
|
95 |
+
return params
|
moge/model/dinov2/utils/config.py
ADDED
@@ -0,0 +1,72 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import math
|
7 |
+
import logging
|
8 |
+
import os
|
9 |
+
|
10 |
+
from omegaconf import OmegaConf
|
11 |
+
|
12 |
+
import dinov2.distributed as distributed
|
13 |
+
from dinov2.logging import setup_logging
|
14 |
+
from dinov2.utils import utils
|
15 |
+
from dinov2.configs import dinov2_default_config
|
16 |
+
|
17 |
+
|
18 |
+
logger = logging.getLogger("dinov2")
|
19 |
+
|
20 |
+
|
21 |
+
def apply_scaling_rules_to_cfg(cfg): # to fix
|
22 |
+
if cfg.optim.scaling_rule == "sqrt_wrt_1024":
|
23 |
+
base_lr = cfg.optim.base_lr
|
24 |
+
cfg.optim.lr = base_lr
|
25 |
+
cfg.optim.lr *= math.sqrt(cfg.train.batch_size_per_gpu * distributed.get_global_size() / 1024.0)
|
26 |
+
logger.info(f"sqrt scaling learning rate; base: {base_lr}, new: {cfg.optim.lr}")
|
27 |
+
else:
|
28 |
+
raise NotImplementedError
|
29 |
+
return cfg
|
30 |
+
|
31 |
+
|
32 |
+
def write_config(cfg, output_dir, name="config.yaml"):
|
33 |
+
logger.info(OmegaConf.to_yaml(cfg))
|
34 |
+
saved_cfg_path = os.path.join(output_dir, name)
|
35 |
+
with open(saved_cfg_path, "w") as f:
|
36 |
+
OmegaConf.save(config=cfg, f=f)
|
37 |
+
return saved_cfg_path
|
38 |
+
|
39 |
+
|
40 |
+
def get_cfg_from_args(args):
|
41 |
+
args.output_dir = os.path.abspath(args.output_dir)
|
42 |
+
args.opts += [f"train.output_dir={args.output_dir}"]
|
43 |
+
default_cfg = OmegaConf.create(dinov2_default_config)
|
44 |
+
cfg = OmegaConf.load(args.config_file)
|
45 |
+
cfg = OmegaConf.merge(default_cfg, cfg, OmegaConf.from_cli(args.opts))
|
46 |
+
return cfg
|
47 |
+
|
48 |
+
|
49 |
+
def default_setup(args):
|
50 |
+
distributed.enable(overwrite=True)
|
51 |
+
seed = getattr(args, "seed", 0)
|
52 |
+
rank = distributed.get_global_rank()
|
53 |
+
|
54 |
+
global logger
|
55 |
+
setup_logging(output=args.output_dir, level=logging.INFO)
|
56 |
+
logger = logging.getLogger("dinov2")
|
57 |
+
|
58 |
+
utils.fix_random_seeds(seed + rank)
|
59 |
+
logger.info("git:\n {}\n".format(utils.get_sha()))
|
60 |
+
logger.info("\n".join("%s: %s" % (k, str(v)) for k, v in sorted(dict(vars(args)).items())))
|
61 |
+
|
62 |
+
|
63 |
+
def setup(args):
|
64 |
+
"""
|
65 |
+
Create configs and perform basic setups.
|
66 |
+
"""
|
67 |
+
cfg = get_cfg_from_args(args)
|
68 |
+
os.makedirs(args.output_dir, exist_ok=True)
|
69 |
+
default_setup(args)
|
70 |
+
apply_scaling_rules_to_cfg(cfg)
|
71 |
+
write_config(cfg, args.output_dir)
|
72 |
+
return cfg
|
moge/model/dinov2/utils/dtype.py
ADDED
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
|
7 |
+
from typing import Dict, Union
|
8 |
+
|
9 |
+
import numpy as np
|
10 |
+
import torch
|
11 |
+
|
12 |
+
|
13 |
+
TypeSpec = Union[str, np.dtype, torch.dtype]
|
14 |
+
|
15 |
+
|
16 |
+
_NUMPY_TO_TORCH_DTYPE: Dict[np.dtype, torch.dtype] = {
|
17 |
+
np.dtype("bool"): torch.bool,
|
18 |
+
np.dtype("uint8"): torch.uint8,
|
19 |
+
np.dtype("int8"): torch.int8,
|
20 |
+
np.dtype("int16"): torch.int16,
|
21 |
+
np.dtype("int32"): torch.int32,
|
22 |
+
np.dtype("int64"): torch.int64,
|
23 |
+
np.dtype("float16"): torch.float16,
|
24 |
+
np.dtype("float32"): torch.float32,
|
25 |
+
np.dtype("float64"): torch.float64,
|
26 |
+
np.dtype("complex64"): torch.complex64,
|
27 |
+
np.dtype("complex128"): torch.complex128,
|
28 |
+
}
|
29 |
+
|
30 |
+
|
31 |
+
def as_torch_dtype(dtype: TypeSpec) -> torch.dtype:
|
32 |
+
if isinstance(dtype, torch.dtype):
|
33 |
+
return dtype
|
34 |
+
if isinstance(dtype, str):
|
35 |
+
dtype = np.dtype(dtype)
|
36 |
+
assert isinstance(dtype, np.dtype), f"Expected an instance of nunpy dtype, got {type(dtype)}"
|
37 |
+
return _NUMPY_TO_TORCH_DTYPE[dtype]
|
moge/model/dinov2/utils/param_groups.py
ADDED
@@ -0,0 +1,103 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
from collections import defaultdict
|
7 |
+
import logging
|
8 |
+
|
9 |
+
|
10 |
+
logger = logging.getLogger("dinov2")
|
11 |
+
|
12 |
+
|
13 |
+
def get_vit_lr_decay_rate(name, lr_decay_rate=1.0, num_layers=12, force_is_backbone=False, chunked_blocks=False):
|
14 |
+
"""
|
15 |
+
Calculate lr decay rate for different ViT blocks.
|
16 |
+
Args:
|
17 |
+
name (string): parameter name.
|
18 |
+
lr_decay_rate (float): base lr decay rate.
|
19 |
+
num_layers (int): number of ViT blocks.
|
20 |
+
Returns:
|
21 |
+
lr decay rate for the given parameter.
|
22 |
+
"""
|
23 |
+
layer_id = num_layers + 1
|
24 |
+
if name.startswith("backbone") or force_is_backbone:
|
25 |
+
if (
|
26 |
+
".pos_embed" in name
|
27 |
+
or ".patch_embed" in name
|
28 |
+
or ".mask_token" in name
|
29 |
+
or ".cls_token" in name
|
30 |
+
or ".register_tokens" in name
|
31 |
+
):
|
32 |
+
layer_id = 0
|
33 |
+
elif force_is_backbone and (
|
34 |
+
"pos_embed" in name
|
35 |
+
or "patch_embed" in name
|
36 |
+
or "mask_token" in name
|
37 |
+
or "cls_token" in name
|
38 |
+
or "register_tokens" in name
|
39 |
+
):
|
40 |
+
layer_id = 0
|
41 |
+
elif ".blocks." in name and ".residual." not in name:
|
42 |
+
layer_id = int(name[name.find(".blocks.") :].split(".")[2]) + 1
|
43 |
+
elif chunked_blocks and "blocks." in name and "residual." not in name:
|
44 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[2]) + 1
|
45 |
+
elif "blocks." in name and "residual." not in name:
|
46 |
+
layer_id = int(name[name.find("blocks.") :].split(".")[1]) + 1
|
47 |
+
|
48 |
+
return lr_decay_rate ** (num_layers + 1 - layer_id)
|
49 |
+
|
50 |
+
|
51 |
+
def get_params_groups_with_decay(model, lr_decay_rate=1.0, patch_embed_lr_mult=1.0):
|
52 |
+
chunked_blocks = False
|
53 |
+
if hasattr(model, "n_blocks"):
|
54 |
+
logger.info("chunked fsdp")
|
55 |
+
n_blocks = model.n_blocks
|
56 |
+
chunked_blocks = model.chunked_blocks
|
57 |
+
elif hasattr(model, "blocks"):
|
58 |
+
logger.info("first code branch")
|
59 |
+
n_blocks = len(model.blocks)
|
60 |
+
elif hasattr(model, "backbone"):
|
61 |
+
logger.info("second code branch")
|
62 |
+
n_blocks = len(model.backbone.blocks)
|
63 |
+
else:
|
64 |
+
logger.info("else code branch")
|
65 |
+
n_blocks = 0
|
66 |
+
all_param_groups = []
|
67 |
+
|
68 |
+
for name, param in model.named_parameters():
|
69 |
+
name = name.replace("_fsdp_wrapped_module.", "")
|
70 |
+
if not param.requires_grad:
|
71 |
+
continue
|
72 |
+
decay_rate = get_vit_lr_decay_rate(
|
73 |
+
name, lr_decay_rate, num_layers=n_blocks, force_is_backbone=n_blocks > 0, chunked_blocks=chunked_blocks
|
74 |
+
)
|
75 |
+
d = {"params": param, "is_last_layer": False, "lr_multiplier": decay_rate, "wd_multiplier": 1.0, "name": name}
|
76 |
+
|
77 |
+
if "last_layer" in name:
|
78 |
+
d.update({"is_last_layer": True})
|
79 |
+
|
80 |
+
if name.endswith(".bias") or "norm" in name or "gamma" in name:
|
81 |
+
d.update({"wd_multiplier": 0.0})
|
82 |
+
|
83 |
+
if "patch_embed" in name:
|
84 |
+
d.update({"lr_multiplier": d["lr_multiplier"] * patch_embed_lr_mult})
|
85 |
+
|
86 |
+
all_param_groups.append(d)
|
87 |
+
logger.info(f"""{name}: lr_multiplier: {d["lr_multiplier"]}, wd_multiplier: {d["wd_multiplier"]}""")
|
88 |
+
|
89 |
+
return all_param_groups
|
90 |
+
|
91 |
+
|
92 |
+
def fuse_params_groups(all_params_groups, keys=("lr_multiplier", "wd_multiplier", "is_last_layer")):
|
93 |
+
fused_params_groups = defaultdict(lambda: {"params": []})
|
94 |
+
for d in all_params_groups:
|
95 |
+
identifier = ""
|
96 |
+
for k in keys:
|
97 |
+
identifier += k + str(d[k]) + "_"
|
98 |
+
|
99 |
+
for k in keys:
|
100 |
+
fused_params_groups[identifier][k] = d[k]
|
101 |
+
fused_params_groups[identifier]["params"].append(d["params"])
|
102 |
+
|
103 |
+
return fused_params_groups.values()
|
moge/model/dinov2/utils/utils.py
ADDED
@@ -0,0 +1,95 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
import logging
|
7 |
+
import os
|
8 |
+
import random
|
9 |
+
import subprocess
|
10 |
+
from urllib.parse import urlparse
|
11 |
+
|
12 |
+
import numpy as np
|
13 |
+
import torch
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
|
19 |
+
|
20 |
+
def load_pretrained_weights(model, pretrained_weights, checkpoint_key):
|
21 |
+
if urlparse(pretrained_weights).scheme: # If it looks like an URL
|
22 |
+
state_dict = torch.hub.load_state_dict_from_url(pretrained_weights, map_location="cpu")
|
23 |
+
else:
|
24 |
+
state_dict = torch.load(pretrained_weights, map_location="cpu")
|
25 |
+
if checkpoint_key is not None and checkpoint_key in state_dict:
|
26 |
+
logger.info(f"Take key {checkpoint_key} in provided checkpoint dict")
|
27 |
+
state_dict = state_dict[checkpoint_key]
|
28 |
+
# remove `module.` prefix
|
29 |
+
state_dict = {k.replace("module.", ""): v for k, v in state_dict.items()}
|
30 |
+
# remove `backbone.` prefix induced by multicrop wrapper
|
31 |
+
state_dict = {k.replace("backbone.", ""): v for k, v in state_dict.items()}
|
32 |
+
msg = model.load_state_dict(state_dict, strict=False)
|
33 |
+
logger.info("Pretrained weights found at {} and loaded with msg: {}".format(pretrained_weights, msg))
|
34 |
+
|
35 |
+
|
36 |
+
def fix_random_seeds(seed=31):
|
37 |
+
"""
|
38 |
+
Fix random seeds.
|
39 |
+
"""
|
40 |
+
torch.manual_seed(seed)
|
41 |
+
torch.cuda.manual_seed_all(seed)
|
42 |
+
np.random.seed(seed)
|
43 |
+
random.seed(seed)
|
44 |
+
|
45 |
+
|
46 |
+
def get_sha():
|
47 |
+
cwd = os.path.dirname(os.path.abspath(__file__))
|
48 |
+
|
49 |
+
def _run(command):
|
50 |
+
return subprocess.check_output(command, cwd=cwd).decode("ascii").strip()
|
51 |
+
|
52 |
+
sha = "N/A"
|
53 |
+
diff = "clean"
|
54 |
+
branch = "N/A"
|
55 |
+
try:
|
56 |
+
sha = _run(["git", "rev-parse", "HEAD"])
|
57 |
+
subprocess.check_output(["git", "diff"], cwd=cwd)
|
58 |
+
diff = _run(["git", "diff-index", "HEAD"])
|
59 |
+
diff = "has uncommitted changes" if diff else "clean"
|
60 |
+
branch = _run(["git", "rev-parse", "--abbrev-ref", "HEAD"])
|
61 |
+
except Exception:
|
62 |
+
pass
|
63 |
+
message = f"sha: {sha}, status: {diff}, branch: {branch}"
|
64 |
+
return message
|
65 |
+
|
66 |
+
|
67 |
+
class CosineScheduler(object):
|
68 |
+
def __init__(self, base_value, final_value, total_iters, warmup_iters=0, start_warmup_value=0, freeze_iters=0):
|
69 |
+
super().__init__()
|
70 |
+
self.final_value = final_value
|
71 |
+
self.total_iters = total_iters
|
72 |
+
|
73 |
+
freeze_schedule = np.zeros((freeze_iters))
|
74 |
+
|
75 |
+
warmup_schedule = np.linspace(start_warmup_value, base_value, warmup_iters)
|
76 |
+
|
77 |
+
iters = np.arange(total_iters - warmup_iters - freeze_iters)
|
78 |
+
schedule = final_value + 0.5 * (base_value - final_value) * (1 + np.cos(np.pi * iters / len(iters)))
|
79 |
+
self.schedule = np.concatenate((freeze_schedule, warmup_schedule, schedule))
|
80 |
+
|
81 |
+
assert len(self.schedule) == self.total_iters
|
82 |
+
|
83 |
+
def __getitem__(self, it):
|
84 |
+
if it >= self.total_iters:
|
85 |
+
return self.final_value
|
86 |
+
else:
|
87 |
+
return self.schedule[it]
|
88 |
+
|
89 |
+
|
90 |
+
def has_batchnorms(model):
|
91 |
+
bn_types = (nn.BatchNorm1d, nn.BatchNorm2d, nn.BatchNorm3d, nn.SyncBatchNorm)
|
92 |
+
for name, module in model.named_modules():
|
93 |
+
if isinstance(module, bn_types):
|
94 |
+
return True
|
95 |
+
return False
|
moge/model/moge_model.py
ADDED
@@ -0,0 +1,376 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from numbers import Number
|
3 |
+
from functools import partial
|
4 |
+
from pathlib import Path
|
5 |
+
import importlib
|
6 |
+
import warnings
|
7 |
+
import json
|
8 |
+
|
9 |
+
import torch
|
10 |
+
import torch.nn as nn
|
11 |
+
import torch.nn.functional as F
|
12 |
+
import torch.utils
|
13 |
+
import torch.utils.checkpoint
|
14 |
+
import torch.version
|
15 |
+
import utils3d
|
16 |
+
from huggingface_hub import hf_hub_download
|
17 |
+
|
18 |
+
from ..utils.geometry_torch import image_plane_uv, point_map_to_depth, gaussian_blur_2d
|
19 |
+
from .utils import wrap_dinov2_attention_with_sdpa, wrap_module_with_gradient_checkpointing, unwrap_module_with_gradient_checkpointing
|
20 |
+
from ..utils.tools import timeit
|
21 |
+
|
22 |
+
|
23 |
+
class ResidualConvBlock(nn.Module):
|
24 |
+
def __init__(self, in_channels: int, out_channels: int = None, hidden_channels: int = None, padding_mode: str = 'replicate', activation: Literal['relu', 'leaky_relu', 'silu', 'elu'] = 'relu', norm: Literal['group_norm', 'layer_norm'] = 'group_norm'):
|
25 |
+
super(ResidualConvBlock, self).__init__()
|
26 |
+
if out_channels is None:
|
27 |
+
out_channels = in_channels
|
28 |
+
if hidden_channels is None:
|
29 |
+
hidden_channels = in_channels
|
30 |
+
|
31 |
+
if activation =='relu':
|
32 |
+
activation_cls = lambda: nn.ReLU(inplace=True)
|
33 |
+
elif activation == 'leaky_relu':
|
34 |
+
activation_cls = lambda: nn.LeakyReLU(negative_slope=0.2, inplace=True)
|
35 |
+
elif activation =='silu':
|
36 |
+
activation_cls = lambda: nn.SiLU(inplace=True)
|
37 |
+
elif activation == 'elu':
|
38 |
+
activation_cls = lambda: nn.ELU(inplace=True)
|
39 |
+
else:
|
40 |
+
raise ValueError(f'Unsupported activation function: {activation}')
|
41 |
+
|
42 |
+
self.layers = nn.Sequential(
|
43 |
+
nn.GroupNorm(1, in_channels),
|
44 |
+
activation_cls(),
|
45 |
+
nn.Conv2d(in_channels, hidden_channels, kernel_size=3, padding=1, padding_mode=padding_mode),
|
46 |
+
nn.GroupNorm(hidden_channels // 32 if norm == 'group_norm' else 1, hidden_channels),
|
47 |
+
activation_cls(),
|
48 |
+
nn.Conv2d(hidden_channels, out_channels, kernel_size=3, padding=1, padding_mode=padding_mode)
|
49 |
+
)
|
50 |
+
|
51 |
+
self.skip_connection = nn.Conv2d(in_channels, out_channels, kernel_size=1, padding=0) if in_channels != out_channels else nn.Identity()
|
52 |
+
|
53 |
+
def forward(self, x):
|
54 |
+
skip = self.skip_connection(x)
|
55 |
+
x = self.layers(x)
|
56 |
+
x = x + skip
|
57 |
+
return x
|
58 |
+
|
59 |
+
|
60 |
+
class Head(nn.Module):
|
61 |
+
def __init__(
|
62 |
+
self,
|
63 |
+
num_features: int,
|
64 |
+
dim_in: int,
|
65 |
+
dim_out: List[int],
|
66 |
+
dim_proj: int = 512,
|
67 |
+
dim_upsample: List[int] = [256, 128, 128],
|
68 |
+
dim_times_res_block_hidden: int = 1,
|
69 |
+
num_res_blocks: int = 1,
|
70 |
+
res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
|
71 |
+
last_res_blocks: int = 0,
|
72 |
+
last_conv_channels: int = 32,
|
73 |
+
last_conv_size: int = 1
|
74 |
+
):
|
75 |
+
super().__init__()
|
76 |
+
|
77 |
+
self.projects = nn.ModuleList([
|
78 |
+
nn.Conv2d(in_channels=dim_in, out_channels=dim_proj, kernel_size=1, stride=1, padding=0,) for _ in range(num_features)
|
79 |
+
])
|
80 |
+
|
81 |
+
self.upsample_blocks = nn.ModuleList([
|
82 |
+
nn.Sequential(
|
83 |
+
self._make_upsampler(in_ch + 2, out_ch),
|
84 |
+
*(ResidualConvBlock(out_ch, out_ch, dim_times_res_block_hidden * out_ch, activation="relu", norm=res_block_norm) for _ in range(num_res_blocks))
|
85 |
+
) for in_ch, out_ch in zip([dim_proj] + dim_upsample[:-1], dim_upsample)
|
86 |
+
])
|
87 |
+
|
88 |
+
self.output_block = nn.ModuleList([
|
89 |
+
self._make_output_block(
|
90 |
+
dim_upsample[-1] + 2, dim_out_, dim_times_res_block_hidden, last_res_blocks, last_conv_channels, last_conv_size, res_block_norm,
|
91 |
+
) for dim_out_ in dim_out
|
92 |
+
])
|
93 |
+
|
94 |
+
def _make_upsampler(self, in_channels: int, out_channels: int):
|
95 |
+
upsampler = nn.Sequential(
|
96 |
+
nn.ConvTranspose2d(in_channels, out_channels, kernel_size=2, stride=2),
|
97 |
+
nn.Conv2d(out_channels, out_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate')
|
98 |
+
)
|
99 |
+
upsampler[0].weight.data[:] = upsampler[0].weight.data[:, :, :1, :1]
|
100 |
+
return upsampler
|
101 |
+
|
102 |
+
def _make_output_block(self, dim_in: int, dim_out: int, dim_times_res_block_hidden: int, last_res_blocks: int, last_conv_channels: int, last_conv_size: int, res_block_norm: Literal['group_norm', 'layer_norm']):
|
103 |
+
return nn.Sequential(
|
104 |
+
nn.Conv2d(dim_in, last_conv_channels, kernel_size=3, stride=1, padding=1, padding_mode='replicate'),
|
105 |
+
*(ResidualConvBlock(last_conv_channels, last_conv_channels, dim_times_res_block_hidden * last_conv_channels, activation='relu', norm=res_block_norm) for _ in range(last_res_blocks)),
|
106 |
+
nn.ReLU(inplace=True),
|
107 |
+
nn.Conv2d(last_conv_channels, dim_out, kernel_size=last_conv_size, stride=1, padding=last_conv_size // 2, padding_mode='replicate'),
|
108 |
+
)
|
109 |
+
|
110 |
+
def forward(self, hidden_states: torch.Tensor, image: torch.Tensor):
|
111 |
+
img_h, img_w = image.shape[-2:]
|
112 |
+
patch_h, patch_w = img_h // 14, img_w // 14
|
113 |
+
|
114 |
+
# Process the hidden states
|
115 |
+
x = torch.stack([
|
116 |
+
proj(feat.permute(0, 2, 1).unflatten(2, (patch_h, patch_w)).contiguous())
|
117 |
+
for proj, (feat, clstoken) in zip(self.projects, hidden_states)
|
118 |
+
], dim=1).sum(dim=1)
|
119 |
+
|
120 |
+
# Upsample stage
|
121 |
+
# (patch_h, patch_w) -> (patch_h * 2, patch_w * 2) -> (patch_h * 4, patch_w * 4) -> (patch_h * 8, patch_w * 8)
|
122 |
+
for i, block in enumerate(self.upsample_blocks):
|
123 |
+
# UV coordinates is for awareness of image aspect ratio
|
124 |
+
uv = image_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
|
125 |
+
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
|
126 |
+
x = torch.cat([x, uv], dim=1)
|
127 |
+
for layer in block:
|
128 |
+
x = torch.utils.checkpoint.checkpoint(layer, x, use_reentrant=False)
|
129 |
+
|
130 |
+
# (patch_h * 8, patch_w * 8) -> (img_h, img_w)
|
131 |
+
x = F.interpolate(x, (img_h, img_w), mode="bilinear", align_corners=False)
|
132 |
+
uv = image_plane_uv(width=x.shape[-1], height=x.shape[-2], aspect_ratio=img_w / img_h, dtype=x.dtype, device=x.device)
|
133 |
+
uv = uv.permute(2, 0, 1).unsqueeze(0).expand(x.shape[0], -1, -1, -1)
|
134 |
+
x = torch.cat([x, uv], dim=1)
|
135 |
+
|
136 |
+
if isinstance(self.output_block, nn.ModuleList):
|
137 |
+
output = [torch.utils.checkpoint.checkpoint(block, x, use_reentrant=False) for block in self.output_block]
|
138 |
+
else:
|
139 |
+
output = torch.utils.checkpoint.checkpoint(self.output_block, x, use_reentrant=False)
|
140 |
+
|
141 |
+
return output
|
142 |
+
|
143 |
+
|
144 |
+
class MoGeModel(nn.Module):
|
145 |
+
image_mean: torch.Tensor
|
146 |
+
image_std: torch.Tensor
|
147 |
+
|
148 |
+
def __init__(self,
|
149 |
+
encoder: str = 'dinov2_vitb14',
|
150 |
+
intermediate_layers: Union[int, List[int]] = 4,
|
151 |
+
dim_proj: int = 512,
|
152 |
+
dim_upsample: List[int] = [256, 128, 128],
|
153 |
+
dim_times_res_block_hidden: int = 1,
|
154 |
+
num_res_blocks: int = 1,
|
155 |
+
output_mask: bool = False,
|
156 |
+
split_head: bool = False,
|
157 |
+
remap_output: Literal[False, True, 'linear', 'sinh', 'exp', 'sinh_exp'] = 'linear',
|
158 |
+
res_block_norm: Literal['group_norm', 'layer_norm'] = 'group_norm',
|
159 |
+
trained_diagonal_size_range: Tuple[Number, Number] = (600, 900),
|
160 |
+
trained_area_range: Tuple[Number, Number] = (500 * 500, 700 * 700),
|
161 |
+
last_res_blocks: int = 0,
|
162 |
+
last_conv_channels: int = 32,
|
163 |
+
last_conv_size: int = 1,
|
164 |
+
**deprecated_kwargs
|
165 |
+
):
|
166 |
+
super(MoGeModel, self).__init__()
|
167 |
+
if deprecated_kwargs:
|
168 |
+
warnings.warn(f"The following deprecated/invalid arguments are ignored: {deprecated_kwargs}")
|
169 |
+
|
170 |
+
self.encoder = encoder
|
171 |
+
self.remap_output = remap_output
|
172 |
+
self.intermediate_layers = intermediate_layers
|
173 |
+
self.trained_diagonal_size_range = trained_diagonal_size_range
|
174 |
+
self.trained_area_range = trained_area_range
|
175 |
+
self.output_mask = output_mask
|
176 |
+
self.split_head = split_head
|
177 |
+
|
178 |
+
# NOTE: We have copied the DINOv2 code in torchhub to this repository.
|
179 |
+
# Minimal modifications have been made: removing irrelevant code, unnecessary warnings and fixing importing issues.
|
180 |
+
hub_loader = getattr(importlib.import_module(".dinov2.hub.backbones", __package__), encoder)
|
181 |
+
self.backbone = hub_loader(pretrained=False)
|
182 |
+
dim_feature = self.backbone.blocks[0].attn.qkv.in_features
|
183 |
+
|
184 |
+
self.head = Head(
|
185 |
+
num_features=intermediate_layers if isinstance(intermediate_layers, int) else len(intermediate_layers),
|
186 |
+
dim_in=dim_feature,
|
187 |
+
dim_out=3 if not output_mask else 4 if output_mask and not split_head else [3, 1],
|
188 |
+
dim_proj=dim_proj,
|
189 |
+
dim_upsample=dim_upsample,
|
190 |
+
dim_times_res_block_hidden=dim_times_res_block_hidden,
|
191 |
+
num_res_blocks=num_res_blocks,
|
192 |
+
res_block_norm=res_block_norm,
|
193 |
+
last_res_blocks=last_res_blocks,
|
194 |
+
last_conv_channels=last_conv_channels,
|
195 |
+
last_conv_size=last_conv_size
|
196 |
+
)
|
197 |
+
|
198 |
+
image_mean = torch.tensor([0.485, 0.456, 0.406]).view(1, 3, 1, 1)
|
199 |
+
image_std = torch.tensor([0.229, 0.224, 0.225]).view(1, 3, 1, 1)
|
200 |
+
|
201 |
+
self.register_buffer("image_mean", image_mean)
|
202 |
+
self.register_buffer("image_std", image_std)
|
203 |
+
|
204 |
+
if torch.__version__ >= '2.0':
|
205 |
+
self.enable_pytorch_native_sdpa()
|
206 |
+
|
207 |
+
@classmethod
|
208 |
+
def from_pretrained(cls, pretrained_model_name_or_path: Union[str, Path, IO[bytes]], model_kwargs: Optional[Dict[str, Any]] = None, **hf_kwargs) -> 'MoGeModel':
|
209 |
+
"""
|
210 |
+
Load a model from a checkpoint file.
|
211 |
+
|
212 |
+
### Parameters:
|
213 |
+
- `pretrained_model_name_or_path`: path to the checkpoint file or repo id.
|
214 |
+
- `model_kwargs`: additional keyword arguments to override the parameters in the checkpoint.
|
215 |
+
- `hf_kwargs`: additional keyword arguments to pass to the `hf_hub_download` function. Ignored if `pretrained_model_name_or_path` is a local path.
|
216 |
+
|
217 |
+
### Returns:
|
218 |
+
- A new instance of `MoGe` with the parameters loaded from the checkpoint.
|
219 |
+
"""
|
220 |
+
if Path(pretrained_model_name_or_path).exists():
|
221 |
+
checkpoint = torch.load(pretrained_model_name_or_path, map_location='cpu', weights_only=True)
|
222 |
+
else:
|
223 |
+
cached_checkpoint_path = hf_hub_download(
|
224 |
+
repo_id=pretrained_model_name_or_path,
|
225 |
+
repo_type="model",
|
226 |
+
filename="model.pt",
|
227 |
+
**hf_kwargs
|
228 |
+
)
|
229 |
+
checkpoint = torch.load(cached_checkpoint_path, map_location='cpu', weights_only=True)
|
230 |
+
model_config = checkpoint['model_config']
|
231 |
+
if model_kwargs is not None:
|
232 |
+
model_config.update(model_kwargs)
|
233 |
+
model = cls(**model_config)
|
234 |
+
model.load_state_dict(checkpoint['model'])
|
235 |
+
return model
|
236 |
+
|
237 |
+
@staticmethod
|
238 |
+
def cache_pretrained_backbone(encoder: str, pretrained: bool):
|
239 |
+
_ = torch.hub.load('facebookresearch/dinov2', encoder, pretrained=pretrained)
|
240 |
+
|
241 |
+
def load_pretrained_backbone(self):
|
242 |
+
"Load the backbone with pretrained dinov2 weights from torch hub"
|
243 |
+
state_dict = torch.hub.load('facebookresearch/dinov2', self.encoder, pretrained=True).state_dict()
|
244 |
+
self.backbone.load_state_dict(state_dict)
|
245 |
+
|
246 |
+
def enable_backbone_gradient_checkpointing(self):
|
247 |
+
for i in range(len(self.backbone.blocks)):
|
248 |
+
self.backbone.blocks[i] = wrap_module_with_gradient_checkpointing(self.backbone.blocks[i])
|
249 |
+
|
250 |
+
def enable_pytorch_native_sdpa(self):
|
251 |
+
for i in range(len(self.backbone.blocks)):
|
252 |
+
self.backbone.blocks[i].attn = wrap_dinov2_attention_with_sdpa(self.backbone.blocks[i].attn)
|
253 |
+
|
254 |
+
def forward(self, image: torch.Tensor, mixed_precision: bool = False) -> Dict[str, torch.Tensor]:
|
255 |
+
raw_img_h, raw_img_w = image.shape[-2:]
|
256 |
+
patch_h, patch_w = raw_img_h // 14, raw_img_w // 14
|
257 |
+
|
258 |
+
image = (image - self.image_mean) / self.image_std
|
259 |
+
|
260 |
+
# Apply image transformation for DINOv2
|
261 |
+
image_14 = F.interpolate(image, (patch_h * 14, patch_w * 14), mode="bilinear", align_corners=False, antialias=True)
|
262 |
+
|
263 |
+
# Get intermediate layers from the backbone
|
264 |
+
with torch.autocast(device_type='cuda', dtype=torch.float16, enabled=mixed_precision):
|
265 |
+
features = self.backbone.get_intermediate_layers(image_14, self.intermediate_layers, return_class_token=True)
|
266 |
+
|
267 |
+
# Predict points (and mask)
|
268 |
+
output = self.head(features, image)
|
269 |
+
if self.output_mask:
|
270 |
+
if self.split_head:
|
271 |
+
points, mask = output
|
272 |
+
else:
|
273 |
+
points, mask = output.split([3, 1], dim=1)
|
274 |
+
points, mask = points.permute(0, 2, 3, 1), mask.squeeze(1)
|
275 |
+
else:
|
276 |
+
points = output.permute(0, 2, 3, 1)
|
277 |
+
|
278 |
+
if self.remap_output == 'linear' or self.remap_output == False:
|
279 |
+
pass
|
280 |
+
elif self.remap_output =='sinh' or self.remap_output == True:
|
281 |
+
points = torch.sinh(points)
|
282 |
+
elif self.remap_output == 'exp':
|
283 |
+
xy, z = points.split([2, 1], dim=-1)
|
284 |
+
z = torch.exp(z)
|
285 |
+
points = torch.cat([xy * z, z], dim=-1)
|
286 |
+
elif self.remap_output =='sinh_exp':
|
287 |
+
xy, z = points.split([2, 1], dim=-1)
|
288 |
+
points = torch.cat([torch.sinh(xy), torch.exp(z)], dim=-1)
|
289 |
+
else:
|
290 |
+
raise ValueError(f"Invalid remap output type: {self.remap_output}")
|
291 |
+
|
292 |
+
return_dict = {'points': points}
|
293 |
+
if self.output_mask:
|
294 |
+
return_dict['mask'] = mask
|
295 |
+
return return_dict
|
296 |
+
|
297 |
+
@torch.inference_mode()
|
298 |
+
def infer(
|
299 |
+
self,
|
300 |
+
image: torch.Tensor,
|
301 |
+
force_projection: bool = True,
|
302 |
+
resolution_level: int = 9,
|
303 |
+
apply_mask: bool = True,
|
304 |
+
) -> Dict[str, torch.Tensor]:
|
305 |
+
"""
|
306 |
+
User-friendly inference function
|
307 |
+
|
308 |
+
### Parameters
|
309 |
+
- `image`: input image tensor of shape (B, 3, H, W) or (3, H, W)
|
310 |
+
- `resolution_level`: the resolution level to use for the output point map in 0-9. Default: 9 (highest)
|
311 |
+
- `interpolation_mode`: interpolation mode for the output points map. Default: 'bilinear'.
|
312 |
+
|
313 |
+
### Returns
|
314 |
+
|
315 |
+
A dictionary containing the following keys:
|
316 |
+
- `points`: output tensor of shape (B, H, W, 3) or (H, W, 3).
|
317 |
+
- `depth`: tensor of shape (B, H, W) or (H, W) containing the depth map.
|
318 |
+
- `intrinsics`: tensor of shape (B, 3, 3) or (3, 3) containing the camera intrinsics.
|
319 |
+
"""
|
320 |
+
if image.dim() == 3:
|
321 |
+
omit_batch_dim = True
|
322 |
+
image = image.unsqueeze(0)
|
323 |
+
else:
|
324 |
+
omit_batch_dim = False
|
325 |
+
|
326 |
+
original_height, original_width = image.shape[-2:]
|
327 |
+
area = original_height * original_width
|
328 |
+
|
329 |
+
min_area, max_area = self.trained_area_range
|
330 |
+
expected_area = min_area + (max_area - min_area) * (resolution_level / 9)
|
331 |
+
|
332 |
+
if expected_area != area:
|
333 |
+
expected_width, expected_height = int(original_width * (expected_area / area) ** 0.5), int(original_height * (expected_area / area) ** 0.5)
|
334 |
+
image = F.interpolate(image, (expected_height, expected_width), mode="bicubic", align_corners=False, antialias=True)
|
335 |
+
|
336 |
+
output = self.forward(image)
|
337 |
+
points, mask = output['points'], output.get('mask', None)
|
338 |
+
|
339 |
+
# Get camera-origin-centered point map
|
340 |
+
depth, fov_x, fov_y, z_shift = point_map_to_depth(points, None if mask is None else mask > 0.5)
|
341 |
+
intrinsics = utils3d.torch.intrinsics_from_fov_xy(fov_x, fov_y)
|
342 |
+
|
343 |
+
# If projection constraint is forces, recompute the point map using the actual depth map
|
344 |
+
if force_projection:
|
345 |
+
points = utils3d.torch.unproject_cv(utils3d.torch.image_uv(width=expected_width, height=expected_height, dtype=points.dtype, device=points.device), depth, extrinsics=None, intrinsics=intrinsics[..., None, :, :])
|
346 |
+
else:
|
347 |
+
points = points + torch.stack([torch.zeros_like(z_shift), torch.zeros_like(z_shift), z_shift], dim=-1)[..., None, None, :]
|
348 |
+
|
349 |
+
# Resize the output to the original resolution
|
350 |
+
if expected_area != area:
|
351 |
+
points = F.interpolate(points.permute(0, 3, 1, 2), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).permute(0, 2, 3, 1)
|
352 |
+
depth = F.interpolate(depth.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1)
|
353 |
+
mask = None if mask is None else F.interpolate(mask.unsqueeze(1), (original_height, original_width), mode='bilinear', align_corners=False, antialias=False).squeeze(1)
|
354 |
+
|
355 |
+
# Apply mask if needed
|
356 |
+
if self.output_mask and apply_mask:
|
357 |
+
mask_binary = (depth > 0) & (mask > 0.5)
|
358 |
+
points = torch.where(mask_binary[..., None], points, torch.inf)
|
359 |
+
depth = torch.where(mask_binary, depth, torch.inf)
|
360 |
+
|
361 |
+
if omit_batch_dim:
|
362 |
+
points = points.squeeze(0)
|
363 |
+
intrinsics = intrinsics.squeeze(0)
|
364 |
+
depth = depth.squeeze(0)
|
365 |
+
if self.output_mask:
|
366 |
+
mask = mask.squeeze(0)
|
367 |
+
|
368 |
+
return_dict = {
|
369 |
+
'points': points,
|
370 |
+
'intrinsics': intrinsics,
|
371 |
+
'depth': depth,
|
372 |
+
}
|
373 |
+
if self.output_mask:
|
374 |
+
return_dict['mask'] = mask > 0.5
|
375 |
+
|
376 |
+
return return_dict
|
moge/model/utils.py
ADDED
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
|
3 |
+
import torch
|
4 |
+
import torch.nn as nn
|
5 |
+
import torch.nn.functional as F
|
6 |
+
|
7 |
+
def wrap_module_with_gradient_checkpointing(module: nn.Module):
|
8 |
+
from torch.utils.checkpoint import checkpoint
|
9 |
+
class _CheckpointingWrapper(module.__class__):
|
10 |
+
_restore_cls = module.__class__
|
11 |
+
def forward(self, *args, **kwargs):
|
12 |
+
return checkpoint(super().forward, *args, use_reentrant=False, **kwargs)
|
13 |
+
|
14 |
+
module.__class__ = _CheckpointingWrapper
|
15 |
+
return module
|
16 |
+
|
17 |
+
|
18 |
+
def unwrap_module_with_gradient_checkpointing(module: nn.Module):
|
19 |
+
module.__class__ = module.__class__._restore_cls
|
20 |
+
|
21 |
+
|
22 |
+
def wrap_dinov2_attention_with_sdpa(module: nn.Module):
|
23 |
+
assert torch.__version__ >= '2.0', "SDPA requires PyTorch 2.0 or later"
|
24 |
+
class _AttentionWrapper(module.__class__):
|
25 |
+
def forward(self, x: torch.Tensor, attn_bias=None) -> torch.Tensor:
|
26 |
+
B, N, C = x.shape
|
27 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4) # (3, B, H, N, C // H)
|
28 |
+
|
29 |
+
q, k, v = torch.unbind(qkv, 0) # (B, H, N, C // H)
|
30 |
+
|
31 |
+
x = F.scaled_dot_product_attention(q, k, v, attn_bias)
|
32 |
+
x = x.permute(0, 2, 1, 3).reshape(B, N, C)
|
33 |
+
|
34 |
+
x = self.proj(x)
|
35 |
+
x = self.proj_drop(x)
|
36 |
+
return x
|
37 |
+
module.__class__ = _AttentionWrapper
|
38 |
+
return module
|
moge/utils/__init__.py
ADDED
File without changes
|
moge/utils/blob.py
ADDED
@@ -0,0 +1,314 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import IO, Generator, Tuple, Union, overload
|
2 |
+
from pathlib import Path, PosixPath, PurePosixPath
|
3 |
+
import io
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import requests
|
7 |
+
import fnmatch
|
8 |
+
|
9 |
+
from azure.identity import DefaultAzureCredential
|
10 |
+
from azure.storage.blob import ContainerClient, BlobClient
|
11 |
+
import requests.adapters
|
12 |
+
import requests.packages
|
13 |
+
from urllib3.util.retry import Retry
|
14 |
+
|
15 |
+
|
16 |
+
__all__ = [
|
17 |
+
'download_blob', 'upload_blob',
|
18 |
+
'download_blob_with_cache',
|
19 |
+
'open_blob', 'open_blob_with_cache',
|
20 |
+
'blob_file_exists',
|
21 |
+
'AzureBlobPath','SmartPath'
|
22 |
+
]
|
23 |
+
|
24 |
+
DEFAULT_CREDENTIAL = DefaultAzureCredential()
|
25 |
+
|
26 |
+
BLOB_CACHE_DIR = './.blobcache'
|
27 |
+
|
28 |
+
def download_blob(blob: Union[str, BlobClient]) -> bytes:
|
29 |
+
if isinstance(blob, str):
|
30 |
+
blob_client = BlobClient.from_blob_url(blob_client)
|
31 |
+
else:
|
32 |
+
blob_client = blob
|
33 |
+
return blob_client.download_blob().read()
|
34 |
+
|
35 |
+
|
36 |
+
def upload_blob(blob: Union[str, BlobClient], data: Union[str, bytes]):
|
37 |
+
if isinstance(blob, str):
|
38 |
+
blob_client = BlobClient.from_blob_url(blob)
|
39 |
+
else:
|
40 |
+
blob_client = blob
|
41 |
+
blob_client.upload_blob(data, overwrite=True)
|
42 |
+
|
43 |
+
|
44 |
+
def download_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> bytes:
|
45 |
+
"""
|
46 |
+
Download a blob file from a container and return its content as bytes.
|
47 |
+
If the file is already present in the cache, it is read from there.
|
48 |
+
"""
|
49 |
+
cache_path = Path(cache_dir) / blob_name
|
50 |
+
if cache_path.exists():
|
51 |
+
return cache_path.read_bytes()
|
52 |
+
data = download_blob(container, blob_name)
|
53 |
+
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
54 |
+
cache_path.write_bytes(data)
|
55 |
+
return data
|
56 |
+
|
57 |
+
|
58 |
+
def open_blob(container: Union[str, ContainerClient], blob_name: str) -> io.BytesIO:
|
59 |
+
"""
|
60 |
+
Open a blob file for reading from a container and return its content as a BytesIO object.
|
61 |
+
"""
|
62 |
+
return io.BytesIO(download_blob(container, blob_name))
|
63 |
+
|
64 |
+
|
65 |
+
def open_blob_with_cache(container: Union[str, ContainerClient], blob_name: str, cache_dir: str = 'blobcache') -> io.BytesIO:
|
66 |
+
"""
|
67 |
+
Open a blob file for reading from a container and return its content as a BytesIO object.
|
68 |
+
If the file is already present in the cache, it is read from there.
|
69 |
+
"""
|
70 |
+
return io.BytesIO(download_blob_with_cache(container, blob_name, cache_dir=cache_dir))
|
71 |
+
|
72 |
+
|
73 |
+
def blob_file_exists(container: Union[str, ContainerClient], blob_name: str) -> bool:
|
74 |
+
"""
|
75 |
+
Check if a blob file exists in a container.
|
76 |
+
"""
|
77 |
+
if isinstance(container, str):
|
78 |
+
container = ContainerClient.from_container_url(container)
|
79 |
+
blob_client = container.get_blob_client(blob_name)
|
80 |
+
return blob_client.exists()
|
81 |
+
|
82 |
+
def is_blob_url(url: str) -> bool:
|
83 |
+
return re.match(r'https://[^/]+blob.core.windows.net/+', url) is not None
|
84 |
+
|
85 |
+
|
86 |
+
def split_blob_url(url: str) -> Tuple[str, str, str]:
|
87 |
+
match = re.match(r'(https://[^/]+blob.core.windows.net/[^/?]+)(/([^\?]*))?(\?.+)?', url)
|
88 |
+
if match:
|
89 |
+
container, _, path, sas = match.groups()
|
90 |
+
return container, path or '', sas or ''
|
91 |
+
raise ValueError(f'Not a valid blob URL: {url}')
|
92 |
+
|
93 |
+
|
94 |
+
def join_blob_path(url: str, *others: str) -> str:
|
95 |
+
container, path, sas = split_blob_url(url)
|
96 |
+
return container + '/' + os.path.join(path, *others) + sas
|
97 |
+
|
98 |
+
|
99 |
+
class AzureBlobStringWriter(io.StringIO):
|
100 |
+
def __init__(self, blob_client: BlobClient, encoding: str = 'utf-8', **kwargs):
|
101 |
+
self._encoding = encoding
|
102 |
+
self.blob_client = blob_client
|
103 |
+
self.kwargs = kwargs
|
104 |
+
super().__init__()
|
105 |
+
|
106 |
+
def close(self):
|
107 |
+
self.blob_client.upload_blob(self.getvalue().encode(self._encoding), blob_type='BlockBlob', overwrite=True, **self.kwargs)
|
108 |
+
|
109 |
+
|
110 |
+
class AzureBlobBytesWriter(io.BytesIO):
|
111 |
+
def __init__(self, blob_client: BlobClient, **kwargs):
|
112 |
+
super().__init__()
|
113 |
+
self.blob_client = blob_client
|
114 |
+
self.kwargs = kwargs
|
115 |
+
|
116 |
+
def close(self):
|
117 |
+
self.blob_client.upload_blob(self.getvalue(), blob_type='BlockBlob', overwrite=True, **self.kwargs)
|
118 |
+
|
119 |
+
|
120 |
+
def open_azure_blob(blob: Union[str, BlobClient], mode: str = 'r', encoding: str = 'utf-8', newline: str = None, cache_blob: bool = False, **kwargs) -> IO:
|
121 |
+
if isinstance(blob, str):
|
122 |
+
blob_client = BlobClient.from_blob_url(blob)
|
123 |
+
elif isinstance(blob, BlobClient):
|
124 |
+
blob_client = blob
|
125 |
+
else:
|
126 |
+
raise ValueError(f'Must be a blob URL or a BlobClient object: {blob}')
|
127 |
+
|
128 |
+
if cache_blob:
|
129 |
+
cache_path = Path(BLOB_CACHE_DIR, blob_client.account_name, blob_client.container_name, blob_client.blob_name)
|
130 |
+
|
131 |
+
if mode == 'r' or mode == 'rb':
|
132 |
+
if cache_blob:
|
133 |
+
if cache_path.exists():
|
134 |
+
data = cache_path.read_bytes()
|
135 |
+
else:
|
136 |
+
data = blob_client.download_blob(**kwargs).read()
|
137 |
+
cache_path.parent.mkdir(parents=True, exist_ok=True)
|
138 |
+
cache_path.write_bytes(data)
|
139 |
+
else:
|
140 |
+
data = blob_client.download_blob(**kwargs).read()
|
141 |
+
if mode == 'r':
|
142 |
+
return io.StringIO(data.decode(encoding), newline=newline)
|
143 |
+
else:
|
144 |
+
return io.BytesIO(data)
|
145 |
+
elif mode == 'w':
|
146 |
+
return AzureBlobStringWriter(blob_client, **kwargs)
|
147 |
+
elif mode == 'wb':
|
148 |
+
return AzureBlobBytesWriter(blob_client, **kwargs)
|
149 |
+
else:
|
150 |
+
raise ValueError(f'Unsupported mode: {mode}')
|
151 |
+
|
152 |
+
|
153 |
+
def smart_open(path_or_url: Union[Path, str], mode: str = 'r', encoding: str = 'utf-8') -> IO:
|
154 |
+
if is_blob_url(str(path_or_url)):
|
155 |
+
return open_azure_blob(str(path_or_url), mode, encoding)
|
156 |
+
return open(path_or_url, mode, encoding)
|
157 |
+
|
158 |
+
|
159 |
+
class AzureBlobPath(PurePosixPath):
|
160 |
+
"""
|
161 |
+
Implementation of pathlib.Path like interface for Azure Blob Storage.
|
162 |
+
"""
|
163 |
+
container_client: ContainerClient
|
164 |
+
_parse_path = PurePosixPath._parse_args if hasattr(PurePosixPath, '_parse_args') else PurePosixPath._parse_path
|
165 |
+
|
166 |
+
def __new__(cls, *args, **kwargs):
|
167 |
+
"""Override the old __new__ method. Parts are parsed in __init__"""
|
168 |
+
return object.__new__(cls)
|
169 |
+
|
170 |
+
def __init__(self, root: Union[str, 'AzureBlobPath', ContainerClient], *others: Union[str, PurePosixPath], pool_maxsize: int = 256, retries: int = 3):
|
171 |
+
if isinstance(root, AzureBlobPath):
|
172 |
+
self.container_client = root.container_client
|
173 |
+
parts = root.parts + others
|
174 |
+
elif isinstance(root, str):
|
175 |
+
url = root
|
176 |
+
container, path, sas = split_blob_url(url)
|
177 |
+
session = self._get_session(pool_maxsize=pool_maxsize, retries=retries)
|
178 |
+
if sas:
|
179 |
+
self.container_client = ContainerClient.from_container_url(container + sas, session=session)
|
180 |
+
else:
|
181 |
+
self.container_client = ContainerClient.from_container_url(container, credential=DEFAULT_CREDENTIAL, session=session)
|
182 |
+
parts = (path, *others)
|
183 |
+
elif isinstance(root, ContainerClient):
|
184 |
+
self.container_client = root
|
185 |
+
parts = others
|
186 |
+
else:
|
187 |
+
raise ValueError(f'Invalid root: {root}')
|
188 |
+
|
189 |
+
if hasattr(PurePosixPath, '_parse_args'):
|
190 |
+
# For compatibility with Python 3.10
|
191 |
+
drv, root, parts = PurePosixPath._parse_args(parts)
|
192 |
+
self._drv = drv
|
193 |
+
self._root = root
|
194 |
+
self._parts = parts
|
195 |
+
else:
|
196 |
+
super().__init__(*parts)
|
197 |
+
|
198 |
+
def _get_session(self, pool_maxsize: int = 1024, retries: int = 3) -> requests.Session:
|
199 |
+
session = requests.Session()
|
200 |
+
retry_strategy = Retry(
|
201 |
+
total=retries,
|
202 |
+
status_forcelist=[429, 500, 502, 503, 504],
|
203 |
+
allowed_methods=["HEAD", "GET", "PUT", "DELETE"],
|
204 |
+
backoff_factor=1,
|
205 |
+
raise_on_status=False,
|
206 |
+
read=retries,
|
207 |
+
connect=retries,
|
208 |
+
redirect=retries,
|
209 |
+
)
|
210 |
+
adapter = requests.adapters.HTTPAdapter(pool_connections=pool_maxsize, pool_maxsize=pool_maxsize, max_retries=retry_strategy)
|
211 |
+
session.mount('http://', adapter)
|
212 |
+
session.mount('https://', adapter)
|
213 |
+
return session
|
214 |
+
|
215 |
+
def _from_parsed_parts(self, drv, root, parts):
|
216 |
+
"For compatibility with Python 3.10"
|
217 |
+
return AzureBlobPath(self.container_client, drv, root, *parts)
|
218 |
+
|
219 |
+
def with_segments(self, *pathsegments):
|
220 |
+
return AzureBlobPath(self.container_client, *pathsegments)
|
221 |
+
|
222 |
+
@property
|
223 |
+
def path(self) -> str:
|
224 |
+
return '/'.join(self.parts)
|
225 |
+
|
226 |
+
@property
|
227 |
+
def blob_client(self) -> BlobClient:
|
228 |
+
return self.container_client.get_blob_client(self.path)
|
229 |
+
|
230 |
+
@property
|
231 |
+
def url(self) -> str:
|
232 |
+
if len(self.parts) == 0:
|
233 |
+
return self.container_client.url
|
234 |
+
return self.container_client.get_blob_client(self.path).url
|
235 |
+
|
236 |
+
@property
|
237 |
+
def container_name(self) -> str:
|
238 |
+
return self.container_client.container_name
|
239 |
+
|
240 |
+
@property
|
241 |
+
def account_name(self) -> str:
|
242 |
+
return self.container_client.account_name
|
243 |
+
|
244 |
+
def __str__(self):
|
245 |
+
return self.url
|
246 |
+
|
247 |
+
def __repr__(self):
|
248 |
+
return self.url
|
249 |
+
|
250 |
+
def open(self, mode: str = 'r', encoding: str = 'utf-8', cache_blob: bool = False, **kwargs) -> IO:
|
251 |
+
return open_azure_blob(self.blob_client, mode, encoding, cache_blob=cache_blob, **kwargs)
|
252 |
+
|
253 |
+
def __truediv__(self, other: Union[str, Path]) -> 'AzureBlobPath':
|
254 |
+
return self.joinpath(other)
|
255 |
+
|
256 |
+
def mkdir(self, parents: bool = False, exist_ok: bool = False):
|
257 |
+
pass
|
258 |
+
|
259 |
+
def iterdir(self) -> Generator['AzureBlobPath', None, None]:
|
260 |
+
path = self.path
|
261 |
+
if not path.endswith('/'):
|
262 |
+
path += '/'
|
263 |
+
for item in self.container_client.walk_blobs(self.path):
|
264 |
+
yield AzureBlobPath(self.container_client, item.name)
|
265 |
+
|
266 |
+
def glob(self, pattern: str) -> Generator['AzureBlobPath', None, None]:
|
267 |
+
special_chars = ".^$+{}[]()|/"
|
268 |
+
for char in special_chars:
|
269 |
+
pattern = pattern.replace(char, "\\" + char)
|
270 |
+
pattern = pattern.replace('**', './/.')
|
271 |
+
pattern = pattern.replace('*', '[^/]*')
|
272 |
+
pattern = pattern.replace('.//.', '.*')
|
273 |
+
pattern = "^" + pattern + "$"
|
274 |
+
reg = re.compile(pattern)
|
275 |
+
|
276 |
+
for item in self.container_client.list_blobs(self.path):
|
277 |
+
if reg.match(os.path.relpath(item.name, self.path)):
|
278 |
+
yield AzureBlobPath(self.container_client, item.name)
|
279 |
+
|
280 |
+
def exists(self) -> bool:
|
281 |
+
return self.blob_client.exists()
|
282 |
+
|
283 |
+
def read_bytes(self, cache_blob: bool = False) -> bytes:
|
284 |
+
with self.open('rb', cache_blob=cache_blob) as f:
|
285 |
+
return f.read()
|
286 |
+
|
287 |
+
def read_text(self, encoding: str = 'utf-8', cache_blob: bool = False) -> str:
|
288 |
+
with self.open('r', encoding=encoding, cache_blob=cache_blob) as f:
|
289 |
+
return f.read()
|
290 |
+
|
291 |
+
def write_bytes(self, data: bytes):
|
292 |
+
self.blob_client.upload_blob(data, overwrite=True)
|
293 |
+
|
294 |
+
def write_text(self, data: str, encoding: str = 'utf-8'):
|
295 |
+
self.blob_client.upload_blob(data.encode(encoding), overwrite=True)
|
296 |
+
|
297 |
+
def unlink(self):
|
298 |
+
self.blob_client.delete_blob()
|
299 |
+
|
300 |
+
def new_client(self) -> 'AzureBlobPath':
|
301 |
+
return AzureBlobPath(self.container_client.url, self.path)
|
302 |
+
|
303 |
+
|
304 |
+
class SmartPath(Path, AzureBlobPath):
|
305 |
+
"""
|
306 |
+
Supports both local file paths and Azure Blob Storage URLs.
|
307 |
+
"""
|
308 |
+
def __new__(cls, first: Union[Path, str], *others: Union[str, PurePosixPath]) -> Union[Path, AzureBlobPath]:
|
309 |
+
if is_blob_url(str(first)):
|
310 |
+
return AzureBlobPath(str(first), *others)
|
311 |
+
return Path(first, *others)
|
312 |
+
|
313 |
+
|
314 |
+
|
moge/utils/download.py
ADDED
@@ -0,0 +1,55 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from pathlib import Path
|
2 |
+
from typing import *
|
3 |
+
import requests
|
4 |
+
|
5 |
+
from tqdm import tqdm
|
6 |
+
|
7 |
+
|
8 |
+
__all__ = ["download_file", "download_bytes"]
|
9 |
+
|
10 |
+
|
11 |
+
def download_file(url: str, filepath: Union[str, Path], headers: dict = None, resume: bool = True) -> None:
|
12 |
+
# Ensure headers is a dict if not provided
|
13 |
+
headers = headers or {}
|
14 |
+
|
15 |
+
# Initialize local variables
|
16 |
+
file_path = Path(filepath)
|
17 |
+
downloaded_bytes = 0
|
18 |
+
|
19 |
+
# Check if we should resume the download
|
20 |
+
if resume and file_path.exists():
|
21 |
+
downloaded_bytes = file_path.stat().st_size
|
22 |
+
headers['Range'] = f"bytes={downloaded_bytes}-"
|
23 |
+
|
24 |
+
# Make a GET request to fetch the file
|
25 |
+
with requests.get(url, stream=True, headers=headers) as response:
|
26 |
+
response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
|
27 |
+
|
28 |
+
# Calculate the total size to download
|
29 |
+
total_size = downloaded_bytes + int(response.headers.get('content-length', 0))
|
30 |
+
|
31 |
+
# Display a progress bar while downloading
|
32 |
+
with (
|
33 |
+
tqdm(desc=f"Downloading {file_path.name}", total=total_size, unit='B', unit_scale=True, leave=False) as pbar,
|
34 |
+
open(file_path, 'ab') as file,
|
35 |
+
):
|
36 |
+
# Set the initial position of the progress bar
|
37 |
+
pbar.update(downloaded_bytes)
|
38 |
+
|
39 |
+
# Write the content to the file in chunks
|
40 |
+
for chunk in response.iter_content(chunk_size=4096):
|
41 |
+
file.write(chunk)
|
42 |
+
pbar.update(len(chunk))
|
43 |
+
|
44 |
+
|
45 |
+
def download_bytes(url: str, headers: dict = None) -> bytes:
|
46 |
+
# Ensure headers is a dict if not provided
|
47 |
+
headers = headers or {}
|
48 |
+
|
49 |
+
# Make a GET request to fetch the file
|
50 |
+
with requests.get(url, stream=True, headers=headers) as response:
|
51 |
+
response.raise_for_status() # This will raise an HTTPError if the status is 4xx/5xx
|
52 |
+
|
53 |
+
# Read the content of the response
|
54 |
+
return response.content
|
55 |
+
|
moge/utils/geometry_numpy.py
ADDED
@@ -0,0 +1,175 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from functools import partial
|
3 |
+
import math
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import utils3d
|
7 |
+
|
8 |
+
from .tools import timeit
|
9 |
+
|
10 |
+
def weighted_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
|
11 |
+
if w is None:
|
12 |
+
return np.mean(x, axis=axis)
|
13 |
+
else:
|
14 |
+
w = w.astype(x.dtype)
|
15 |
+
return (x * w).mean(axis=axis) / np.clip(w.mean(axis=axis), eps, None)
|
16 |
+
|
17 |
+
|
18 |
+
def harmonic_mean_numpy(x: np.ndarray, w: np.ndarray = None, axis: Union[int, Tuple[int,...]] = None, keepdims: bool = False, eps: float = 1e-7) -> np.ndarray:
|
19 |
+
if w is None:
|
20 |
+
return 1 / (1 / np.clip(x, eps, None)).mean(axis=axis)
|
21 |
+
else:
|
22 |
+
w = w.astype(x.dtype)
|
23 |
+
return 1 / (weighted_mean_numpy(1 / (x + eps), w, axis=axis, keepdims=keepdims, eps=eps) + eps)
|
24 |
+
|
25 |
+
|
26 |
+
def image_plane_uv_numpy(width: int, height: int, aspect_ratio: float = None, dtype: np.dtype = np.float32) -> np.ndarray:
|
27 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
28 |
+
if aspect_ratio is None:
|
29 |
+
aspect_ratio = width / height
|
30 |
+
|
31 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
32 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
33 |
+
|
34 |
+
u = np.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype)
|
35 |
+
v = np.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype)
|
36 |
+
u, v = np.meshgrid(u, v, indexing='xy')
|
37 |
+
uv = np.stack([u, v], axis=-1)
|
38 |
+
return uv
|
39 |
+
|
40 |
+
|
41 |
+
def focal_to_fov_numpy(focal: np.ndarray):
|
42 |
+
return 2 * np.arctan(0.5 / focal)
|
43 |
+
|
44 |
+
|
45 |
+
def fov_to_focal_numpy(fov: np.ndarray):
|
46 |
+
return 0.5 / np.tan(fov / 2)
|
47 |
+
|
48 |
+
|
49 |
+
def intrinsics_to_fov_numpy(intrinsics: np.ndarray) -> Tuple[np.ndarray, np.ndarray]:
|
50 |
+
fov_x = focal_to_fov_numpy(intrinsics[..., 0, 0])
|
51 |
+
fov_y = focal_to_fov_numpy(intrinsics[..., 1, 1])
|
52 |
+
return fov_x, fov_y
|
53 |
+
|
54 |
+
|
55 |
+
def solve_optimal_shift_focal(uv: np.ndarray, xyz: np.ndarray, ransac_iters: int = None, ransac_hypothetical_size: float = 0.1, ransac_threshold: float = 0.1):
|
56 |
+
"Solve `min |focal * xy / (z + shift) - uv|` with respect to shift and focal"
|
57 |
+
from scipy.optimize import least_squares
|
58 |
+
uv, xy, z = uv.reshape(-1, 2), xyz[..., :2].reshape(-1, 2), xyz[..., 2].reshape(-1)
|
59 |
+
|
60 |
+
def fn(uv: np.ndarray, xy: np.ndarray, z: np.ndarray, shift: np.ndarray):
|
61 |
+
xy_proj = xy / (z + shift)[: , None]
|
62 |
+
f = (xy_proj * uv).sum() / np.square(xy_proj).sum()
|
63 |
+
err = (f * xy_proj - uv).ravel()
|
64 |
+
return err
|
65 |
+
|
66 |
+
initial_shift = 0 #-z.min(keepdims=True) + 1.0
|
67 |
+
|
68 |
+
if ransac_iters is None:
|
69 |
+
solution = least_squares(partial(fn, uv, xy, z), x0=initial_shift, ftol=1e-3, method='lm')
|
70 |
+
optim_shift = solution['x'].squeeze().astype(np.float32)
|
71 |
+
else:
|
72 |
+
best_err, best_shift = np.inf, None
|
73 |
+
for _ in range(ransac_iters):
|
74 |
+
maybe_inliers = np.random.choice(len(z), size=int(ransac_hypothetical_size * len(z)), replace=False)
|
75 |
+
solution = least_squares(partial(fn, uv[maybe_inliers], xy[maybe_inliers], z[maybe_inliers]), x0=initial_shift, ftol=1e-3, method='lm')
|
76 |
+
maybe_shift = solution['x'].squeeze().astype(np.float32)
|
77 |
+
confirmed_inliers = np.linalg.norm(fn(uv, xy, z, maybe_shift).reshape(-1, 2), axis=-1) < ransac_threshold
|
78 |
+
if confirmed_inliers.sum() > 10:
|
79 |
+
solution = least_squares(partial(fn, uv[confirmed_inliers], xy[confirmed_inliers], z[confirmed_inliers]), x0=maybe_shift, ftol=1e-3, method='lm')
|
80 |
+
better_shift = solution['x'].squeeze().astype(np.float32)
|
81 |
+
else:
|
82 |
+
better_shift = maybe_shift
|
83 |
+
err = np.linalg.norm(fn(uv, xy, z, better_shift).reshape(-1, 2), axis=-1).clip(max=ransac_threshold).mean()
|
84 |
+
if err < best_err:
|
85 |
+
best_err, best_shift = err, better_shift
|
86 |
+
initial_shift = best_shift
|
87 |
+
|
88 |
+
optim_shift = best_shift
|
89 |
+
|
90 |
+
xy_proj = xy / (z + optim_shift)[: , None]
|
91 |
+
optim_focal = (xy_proj * uv).sum() / (xy_proj * xy_proj).sum()
|
92 |
+
|
93 |
+
return optim_shift, optim_focal
|
94 |
+
|
95 |
+
|
96 |
+
def point_map_to_depth_numpy(points: np.ndarray, mask: np.ndarray = None, downsample_size: Tuple[int, int] = (64, 64)):
|
97 |
+
import cv2
|
98 |
+
assert points.shape[-1] == 3, "Points should (H, W, 3)"
|
99 |
+
|
100 |
+
height, width = points.shape[-3], points.shape[-2]
|
101 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
102 |
+
|
103 |
+
uv = image_plane_uv_numpy(width=width, height=height)
|
104 |
+
|
105 |
+
if mask is None:
|
106 |
+
points_lr = cv2.resize(points, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 3)
|
107 |
+
uv_lr = cv2.resize(uv, downsample_size, interpolation=cv2.INTER_LINEAR).reshape(-1, 2)
|
108 |
+
else:
|
109 |
+
index, mask_lr = mask_aware_nearest_resize_numpy(mask, *downsample_size)
|
110 |
+
points_lr, uv_lr = points[index][mask_lr], uv[index][mask_lr]
|
111 |
+
|
112 |
+
if points_lr.size == 0:
|
113 |
+
return np.zeros((height, width)), 0, 0, 0
|
114 |
+
|
115 |
+
optim_shift, optim_focal = solve_optimal_shift_focal(uv_lr, points_lr, ransac_iters=None)
|
116 |
+
|
117 |
+
fov_x = 2 * np.arctan(width / diagonal / optim_focal)
|
118 |
+
fov_y = 2 * np.arctan(height / diagonal / optim_focal)
|
119 |
+
|
120 |
+
depth = points[:, :, 2] + optim_shift
|
121 |
+
return depth, fov_x, fov_y, optim_shift
|
122 |
+
|
123 |
+
|
124 |
+
def mask_aware_nearest_resize_numpy(mask: np.ndarray, target_width: int, target_height: int) -> Tuple[np.ndarray, np.ndarray, np.ndarray]:
|
125 |
+
"""
|
126 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
127 |
+
|
128 |
+
### Parameters
|
129 |
+
- `mask`: Input 2D mask of shape (..., H, W)
|
130 |
+
- `target_width`: target width of the resized map
|
131 |
+
- `target_height`: target height of the resized map
|
132 |
+
|
133 |
+
### Returns
|
134 |
+
- `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width). Indices are like j + i * W, where j is the row index and i is the column index.
|
135 |
+
- `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
|
136 |
+
"""
|
137 |
+
height, width = mask.shape[-2:]
|
138 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
139 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
|
140 |
+
filter_size = filter_h_i * filter_w_i
|
141 |
+
padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2)
|
142 |
+
|
143 |
+
# Window the original mask and uv
|
144 |
+
uv = utils3d.numpy.image_pixel_center(width=width, height=height, dtype=np.float32)
|
145 |
+
indices = np.arange(height * width, dtype=np.int32).reshape(height, width)
|
146 |
+
padded_uv = np.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=np.float32)
|
147 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
148 |
+
padded_mask = np.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=bool)
|
149 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
150 |
+
padded_indices = np.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=np.int32)
|
151 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
152 |
+
windowed_uv = utils3d.numpy.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
153 |
+
windowed_mask = utils3d.numpy.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, axis=(-2, -1))
|
154 |
+
windowed_indices = utils3d.numpy.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, axis=(0, 1))
|
155 |
+
|
156 |
+
# Gather the target pixels's local window
|
157 |
+
target_uv = utils3d.numpy.image_uv(width=target_width, height=target_height, dtype=np.float32) * np.array([width, height], dtype=np.float32)
|
158 |
+
target_corner = target_uv - np.array((filter_w_f / 2, filter_h_f / 2), dtype=np.float32)
|
159 |
+
target_corner = np.round(target_corner - 0.5).astype(np.int32) + np.array((padding_w, padding_h), dtype=np.int32)
|
160 |
+
|
161 |
+
target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
162 |
+
target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
163 |
+
target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
164 |
+
|
165 |
+
# Compute nearest neighbor in the local window for each pixel
|
166 |
+
dist = np.square(target_window_uv - target_uv[..., None])
|
167 |
+
dist = dist[..., 0, :] + dist[..., 1, :]
|
168 |
+
dist = np.where(target_window_mask, dist, np.inf) # (..., target_height, tgt_width, filter_size)
|
169 |
+
nearest_in_window = np.argmin(dist, axis=-1, keepdims=True) # (..., target_height, tgt_width, 1)
|
170 |
+
nearest_idx = np.take_along_axis(target_window_indices, nearest_in_window, axis=-1).squeeze(-1) # (..., target_height, tgt_width)
|
171 |
+
nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
|
172 |
+
target_mask = np.any(target_window_mask, axis=-1)
|
173 |
+
batch_indices = [np.arange(n).reshape([1] * i + [n] + [1] * (mask.ndim - i - 1)) for i, n in enumerate(mask.shape[:-2])]
|
174 |
+
|
175 |
+
return (*batch_indices, nearest_i, nearest_j), target_mask
|
moge/utils/geometry_torch.py
ADDED
@@ -0,0 +1,231 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import math
|
3 |
+
from collections import namedtuple
|
4 |
+
|
5 |
+
import numpy as np
|
6 |
+
import torch
|
7 |
+
import torch.nn as nn
|
8 |
+
import torch.nn.functional as F
|
9 |
+
import torch.types
|
10 |
+
import utils3d
|
11 |
+
|
12 |
+
from .tools import timeit
|
13 |
+
from .geometry_numpy import solve_optimal_shift_focal
|
14 |
+
|
15 |
+
|
16 |
+
def weighted_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
17 |
+
if w is None:
|
18 |
+
return x.mean(dim=dim, keepdim=keepdim)
|
19 |
+
else:
|
20 |
+
w = w.to(x.dtype)
|
21 |
+
return (x * w).mean(dim=dim, keepdim=keepdim) / w.mean(dim=dim, keepdim=keepdim).add(eps)
|
22 |
+
|
23 |
+
|
24 |
+
def harmonic_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
25 |
+
if w is None:
|
26 |
+
return x.add(eps).reciprocal().mean(dim=dim, keepdim=keepdim).reciprocal()
|
27 |
+
else:
|
28 |
+
w = w.to(x.dtype)
|
29 |
+
return weighted_mean(x.add(eps).reciprocal(), w, dim=dim, keepdim=keepdim, eps=eps).add(eps).reciprocal()
|
30 |
+
|
31 |
+
|
32 |
+
def geometric_mean(x: torch.Tensor, w: torch.Tensor = None, dim: Union[int, torch.Size] = None, keepdim: bool = False, eps: float = 1e-7) -> torch.Tensor:
|
33 |
+
if w is None:
|
34 |
+
return x.add(eps).log().mean(dim=dim).exp()
|
35 |
+
else:
|
36 |
+
w = w.to(x.dtype)
|
37 |
+
return weighted_mean(x.add(eps).log(), w, dim=dim, keepdim=keepdim, eps=eps).exp()
|
38 |
+
|
39 |
+
|
40 |
+
def image_plane_uv(width: int, height: int, aspect_ratio: float = None, dtype: torch.dtype = None, device: torch.device = None) -> torch.Tensor:
|
41 |
+
"UV with left-top corner as (-width / diagonal, -height / diagonal) and right-bottom corner as (width / diagonal, height / diagonal)"
|
42 |
+
if aspect_ratio is None:
|
43 |
+
aspect_ratio = width / height
|
44 |
+
|
45 |
+
span_x = aspect_ratio / (1 + aspect_ratio ** 2) ** 0.5
|
46 |
+
span_y = 1 / (1 + aspect_ratio ** 2) ** 0.5
|
47 |
+
|
48 |
+
u = torch.linspace(-span_x * (width - 1) / width, span_x * (width - 1) / width, width, dtype=dtype, device=device)
|
49 |
+
v = torch.linspace(-span_y * (height - 1) / height, span_y * (height - 1) / height, height, dtype=dtype, device=device)
|
50 |
+
u, v = torch.meshgrid(u, v, indexing='xy')
|
51 |
+
uv = torch.stack([u, v], dim=-1)
|
52 |
+
return uv
|
53 |
+
|
54 |
+
|
55 |
+
def gaussian_blur_2d(input: torch.Tensor, kernel_size: int, sigma: float) -> torch.Tensor:
|
56 |
+
kernel = torch.exp(-(torch.arange(-kernel_size // 2 + 1, kernel_size // 2 + 1, dtype=input.dtype, device=input.device) ** 2) / (2 * sigma ** 2))
|
57 |
+
kernel = kernel / kernel.sum()
|
58 |
+
kernel = (kernel[:, None] * kernel[None, :]).reshape(1, 1, kernel_size, kernel_size)
|
59 |
+
input = F.pad(input, (kernel_size // 2, kernel_size // 2, kernel_size // 2, kernel_size // 2), mode='replicate')
|
60 |
+
input = F.conv2d(input, kernel, groups=input.shape[1])
|
61 |
+
return input
|
62 |
+
|
63 |
+
|
64 |
+
def split_batch_fwd(fn: Callable, chunk_size: int, *args, **kwargs):
|
65 |
+
batch_size = next(x for x in (*args, *kwargs.values()) if isinstance(x, torch.Tensor)).shape[0]
|
66 |
+
n_chunks = batch_size // chunk_size + (batch_size % chunk_size > 0)
|
67 |
+
splited_args = tuple(arg.split(chunk_size, dim=0) if isinstance(arg, torch.Tensor) else [arg] * n_chunks for arg in args)
|
68 |
+
splited_kwargs = {k: [v.split(chunk_size, dim=0) if isinstance(v, torch.Tensor) else [v] * n_chunks] for k, v in kwargs.items()}
|
69 |
+
results = []
|
70 |
+
for i in range(n_chunks):
|
71 |
+
chunk_args = tuple(arg[i] for arg in splited_args)
|
72 |
+
chunk_kwargs = {k: v[i] for k, v in splited_kwargs.items()}
|
73 |
+
results.append(fn(*chunk_args, **chunk_kwargs))
|
74 |
+
|
75 |
+
if isinstance(results[0], tuple):
|
76 |
+
return tuple(torch.cat(r, dim=0) for r in zip(*results))
|
77 |
+
else:
|
78 |
+
return torch.cat(results, dim=0)
|
79 |
+
|
80 |
+
|
81 |
+
def focal_to_fov(focal: torch.Tensor):
|
82 |
+
return 2 * torch.atan(0.5 / focal)
|
83 |
+
|
84 |
+
|
85 |
+
def fov_to_focal(fov: torch.Tensor):
|
86 |
+
return 0.5 / torch.tan(fov / 2)
|
87 |
+
|
88 |
+
|
89 |
+
def intrinsics_to_fov(intrinsics: torch.Tensor):
|
90 |
+
"""
|
91 |
+
Returns field of view in radians from normalized intrinsics matrix.
|
92 |
+
### Parameters:
|
93 |
+
- intrinsics: torch.Tensor of shape (..., 3, 3)
|
94 |
+
|
95 |
+
### Returns:
|
96 |
+
- fov_x: torch.Tensor of shape (...)
|
97 |
+
- fov_y: torch.Tensor of shape (...)
|
98 |
+
"""
|
99 |
+
focal_x = intrinsics[..., 0, 0]
|
100 |
+
focal_y = intrinsics[..., 1, 1]
|
101 |
+
return 2 * torch.atan(0.5 / focal_x), 2 * torch.atan(0.5 / focal_y)
|
102 |
+
|
103 |
+
|
104 |
+
def point_map_to_depth_legacy(points: torch.Tensor):
|
105 |
+
height, width = points.shape[-3:-1]
|
106 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
107 |
+
uv = image_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
|
108 |
+
|
109 |
+
# Solve least squares problem
|
110 |
+
b = (uv * points[..., 2:]).flatten(-3, -1) # (..., H * W * 2)
|
111 |
+
A = torch.stack([points[..., :2], -uv.expand_as(points[..., :2])], dim=-1).flatten(-4, -2) # (..., H * W * 2, 2)
|
112 |
+
|
113 |
+
M = A.transpose(-2, -1) @ A
|
114 |
+
solution = (torch.inverse(M + 1e-6 * torch.eye(2).to(A)) @ (A.transpose(-2, -1) @ b[..., None])).squeeze(-1)
|
115 |
+
focal, shift = solution.unbind(-1)
|
116 |
+
|
117 |
+
depth = points[..., 2] + shift[..., None, None]
|
118 |
+
fov_x = torch.atan(width / diagonal / focal) * 2
|
119 |
+
fov_y = torch.atan(height / diagonal / focal) * 2
|
120 |
+
return depth, fov_x, fov_y, shift
|
121 |
+
|
122 |
+
|
123 |
+
def point_map_to_depth(points: torch.Tensor, mask: torch.Tensor = None, downsample_size: Tuple[int, int] = (64, 64)):
|
124 |
+
"""
|
125 |
+
Recover the depth map and FoV from a point map with unknown z shift and focal.
|
126 |
+
|
127 |
+
Note that it assumes:
|
128 |
+
- the optical center is at the center of the map
|
129 |
+
- the map is undistorted
|
130 |
+
- the map is isometric in the x and y directions
|
131 |
+
|
132 |
+
### Parameters:
|
133 |
+
- `points: torch.Tensor` of shape (..., H, W, 3)
|
134 |
+
- `downsample_size: Tuple[int, int]` in (height, width), the size of the downsampled map. Downsampling produces approximate solution and is efficient for large maps.
|
135 |
+
|
136 |
+
### Returns:
|
137 |
+
- `depth: torch.Tensor` of shape (..., H, W)
|
138 |
+
- `fov_x: torch.Tensor` of shape (...)
|
139 |
+
- `fov_y: torch.Tensor` of shape (...)
|
140 |
+
- `shift: torch.Tensor` of shape (...), the z shift, making `depth = points[..., 2] + shift`
|
141 |
+
"""
|
142 |
+
shape = points.shape
|
143 |
+
height, width = points.shape[-3], points.shape[-2]
|
144 |
+
diagonal = (height ** 2 + width ** 2) ** 0.5
|
145 |
+
|
146 |
+
points = points.reshape(-1, *shape[-3:])
|
147 |
+
mask = None if mask is None else mask.reshape(-1, *shape[-3:-1])
|
148 |
+
uv = image_plane_uv(width, height, dtype=points.dtype, device=points.device) # (H, W, 2)
|
149 |
+
|
150 |
+
points_lr = F.interpolate(points.permute(0, 3, 1, 2), downsample_size, mode='nearest').permute(0, 2, 3, 1)
|
151 |
+
uv_lr = F.interpolate(uv.unsqueeze(0).permute(0, 3, 1, 2), downsample_size, mode='nearest').squeeze(0).permute(1, 2, 0)
|
152 |
+
mask_lr = None if mask is None else F.interpolate(mask.to(torch.float32).unsqueeze(1), downsample_size, mode='nearest').squeeze(1) > 0
|
153 |
+
|
154 |
+
uv_lr_np = uv_lr.cpu().numpy()
|
155 |
+
points_lr_np = points_lr.detach().cpu().numpy()
|
156 |
+
mask_lr_np = None if mask is None else mask_lr.cpu().numpy()
|
157 |
+
optim_shift, optim_focal = [], []
|
158 |
+
for i in range(points.shape[0]):
|
159 |
+
points_lr_i_np = points_lr_np[i] if mask is None else points_lr_np[i][mask_lr_np[i]]
|
160 |
+
uv_lr_i_np = uv_lr_np if mask is None else uv_lr_np[mask_lr_np[i]]
|
161 |
+
optim_shift_i, optim_focal_i = solve_optimal_shift_focal(uv_lr_i_np, points_lr_i_np, ransac_iters=None)
|
162 |
+
optim_shift.append(float(optim_shift_i))
|
163 |
+
optim_focal.append(float(optim_focal_i))
|
164 |
+
optim_shift = torch.tensor(optim_shift, device=points.device, dtype=points.dtype)
|
165 |
+
optim_focal = torch.tensor(optim_focal, device=points.device, dtype=points.dtype)
|
166 |
+
|
167 |
+
fov_x = 2 * torch.atan(width / diagonal / optim_focal)
|
168 |
+
fov_y = 2 * torch.atan(height / diagonal / optim_focal)
|
169 |
+
|
170 |
+
depth = (points[..., 2] + optim_shift[:, None, None]).reshape(shape[:-1])
|
171 |
+
fov_x = fov_x.reshape(shape[:-3])
|
172 |
+
fov_y = fov_y.reshape(shape[:-3])
|
173 |
+
optim_shift = optim_shift.reshape(shape[:-3])
|
174 |
+
|
175 |
+
return depth, fov_x, fov_y, optim_shift
|
176 |
+
|
177 |
+
|
178 |
+
def mask_aware_nearest_resize(mask: torch.BoolTensor, target_width: int, target_height: int) -> Tuple[torch.LongTensor, torch.LongTensor, torch.BoolTensor]:
|
179 |
+
"""
|
180 |
+
Resize 2D map by nearest interpolation. Return the nearest neighbor index and mask of the resized map.
|
181 |
+
|
182 |
+
### Parameters
|
183 |
+
- `mask`: Input 2D mask of shape (..., H, W)
|
184 |
+
- `target_width`: target width of the resized map
|
185 |
+
- `target_height`: target height of the resized map
|
186 |
+
|
187 |
+
### Returns
|
188 |
+
- `nearest_idx`: Nearest neighbor index of the resized map of shape (..., target_height, target_width) for each dimension
|
189 |
+
- `target_mask`: Mask of the resized map of shape (..., target_height, target_width)
|
190 |
+
"""
|
191 |
+
height, width = mask.shape[-2:]
|
192 |
+
device = mask.device
|
193 |
+
filter_h_f, filter_w_f = max(1, height / target_height), max(1, width / target_width)
|
194 |
+
filter_h_i, filter_w_i = math.ceil(filter_h_f), math.ceil(filter_w_f)
|
195 |
+
filter_size = filter_h_i * filter_w_i
|
196 |
+
padding_h, padding_w = round(filter_h_f / 2), round(filter_w_f / 2)
|
197 |
+
|
198 |
+
# Window the original mask and uv
|
199 |
+
uv = utils3d.torch.image_pixel_center(width=width, height=height, dtype=torch.float32, device=device)
|
200 |
+
indices = torch.arange(height * width, dtype=torch.long, device=device).reshape(height, width)
|
201 |
+
padded_uv = torch.full((height + 2 * padding_h, width + 2 * padding_w, 2), 0, dtype=torch.float32, device=device)
|
202 |
+
padded_uv[padding_h:padding_h + height, padding_w:padding_w + width] = uv
|
203 |
+
padded_mask = torch.full((*mask.shape[:-2], height + 2 * padding_h, width + 2 * padding_w), False, dtype=torch.bool, device=device)
|
204 |
+
padded_mask[..., padding_h:padding_h + height, padding_w:padding_w + width] = mask
|
205 |
+
padded_indices = torch.full((height + 2 * padding_h, width + 2 * padding_w), 0, dtype=torch.long, device=device)
|
206 |
+
padded_indices[padding_h:padding_h + height, padding_w:padding_w + width] = indices
|
207 |
+
windowed_uv = utils3d.torch.sliding_window_2d(padded_uv, (filter_h_i, filter_w_i), 1, dim=(0, 1))
|
208 |
+
windowed_mask = utils3d.torch.sliding_window_2d(padded_mask, (filter_h_i, filter_w_i), 1, dim=(-2, -1))
|
209 |
+
windowed_indices = utils3d.torch.sliding_window_2d(padded_indices, (filter_h_i, filter_w_i), 1, dim=(0, 1))
|
210 |
+
|
211 |
+
# Gather the target pixels's local window
|
212 |
+
target_uv = utils3d.torch.image_uv(width=target_width, height=target_height, dtype=torch.float32, device=device) * torch.tensor([width, height], dtype=torch.float32, device=device)
|
213 |
+
target_corner = target_uv - torch.tensor((filter_w_f / 2, filter_h_f / 2), dtype=torch.float32, device=device)
|
214 |
+
target_corner = torch.round(target_corner - 0.5).long() + torch.tensor((padding_w, padding_h), dtype=torch.long, device=device)
|
215 |
+
|
216 |
+
target_window_uv = windowed_uv[target_corner[..., 1], target_corner[..., 0], :, :, :].reshape(target_height, target_width, 2, filter_size) # (target_height, tgt_width, 2, filter_size)
|
217 |
+
target_window_mask = windowed_mask[..., target_corner[..., 1], target_corner[..., 0], :, :].reshape(*mask.shape[:-2], target_height, target_width, filter_size) # (..., target_height, tgt_width, filter_size)
|
218 |
+
target_window_indices = windowed_indices[target_corner[..., 1], target_corner[..., 0], :, :].reshape(target_height, target_width, filter_size) # (target_height, tgt_width, filter_size)
|
219 |
+
target_window_indices = target_window_indices.expand_as(target_window_mask)
|
220 |
+
|
221 |
+
# Compute nearest neighbor in the local window for each pixel
|
222 |
+
dist = torch.where(target_window_mask, torch.norm(target_window_uv - target_uv[..., None], dim=-2), torch.inf) # (..., target_height, tgt_width, filter_size)
|
223 |
+
nearest = torch.argmin(dist, dim=-1, keepdim=True) # (..., target_height, tgt_width, 1)
|
224 |
+
nearest_idx = torch.gather(target_window_indices, index=nearest, dim=-1).squeeze(-1) # (..., target_height, tgt_width)
|
225 |
+
target_mask = torch.any(target_window_mask, dim=-1)
|
226 |
+
nearest_i, nearest_j = nearest_idx // width, nearest_idx % width
|
227 |
+
batch_indices = [torch.arange(n, device=device).reshape([1] * i + [n] + [1] * (mask.dim() - i - 1)) for i, n in enumerate(mask.shape[:-2])]
|
228 |
+
|
229 |
+
return (*batch_indices, nearest_i, nearest_j), target_mask
|
230 |
+
|
231 |
+
|
moge/utils/io.py
ADDED
@@ -0,0 +1,347 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
os.environ['OPENCV_IO_ENABLE_OPENEXR'] = '1'
|
3 |
+
from typing import IO
|
4 |
+
import zipfile
|
5 |
+
import json
|
6 |
+
import io
|
7 |
+
from typing import *
|
8 |
+
from pathlib import Path
|
9 |
+
import re
|
10 |
+
|
11 |
+
import numpy as np
|
12 |
+
import cv2
|
13 |
+
|
14 |
+
from .tools import timeit
|
15 |
+
|
16 |
+
|
17 |
+
LEGACY_SEGFORMER_CLASSES = [
|
18 |
+
'wall', 'building', 'sky', 'floor', 'tree', 'ceiling', 'road', 'bed ',
|
19 |
+
'windowpane', 'grass', 'cabinet', 'sidewalk', 'person', 'earth',
|
20 |
+
'door', 'table', 'mountain', 'plant', 'curtain', 'chair', 'car',
|
21 |
+
'water', 'painting', 'sofa', 'shelf', 'house', 'sea', 'mirror', 'rug',
|
22 |
+
'field', 'armchair', 'seat', 'fence', 'desk', 'rock', 'wardrobe',
|
23 |
+
'lamp', 'bathtub', 'railing', 'cushion', 'base', 'box', 'column',
|
24 |
+
'signboard', 'chest of drawers', 'counter', 'sand', 'sink',
|
25 |
+
'skyscraper', 'fireplace', 'refrigerator', 'grandstand', 'path',
|
26 |
+
'stairs', 'runway', 'case', 'pool table', 'pillow', 'screen door',
|
27 |
+
'stairway', 'river', 'bridge', 'bookcase', 'blind', 'coffee table',
|
28 |
+
'toilet', 'flower', 'book', 'hill', 'bench', 'countertop', 'stove',
|
29 |
+
'palm', 'kitchen island', 'computer', 'swivel chair', 'boat', 'bar',
|
30 |
+
'arcade machine', 'hovel', 'bus', 'towel', 'light', 'truck', 'tower',
|
31 |
+
'chandelier', 'awning', 'streetlight', 'booth', 'television receiver',
|
32 |
+
'airplane', 'dirt track', 'apparel', 'pole', 'land', 'bannister',
|
33 |
+
'escalator', 'ottoman', 'bottle', 'buffet', 'poster', 'stage', 'van',
|
34 |
+
'ship', 'fountain', 'conveyer belt', 'canopy', 'washer', 'plaything',
|
35 |
+
'swimming pool', 'stool', 'barrel', 'basket', 'waterfall', 'tent',
|
36 |
+
'bag', 'minibike', 'cradle', 'oven', 'ball', 'food', 'step', 'tank',
|
37 |
+
'trade name', 'microwave', 'pot', 'animal', 'bicycle', 'lake',
|
38 |
+
'dishwasher', 'screen', 'blanket', 'sculpture', 'hood', 'sconce',
|
39 |
+
'vase', 'traffic light', 'tray', 'ashcan', 'fan', 'pier', 'crt screen',
|
40 |
+
'plate', 'monitor', 'bulletin board', 'shower', 'radiator', 'glass',
|
41 |
+
'clock', 'flag'
|
42 |
+
]
|
43 |
+
LEGACY_SEGFORMER_LABELS = {k: i for i, k in enumerate(LEGACY_SEGFORMER_CLASSES)}
|
44 |
+
|
45 |
+
|
46 |
+
def write_rgbd_zip(
|
47 |
+
file: Union[IO, os.PathLike],
|
48 |
+
image: Union[np.ndarray, bytes],
|
49 |
+
depth: Union[np.ndarray, bytes], mask: Union[np.ndarray, bytes],
|
50 |
+
segmentation_mask: Union[np.ndarray, bytes] = None, segmentation_labels: Union[Dict[str, int], bytes] = None,
|
51 |
+
intrinsics: np.ndarray = None,
|
52 |
+
normal: np.ndarray = None, normal_mask: np.ndarray = None,
|
53 |
+
meta: Union[Dict[str, Any], bytes] = None,
|
54 |
+
*, image_quality: int = 95, depth_type: Literal['linear', 'log', 'disparity'] = 'linear', depth_format: Literal['png', 'exr'] = 'png', depth_max_dynamic_range: float = 1e4, png_compression: int = 7
|
55 |
+
):
|
56 |
+
"""
|
57 |
+
Write RGBD data as zip archive containing the image, depth, mask, segmentation_mask, and meta data.
|
58 |
+
In the zip file there will be:
|
59 |
+
- `meta.json`: The meta data as a JSON file.
|
60 |
+
- `image.jpg`: The RGB image as a JPEG file.
|
61 |
+
- `depth.png/exr`: The depth map as a PNG or EXR file, depending on the `depth_type`.
|
62 |
+
- `mask.png` (optional): The mask as a uint8 PNG file.
|
63 |
+
- `segmentation_mask.png` (optional): The segformer mask as a uint8/uint16 PNG file.
|
64 |
+
|
65 |
+
You can provided those data as np.ndarray or bytes. If you provide them as np.ndarray, they will be properly processed and encoded.
|
66 |
+
If you provide them as bytes, they will be written as is, assuming they are already encoded.
|
67 |
+
"""
|
68 |
+
if meta is None:
|
69 |
+
meta = {}
|
70 |
+
elif isinstance(meta, bytes):
|
71 |
+
meta = json.loads(meta.decode())
|
72 |
+
|
73 |
+
if isinstance(image, bytes):
|
74 |
+
image_bytes = image
|
75 |
+
elif isinstance(image, np.ndarray):
|
76 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
77 |
+
image_bytes = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, image_quality])[1].tobytes()
|
78 |
+
|
79 |
+
if isinstance(depth, bytes):
|
80 |
+
depth_bytes = depth
|
81 |
+
elif isinstance(depth, np.ndarray):
|
82 |
+
meta['depth_type'] = depth_type
|
83 |
+
if depth_type == 'linear':
|
84 |
+
if depth.dtype == np.float16:
|
85 |
+
depth_format = 'exr'
|
86 |
+
depth_bytes = cv2.imencode('.exr', depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_HALF])[1].tobytes()
|
87 |
+
elif np.issubdtype(depth.dtype, np.floating):
|
88 |
+
depth_format = 'exr'
|
89 |
+
depth_bytes = cv2.imencode('.exr', depth.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])[1].tobytes()
|
90 |
+
elif depth.dtype in [np.uint8, np.uint16]:
|
91 |
+
depth_format = 'png'
|
92 |
+
depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes()
|
93 |
+
elif depth_type == 'log':
|
94 |
+
depth_format = 'png'
|
95 |
+
depth = depth.astype(np.float32)
|
96 |
+
near = max(depth[mask].min(), 1e-3)
|
97 |
+
far = min(depth[mask].max(), near * depth_max_dynamic_range)
|
98 |
+
depth = ((np.log(depth.clip(near, far) / near) / np.log(far / near)).clip(0, 1) * 65535).astype(np.uint16)
|
99 |
+
depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes()
|
100 |
+
meta['depth_near'] = float(near)
|
101 |
+
meta['depth_far'] = float(far)
|
102 |
+
elif depth_type == 'disparity':
|
103 |
+
depth_format = 'png'
|
104 |
+
depth = depth.astype(np.float32)
|
105 |
+
depth = 1 / (depth + 1e-12)
|
106 |
+
depth = (depth / depth[mask].max()).clip(0, 1)
|
107 |
+
if np.unique(depth) < 200:
|
108 |
+
depth = (depth * 255).astype(np.uint8)
|
109 |
+
else:
|
110 |
+
depth = (depth * 65535).astype(np.uint16)
|
111 |
+
depth_bytes = cv2.imencode('.png', depth, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes()
|
112 |
+
|
113 |
+
if isinstance(mask, bytes):
|
114 |
+
mask_bytes = mask
|
115 |
+
elif isinstance(mask, np.ndarray):
|
116 |
+
mask_bytes = cv2.imencode('.png', mask.astype(np.uint8) * 255)[1].tobytes()
|
117 |
+
|
118 |
+
if segmentation_mask is not None:
|
119 |
+
if isinstance(segmentation_mask, bytes):
|
120 |
+
segmentation_mask_bytes = segmentation_mask
|
121 |
+
else:
|
122 |
+
segmentation_mask_bytes = cv2.imencode('.png', segmentation_mask)[1].tobytes()
|
123 |
+
assert segmentation_labels is not None, "You provided a segmentation mask, but not the corresponding labels."
|
124 |
+
if isinstance(segmentation_labels, bytes):
|
125 |
+
segmentation_labels = json.loads(segmentation_labels)
|
126 |
+
meta['segmentation_labels'] = segmentation_labels
|
127 |
+
|
128 |
+
if intrinsics is not None:
|
129 |
+
meta['intrinsics'] = intrinsics.tolist()
|
130 |
+
|
131 |
+
if normal is not None:
|
132 |
+
if isinstance(normal, bytes):
|
133 |
+
normal_bytes = normal
|
134 |
+
elif isinstance(normal, np.ndarray):
|
135 |
+
normal = ((normal * [0.5, -0.5, -0.5] + 0.5).clip(0, 1) * 65535).astype(np.uint16)
|
136 |
+
normal = cv2.cvtColor(normal, cv2.COLOR_RGB2BGR)
|
137 |
+
normal_bytes = cv2.imencode('.png', normal, [cv2.IMWRITE_PNG_COMPRESSION, png_compression])[1].tobytes()
|
138 |
+
if normal_mask is None:
|
139 |
+
normal_mask = np.ones(image.shape[:2], dtype=bool)
|
140 |
+
normal_mask_bytes = cv2.imencode('.png', normal_mask.astype(np.uint8) * 255)[1].tobytes()
|
141 |
+
|
142 |
+
meta_bytes = meta if isinstance(meta, bytes) else json.dumps(meta).encode()
|
143 |
+
|
144 |
+
with zipfile.ZipFile(file, 'w') as z:
|
145 |
+
z.writestr('meta.json', meta_bytes)
|
146 |
+
z.writestr('image.jpg', image_bytes)
|
147 |
+
z.writestr(f'depth.{depth_format}', depth_bytes)
|
148 |
+
z.writestr('mask.png', mask_bytes)
|
149 |
+
if segmentation_mask is not None:
|
150 |
+
z.writestr('segmentation_mask.png', segmentation_mask_bytes)
|
151 |
+
if normal is not None:
|
152 |
+
z.writestr('normal.png', normal_bytes)
|
153 |
+
z.writestr('normal_mask.png', normal_mask_bytes)
|
154 |
+
|
155 |
+
|
156 |
+
def read_rgbd_zip(file: Union[str, Path, IO], return_bytes: bool = False) -> Dict[str, Union[np.ndarray, Dict[str, Any], bytes]]:
|
157 |
+
"""
|
158 |
+
Read an RGBD zip file and return the image, depth, mask, segmentation_mask, intrinsics, and meta data.
|
159 |
+
|
160 |
+
### Parameters:
|
161 |
+
- `file: Union[str, Path, IO]`
|
162 |
+
The file path or file object to read from.
|
163 |
+
- `return_bytes: bool = False`
|
164 |
+
If True, return the image, depth, mask, and segmentation_mask as raw bytes.
|
165 |
+
|
166 |
+
### Returns:
|
167 |
+
- `Tuple[Dict[str, Union[np.ndarray, Dict[str, Any]]], Dict[str, bytes]]`
|
168 |
+
A dictionary containing: (If missing, the value will be None; if return_bytes is True, the value will be bytes)
|
169 |
+
- `image`: RGB numpy.ndarray of shape (H, W, 3).
|
170 |
+
- `depth`: float32 numpy.ndarray of shape (H, W).
|
171 |
+
- `mask`: bool numpy.ndarray of shape (H, W).
|
172 |
+
- `segformer_mask`: uint8 numpy.ndarray of shape (H, W).
|
173 |
+
- `intrinsics`: float32 numpy.ndarray of shape (3, 3).
|
174 |
+
- `meta`: Dict[str, Any].
|
175 |
+
"""
|
176 |
+
# Load & extract archive
|
177 |
+
with zipfile.ZipFile(file, 'r') as z:
|
178 |
+
meta = z.read('meta.json')
|
179 |
+
if not return_bytes:
|
180 |
+
meta = json.loads(z.read('meta.json'))
|
181 |
+
|
182 |
+
image = z.read('image.jpg')
|
183 |
+
if not return_bytes:
|
184 |
+
image = cv2.imdecode(np.frombuffer(z.read('image.jpg'), np.uint8), cv2.IMREAD_COLOR)
|
185 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
186 |
+
|
187 |
+
depth_name = next(s for s in z.namelist() if s.startswith('depth'))
|
188 |
+
depth = z.read(depth_name)
|
189 |
+
if not return_bytes:
|
190 |
+
depth = cv2.imdecode(np.frombuffer(z.read(depth_name), np.uint8), cv2.IMREAD_UNCHANGED)
|
191 |
+
|
192 |
+
if 'mask.png' in z.namelist():
|
193 |
+
mask = z.read('mask.png')
|
194 |
+
if not return_bytes:
|
195 |
+
mask = cv2.imdecode(np.frombuffer(z.read('mask.png'), np.uint8), cv2.IMREAD_UNCHANGED) > 0
|
196 |
+
else:
|
197 |
+
mask = None
|
198 |
+
|
199 |
+
if 'segformer_mask.png' in z.namelist():
|
200 |
+
# NOTE: Legacy support for segformer_mask.png
|
201 |
+
segmentation_mask = z.read('segformer_mask.png')
|
202 |
+
segmentation_labels = None
|
203 |
+
if not return_bytes:
|
204 |
+
segmentation_mask = cv2.imdecode(np.frombuffer(segmentation_mask, np.uint8), cv2.IMREAD_UNCHANGED)
|
205 |
+
segmentation_labels = LEGACY_SEGFORMER_LABELS
|
206 |
+
elif 'segmentation_mask.png' in z.namelist():
|
207 |
+
segmentation_mask = z.read('segmentation_mask.png')
|
208 |
+
segmentation_labels = None
|
209 |
+
if not return_bytes:
|
210 |
+
segmentation_mask = cv2.imdecode(np.frombuffer(segmentation_mask, np.uint8), cv2.IMREAD_UNCHANGED)
|
211 |
+
segmentation_labels = meta['segmentation_labels']
|
212 |
+
else:
|
213 |
+
segmentation_mask = None
|
214 |
+
segmentation_labels = None
|
215 |
+
|
216 |
+
if 'normal.png' in z.namelist():
|
217 |
+
normal = z.read('normal.png')
|
218 |
+
if not return_bytes:
|
219 |
+
normal = cv2.imdecode(np.frombuffer(z.read('normal.png'), np.uint8), cv2.IMREAD_UNCHANGED)
|
220 |
+
normal = cv2.cvtColor(normal, cv2.COLOR_BGR2RGB)
|
221 |
+
normal = (normal.astype(np.float32) / 65535 - 0.5) * [2.0, -2.0, -2.0]
|
222 |
+
normal = normal / np.linalg.norm(normal, axis=-1, keepdims=True)
|
223 |
+
|
224 |
+
if 'normal_mask.png' in z.namelist():
|
225 |
+
normal_mask = z.read('normal_mask.png')
|
226 |
+
normal_mask = cv2.imdecode(np.frombuffer(normal_mask, np.uint8), cv2.IMREAD_UNCHANGED) > 0
|
227 |
+
else:
|
228 |
+
normal_mask = np.ones(image.shape[:2], dtype=bool)
|
229 |
+
else:
|
230 |
+
normal, normal_mask = None, None
|
231 |
+
|
232 |
+
# recover linear depth
|
233 |
+
if not return_bytes:
|
234 |
+
if mask is None:
|
235 |
+
mask = np.ones(image.shape[:2], dtype=bool)
|
236 |
+
if meta['depth_type'] == 'linear':
|
237 |
+
depth = depth.astype(np.float32)
|
238 |
+
mask = mask & (depth > 0)
|
239 |
+
elif meta['depth_type'] == 'log':
|
240 |
+
near, far = meta['depth_near'], meta['depth_far']
|
241 |
+
if depth.dtype == np.uint16:
|
242 |
+
depth = depth.astype(np.float32) / 65535
|
243 |
+
elif depth.dtype == np.uint8:
|
244 |
+
depth = depth.astype(np.float32) / 255
|
245 |
+
depth = near ** (1 - depth) * far ** depth
|
246 |
+
mask = mask & ~np.isnan(depth)
|
247 |
+
elif meta['depth_type'] == 'disparity':
|
248 |
+
mask = mask & (depth > 0)
|
249 |
+
if depth.dtype == np.uint16:
|
250 |
+
depth = depth.astype(np.float32) / 65535
|
251 |
+
elif depth.dtype == np.uint8:
|
252 |
+
depth = depth.astype(np.float32) / 255
|
253 |
+
depth = 1 / (depth + 1e-12)
|
254 |
+
|
255 |
+
# intrinsics
|
256 |
+
if not return_bytes and 'intrinsics' in meta:
|
257 |
+
intrinsics = np.array(meta['intrinsics'], dtype=np.float32)
|
258 |
+
else:
|
259 |
+
intrinsics = None
|
260 |
+
|
261 |
+
# depth unit
|
262 |
+
if not return_bytes and 'depth_unit' in meta:
|
263 |
+
depth_unit_str = meta['depth_unit']
|
264 |
+
if r := re.match(r'([\d.]*)(\w*)', depth_unit_str):
|
265 |
+
digits, unit = r.groups()
|
266 |
+
depth_unit = float(digits or 1) * {'m': 1, 'cm': 0.01, 'mm': 0.001}[unit]
|
267 |
+
else:
|
268 |
+
depth_unit = None
|
269 |
+
else:
|
270 |
+
depth_unit = None
|
271 |
+
|
272 |
+
return_dict = {
|
273 |
+
'image': image,
|
274 |
+
'depth': depth,
|
275 |
+
'mask': mask,
|
276 |
+
'segmentation_mask': segmentation_mask,
|
277 |
+
'segmentation_labels': segmentation_labels,
|
278 |
+
'normal': normal,
|
279 |
+
'normal_mask': normal_mask,
|
280 |
+
'intrinsics': intrinsics,
|
281 |
+
'depth_unit': depth_unit,
|
282 |
+
'meta': meta,
|
283 |
+
}
|
284 |
+
return_dict = {k: v for k, v in return_dict.items() if v is not None}
|
285 |
+
|
286 |
+
return return_dict
|
287 |
+
|
288 |
+
def write_rgbxyz(file: Union[IO, Path], image: np.ndarray, points: np.ndarray, mask: np.ndarray = None, image_quality: int = 95):
|
289 |
+
if isinstance(image, bytes):
|
290 |
+
image_bytes = image
|
291 |
+
elif isinstance(image, np.ndarray):
|
292 |
+
image = cv2.cvtColor(image, cv2.COLOR_RGB2BGR)
|
293 |
+
image_bytes = cv2.imencode('.jpg', image, [cv2.IMWRITE_JPEG_QUALITY, image_quality])[1].tobytes()
|
294 |
+
|
295 |
+
if isinstance(points, bytes):
|
296 |
+
points_bytes = points
|
297 |
+
elif isinstance(points, np.ndarray):
|
298 |
+
points_bytes = cv2.imencode('.exr', points.astype(np.float32), [cv2.IMWRITE_EXR_TYPE, cv2.IMWRITE_EXR_TYPE_FLOAT])[1].tobytes()
|
299 |
+
|
300 |
+
if mask is None:
|
301 |
+
mask = np.ones(image.shape[:2], dtype=bool)
|
302 |
+
if isinstance(mask, bytes):
|
303 |
+
mask_bytes = mask
|
304 |
+
elif isinstance(mask, np.ndarray):
|
305 |
+
mask_bytes = cv2.imencode('.png', mask.astype(np.uint8) * 255)[1].tobytes()
|
306 |
+
|
307 |
+
is_archive = hasattr(file, 'write') or Path(file).suffix == '.zip'
|
308 |
+
if is_archive:
|
309 |
+
with zipfile.ZipFile(file, 'w') as z:
|
310 |
+
z.writestr('image.jpg', image_bytes)
|
311 |
+
z.writestr('points.exr', points_bytes)
|
312 |
+
if mask is not None:
|
313 |
+
z.writestr('mask.png', mask_bytes)
|
314 |
+
else:
|
315 |
+
file = Path(file)
|
316 |
+
file.mkdir(parents=True, exist_ok=True)
|
317 |
+
with open(file / 'image.jpg', 'wb') as f:
|
318 |
+
f.write(image_bytes)
|
319 |
+
with open(file / 'points.exr', 'wb') as f:
|
320 |
+
f.write(points_bytes)
|
321 |
+
if mask is not None:
|
322 |
+
with open(file / 'mask.png', 'wb') as f:
|
323 |
+
f.write(mask_bytes)
|
324 |
+
|
325 |
+
|
326 |
+
def read_rgbxyz(file: Union[IO, str, Path]) -> Tuple[np.ndarray, np.ndarray, np.ndarray, Dict[str, Any]]:
|
327 |
+
is_archive = hasattr(file, 'read') or Path(file).suffix == '.zip'
|
328 |
+
if is_archive:
|
329 |
+
with zipfile.ZipFile(file, 'r') as z:
|
330 |
+
image = cv2.imdecode(np.frombuffer(z.read('image.jpg'), np.uint8), cv2.IMREAD_COLOR)
|
331 |
+
points = cv2.imdecode(np.frombuffer(z.read('points.exr'), np.uint8), cv2.IMREAD_UNCHANGED)
|
332 |
+
if 'mask.png' in z.namelist():
|
333 |
+
mask = cv2.imdecode(np.frombuffer(z.read('mask.png'), np.uint8), cv2.IMREAD_UNCHANGED) > 0
|
334 |
+
else:
|
335 |
+
mask = np.ones(image.shape[:2], dtype=bool)
|
336 |
+
else:
|
337 |
+
file = Path(file)
|
338 |
+
file.mkdir(parents=True, exist_ok=True)
|
339 |
+
image = cv2.imread(str(file / 'image.jpg'), cv2.IMREAD_COLOR)
|
340 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
341 |
+
points = cv2.imread(str(file / 'points.exr'), cv2.IMREAD_UNCHANGED)
|
342 |
+
if (file /'mask.png').exists():
|
343 |
+
mask = cv2.imread(str(file / 'mask.png'), cv2.IMREAD_UNCHANGED) > 0
|
344 |
+
else:
|
345 |
+
mask = np.ones(image.shape[:2], dtype=bool)
|
346 |
+
|
347 |
+
return image, points, mask
|
moge/utils/pipeline.py
ADDED
@@ -0,0 +1,503 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from abc import abstractmethod
|
3 |
+
from queue import Empty, Full
|
4 |
+
from threading import Thread
|
5 |
+
from queue import Queue
|
6 |
+
from multiprocessing import Process
|
7 |
+
from threading import Thread, Event
|
8 |
+
import multiprocessing
|
9 |
+
import threading
|
10 |
+
import inspect
|
11 |
+
import time
|
12 |
+
import uuid
|
13 |
+
from copy import deepcopy
|
14 |
+
import itertools
|
15 |
+
import functools
|
16 |
+
|
17 |
+
__all__ = [
|
18 |
+
'Node',
|
19 |
+
'Link',
|
20 |
+
'ConcurrentNode',
|
21 |
+
'Worker',
|
22 |
+
'WorkerFunction',
|
23 |
+
'Provider',
|
24 |
+
'ProviderFunction',
|
25 |
+
'Sequential',
|
26 |
+
'Batch',
|
27 |
+
'Unbatch',
|
28 |
+
'Parallel',
|
29 |
+
'Graph',
|
30 |
+
'Buffer',
|
31 |
+
]
|
32 |
+
|
33 |
+
TERMINATE_CHECK_INTERVAL = 0.5
|
34 |
+
|
35 |
+
|
36 |
+
class _ItemWrapper:
|
37 |
+
def __init__(self, data: Any, id: Union[int, List[int]] = None):
|
38 |
+
self.data = data
|
39 |
+
self.id = id
|
40 |
+
|
41 |
+
|
42 |
+
class Terminate(Exception):
|
43 |
+
pass
|
44 |
+
|
45 |
+
|
46 |
+
def _get_queue_item(queue: Queue, terminate_flag: Event, timeout: float = None) -> _ItemWrapper:
|
47 |
+
while True:
|
48 |
+
try:
|
49 |
+
item: _ItemWrapper = queue.get(block=True, timeout=TERMINATE_CHECK_INTERVAL if timeout is None else min(timeout, TERMINATE_CHECK_INTERVAL))
|
50 |
+
if terminate_flag.is_set():
|
51 |
+
raise Terminate()
|
52 |
+
return item
|
53 |
+
except Empty:
|
54 |
+
if terminate_flag.is_set():
|
55 |
+
raise Terminate()
|
56 |
+
|
57 |
+
if timeout is not None:
|
58 |
+
timeout -= TERMINATE_CHECK_INTERVAL
|
59 |
+
if timeout <= 0:
|
60 |
+
raise Empty()
|
61 |
+
|
62 |
+
|
63 |
+
def _put_queue_item(queue: Queue, item: _ItemWrapper, terminate_flag: Event):
|
64 |
+
while True:
|
65 |
+
try:
|
66 |
+
queue.put(item, block=True, timeout=TERMINATE_CHECK_INTERVAL)
|
67 |
+
if terminate_flag.is_set():
|
68 |
+
raise Terminate()
|
69 |
+
return
|
70 |
+
except Full:
|
71 |
+
if terminate_flag.is_set():
|
72 |
+
raise Terminate()
|
73 |
+
|
74 |
+
class Node:
|
75 |
+
def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
|
76 |
+
self.input: Queue = Queue(maxsize=in_buffer_size)
|
77 |
+
self.output: Queue = Queue(maxsize=out_buffer_size)
|
78 |
+
self.in_buffer_size = in_buffer_size
|
79 |
+
self.out_buffer_size = out_buffer_size
|
80 |
+
|
81 |
+
@abstractmethod
|
82 |
+
def start(self):
|
83 |
+
pass
|
84 |
+
|
85 |
+
@abstractmethod
|
86 |
+
def terminate(self):
|
87 |
+
pass
|
88 |
+
|
89 |
+
def stop(self):
|
90 |
+
self.terminate()
|
91 |
+
self.join()
|
92 |
+
|
93 |
+
@abstractmethod
|
94 |
+
def join(self):
|
95 |
+
pass
|
96 |
+
|
97 |
+
def put(self, data: Any, key: str = None, block: bool = True) -> None:
|
98 |
+
item = _ItemWrapper(data)
|
99 |
+
self.input.put(item, block=block)
|
100 |
+
|
101 |
+
def get(self, key: str = None, block: bool = True) -> Any:
|
102 |
+
item: _ItemWrapper = self.output.get(block=block)
|
103 |
+
return item.data
|
104 |
+
|
105 |
+
def __enter__(self):
|
106 |
+
self.start()
|
107 |
+
return self
|
108 |
+
|
109 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
110 |
+
self.terminate()
|
111 |
+
self.join()
|
112 |
+
|
113 |
+
|
114 |
+
class ConcurrentNode(Node):
|
115 |
+
job: Union[Thread, Process]
|
116 |
+
|
117 |
+
def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
|
118 |
+
super().__init__(in_buffer_size, out_buffer_size)
|
119 |
+
self.running_as = running_as
|
120 |
+
|
121 |
+
@abstractmethod
|
122 |
+
def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
|
123 |
+
pass
|
124 |
+
|
125 |
+
def start(self):
|
126 |
+
if self.running_as == 'thread':
|
127 |
+
terminate_flag = threading.Event()
|
128 |
+
job = Thread(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
|
129 |
+
elif self.running_as == 'process':
|
130 |
+
terminate_flag = multiprocessing.Event()
|
131 |
+
job = Process(target=self._loop_fn, args=(self.input, self.output, terminate_flag))
|
132 |
+
job.start()
|
133 |
+
self.job = job
|
134 |
+
self.terminate_flag = terminate_flag
|
135 |
+
|
136 |
+
def terminate(self):
|
137 |
+
self.terminate_flag.set()
|
138 |
+
|
139 |
+
def join(self):
|
140 |
+
self.job.join()
|
141 |
+
|
142 |
+
|
143 |
+
class Worker(ConcurrentNode):
|
144 |
+
def __init__(self, running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 0, out_buffer_size: int = 0) -> None:
|
145 |
+
super().__init__(running_as, in_buffer_size, out_buffer_size)
|
146 |
+
|
147 |
+
def init(self) -> None:
|
148 |
+
"""
|
149 |
+
This method is called the the thread is started, to initialize any resources that is only held in the thread.
|
150 |
+
"""
|
151 |
+
pass
|
152 |
+
|
153 |
+
@abstractmethod
|
154 |
+
def work(self, *args, **kwargs) -> Union[Any, Dict[str, Any]]:
|
155 |
+
"""
|
156 |
+
This method defines the job that the node should do for each input item.
|
157 |
+
A item obtained from the input queue is passed as arguments to this method, and the result is placed in the output queue.
|
158 |
+
The method is executed concurrently with other nodes.
|
159 |
+
"""
|
160 |
+
pass
|
161 |
+
|
162 |
+
def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
|
163 |
+
self.init()
|
164 |
+
try:
|
165 |
+
while True:
|
166 |
+
item = _get_queue_item(input, terminate_flag)
|
167 |
+
result = self.work(item.data)
|
168 |
+
_put_queue_item(output, _ItemWrapper(result, item.id), terminate_flag)
|
169 |
+
|
170 |
+
except Terminate:
|
171 |
+
return
|
172 |
+
|
173 |
+
|
174 |
+
class Provider(ConcurrentNode):
|
175 |
+
"""
|
176 |
+
A node that provides data to successive nodes. It takes no input and provides data to the output queue.
|
177 |
+
"""
|
178 |
+
def __init__(self, running_as: Literal['thread', 'process'], out_buffer_size: int = 1) -> None:
|
179 |
+
super().__init__(running_as, 0, out_buffer_size)
|
180 |
+
|
181 |
+
def init(self) -> None:
|
182 |
+
"""
|
183 |
+
This method is called the the thread or process is started, to initialize any resources that is only held in the thread or process.
|
184 |
+
"""
|
185 |
+
pass
|
186 |
+
|
187 |
+
@abstractmethod
|
188 |
+
def provide(self) -> Generator[Any, None, None]:
|
189 |
+
pass
|
190 |
+
|
191 |
+
def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
|
192 |
+
self.init()
|
193 |
+
try:
|
194 |
+
for data in self.provide():
|
195 |
+
_put_queue_item(output, _ItemWrapper(data), terminate_flag)
|
196 |
+
except Terminate:
|
197 |
+
return
|
198 |
+
|
199 |
+
|
200 |
+
class WorkerFunction(Worker):
|
201 |
+
def __init__(self, fn: Callable, running_as: 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1) -> None:
|
202 |
+
super().__init__(running_as, in_buffer_size, out_buffer_size)
|
203 |
+
self.fn = fn
|
204 |
+
|
205 |
+
def work(self, *args, **kwargs):
|
206 |
+
return self.fn(*args, **kwargs)
|
207 |
+
|
208 |
+
|
209 |
+
class ProviderFunction(Provider):
|
210 |
+
def __init__(self, fn: Callable, running_as: 'thread', out_buffer_size: int = 1) -> None:
|
211 |
+
super().__init__(running_as, out_buffer_size)
|
212 |
+
self.fn = fn
|
213 |
+
|
214 |
+
def provide(self):
|
215 |
+
for item in self.fn():
|
216 |
+
yield item
|
217 |
+
|
218 |
+
|
219 |
+
class Link:
|
220 |
+
def __init__(self, src: Queue, dst: Queue):
|
221 |
+
self.src = src
|
222 |
+
self.dst = dst
|
223 |
+
|
224 |
+
def _thread_fn(self):
|
225 |
+
try:
|
226 |
+
while True:
|
227 |
+
item = _get_queue_item(self.src, self.terminate_flag)
|
228 |
+
_put_queue_item(self.dst, item, self.terminate_flag)
|
229 |
+
except Terminate:
|
230 |
+
return
|
231 |
+
|
232 |
+
def start(self):
|
233 |
+
self.terminate_flag = threading.Event()
|
234 |
+
self.thread = Thread(target=self._thread_fn)
|
235 |
+
self.thread.start()
|
236 |
+
|
237 |
+
def terminate(self):
|
238 |
+
self.terminate_flag.set()
|
239 |
+
|
240 |
+
def join(self):
|
241 |
+
self.thread.join()
|
242 |
+
|
243 |
+
|
244 |
+
class Graph(Node):
|
245 |
+
"""
|
246 |
+
Graph pipeline of nodes and links
|
247 |
+
"""
|
248 |
+
nodes: List[Node]
|
249 |
+
links: List[Link]
|
250 |
+
|
251 |
+
def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
|
252 |
+
super().__init__(in_buffer_size, out_buffer_size)
|
253 |
+
self.nodes = []
|
254 |
+
self.links = []
|
255 |
+
|
256 |
+
def add(self, node: Node):
|
257 |
+
self.nodes.append(node)
|
258 |
+
|
259 |
+
def link(self, src: Union[Node, Tuple[Node, str]], dst: Union[Node, Tuple[Node, str]]):
|
260 |
+
"""
|
261 |
+
Links the output of the source node to the input of the destination node.
|
262 |
+
If the source or destination node is None, the pipeline's input or output is used.
|
263 |
+
"""
|
264 |
+
src_queue = self.input if src is None else src.output
|
265 |
+
dst_queue = self.output if dst is None else dst.input
|
266 |
+
self.links.append(Link(src_queue, dst_queue))
|
267 |
+
|
268 |
+
def chain(self, nodes: Iterable[Node]):
|
269 |
+
"""
|
270 |
+
Link the output of each node to the input of the next node.
|
271 |
+
"""
|
272 |
+
nodes = list(nodes)
|
273 |
+
for i in range(len(nodes) - 1):
|
274 |
+
self.link(nodes[i], nodes[i + 1])
|
275 |
+
|
276 |
+
def start(self):
|
277 |
+
for node in self.nodes:
|
278 |
+
node.start()
|
279 |
+
for link in self.links:
|
280 |
+
link.start()
|
281 |
+
|
282 |
+
def terminate(self):
|
283 |
+
for node in self.nodes:
|
284 |
+
node.terminate()
|
285 |
+
for link in self.links:
|
286 |
+
link.terminate()
|
287 |
+
|
288 |
+
def join(self):
|
289 |
+
for node in self.nodes:
|
290 |
+
node.join()
|
291 |
+
for link in self.links:
|
292 |
+
link.join()
|
293 |
+
|
294 |
+
def __iter__(self):
|
295 |
+
providers = [node for node in self.nodes if isinstance(node, Provider)]
|
296 |
+
if len(providers) == 0:
|
297 |
+
raise ValueError("No provider node found in the pipeline. If you want to iterate over the pipeline, the pipeline must be driven by a provider node.")
|
298 |
+
with self:
|
299 |
+
# while all(provider.job.is_alive() for provider in providers):
|
300 |
+
while True:
|
301 |
+
yield self.get()
|
302 |
+
|
303 |
+
def __call__(self, data: Any) -> Any:
|
304 |
+
"""
|
305 |
+
Submit data to the pipeline's input queue, and return the output data asynchronously.
|
306 |
+
NOTE: The pipeline must be streamed (i.e., every output item is uniquely associated with an input item) for this to work.
|
307 |
+
"""
|
308 |
+
# TODO
|
309 |
+
|
310 |
+
|
311 |
+
class Sequential(Graph):
|
312 |
+
"""
|
313 |
+
Pipeline of nodes in sequential order, where each node takes the output of the previous node as input.
|
314 |
+
The order of input and output items is preserved (FIFO)
|
315 |
+
"""
|
316 |
+
def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
|
317 |
+
"""
|
318 |
+
Initialize the pipeline with a list of nodes to execute sequentially.
|
319 |
+
### Parameters:
|
320 |
+
- nodes: List of nodes or functions to execute sequentially. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
|
321 |
+
- function_running_as: Whether to wrap the function as a thread or process worker. Default is 'thread'.
|
322 |
+
- in_buffer_size: Maximum size of the input queue of the pipeline. Default is 0 (unlimited).
|
323 |
+
- out_buffer_size: Maximum size of the output queue of the pipeline. Default is 0 (unlimited).
|
324 |
+
"""
|
325 |
+
super().__init__(in_buffer_size, out_buffer_size)
|
326 |
+
for node in nodes:
|
327 |
+
if isinstance(node, Node):
|
328 |
+
pass
|
329 |
+
elif isinstance(node, Callable):
|
330 |
+
if inspect.isgeneratorfunction(node):
|
331 |
+
node = ProviderFunction(node, function_running_as)
|
332 |
+
else:
|
333 |
+
node = WorkerFunction(node, function_running_as)
|
334 |
+
else:
|
335 |
+
raise ValueError(f"Invalid node type: {type(node)}")
|
336 |
+
self.add(node)
|
337 |
+
self.chain([None, *self.nodes, None])
|
338 |
+
|
339 |
+
|
340 |
+
class Parallel(Node):
|
341 |
+
"""
|
342 |
+
A FIFO node that runs multiple nodes in parallel to process the input items. Each input item is handed to one of the nodes whoever is available.
|
343 |
+
NOTE: It is FIFO if and only if all the nested nodes are FIFO.
|
344 |
+
"""
|
345 |
+
nodes: List[Node]
|
346 |
+
|
347 |
+
def __init__(self, nodes: Iterable[Node], in_buffer_size: int = 1, out_buffer_size: int = 1, function_running_as: Literal['thread', 'process'] = 'thread'):
|
348 |
+
super().__init__(in_buffer_size, out_buffer_size)
|
349 |
+
self.nodes = []
|
350 |
+
for node in nodes:
|
351 |
+
if isinstance(node, Node):
|
352 |
+
pass
|
353 |
+
elif isinstance(node, Callable):
|
354 |
+
if inspect.isgeneratorfunction(node):
|
355 |
+
node = ProviderFunction(node, function_running_as)
|
356 |
+
else:
|
357 |
+
node = WorkerFunction(node, function_running_as)
|
358 |
+
else:
|
359 |
+
raise ValueError(f"Invalid node type: {type(node)}")
|
360 |
+
self.nodes.append(node)
|
361 |
+
self.output_order = Queue()
|
362 |
+
self.lock = threading.Lock()
|
363 |
+
|
364 |
+
def _in_thread_fn(self, node: Node):
|
365 |
+
try:
|
366 |
+
while True:
|
367 |
+
with self.lock:
|
368 |
+
# A better idea: first make sure its node is vacant, then get it a new item.
|
369 |
+
# Currently we will not be able to know which node is busy util there is at least one item already waiting in the queue of the node.
|
370 |
+
# This could lead to suboptimal scheduling.
|
371 |
+
item = _get_queue_item(self.input, self.terminate_flag)
|
372 |
+
self.output_order.put(node.output)
|
373 |
+
_put_queue_item(node.input, item, self.terminate_flag)
|
374 |
+
except Terminate:
|
375 |
+
return
|
376 |
+
|
377 |
+
def _out_thread_fn(self):
|
378 |
+
try:
|
379 |
+
while True:
|
380 |
+
queue = _get_queue_item(self.output_order, self.terminate_flag)
|
381 |
+
item = _get_queue_item(queue, self.terminate_flag)
|
382 |
+
_put_queue_item(self.output, item, self.terminate_flag)
|
383 |
+
except Terminate:
|
384 |
+
return
|
385 |
+
|
386 |
+
def start(self):
|
387 |
+
self.terminate_flag = threading.Event()
|
388 |
+
self.in_threads = []
|
389 |
+
for node in self.nodes:
|
390 |
+
thread = Thread(target=self._in_thread_fn, args=(node,))
|
391 |
+
thread.start()
|
392 |
+
self.in_threads.append(thread)
|
393 |
+
thread = Thread(target=self._out_thread_fn)
|
394 |
+
thread.start()
|
395 |
+
self.out_thread = thread
|
396 |
+
for node in self.nodes:
|
397 |
+
node.start()
|
398 |
+
|
399 |
+
def terminate(self):
|
400 |
+
self.terminate_flag.set()
|
401 |
+
for node in self.nodes:
|
402 |
+
node.terminate()
|
403 |
+
|
404 |
+
def join(self):
|
405 |
+
for thread in self.in_threads:
|
406 |
+
thread.join()
|
407 |
+
self.out_thread.join()
|
408 |
+
|
409 |
+
|
410 |
+
class UnorderedParallel(Graph):
|
411 |
+
"""
|
412 |
+
Pipeline of nodes in parallel, where each input item is handed to one of the nodes whoever is available.
|
413 |
+
NOTE: The order of the output items is NOT guaranteed to be the same as the input items, depending on how fast the nodes handle their input.
|
414 |
+
"""
|
415 |
+
def __init__(self, nodes: List[Union[Node, Callable]], function_running_as: Literal['thread', 'process'] = 'thread', in_buffer_size: int = 1, out_buffer_size: int = 1):
|
416 |
+
"""
|
417 |
+
Initialize the pipeline with a list of nodes to execute in parallel. If a function is given, it is wrapped in a worker node.
|
418 |
+
### Parameters:
|
419 |
+
- nodes: List of nodes or functions to execute in parallel. Generator functions are wrapped in provider nodes, and other functions are wrapped in worker nodes.
|
420 |
+
- function_running_as: Whether to wrap the function as a thread or process worker. Default is 'thread'.
|
421 |
+
- in_buffer_size: Maximum size of the input queue of the pipeline. Default is 0 (unlimited).
|
422 |
+
- out_buffer_size: Maximum size of the output queue of the pipeline. Default is 0 (unlimited).
|
423 |
+
"""
|
424 |
+
super().__init__(in_buffer_size, out_buffer_size)
|
425 |
+
for node in nodes:
|
426 |
+
if isinstance(node, Node):
|
427 |
+
pass
|
428 |
+
elif isinstance(node, Callable):
|
429 |
+
if inspect.isgeneratorfunction(node):
|
430 |
+
node = ProviderFunction(node, function_running_as)
|
431 |
+
else:
|
432 |
+
node = WorkerFunction(node, function_running_as)
|
433 |
+
else:
|
434 |
+
raise ValueError(f"Invalid node type: {type(node)}")
|
435 |
+
self.add(node)
|
436 |
+
for i in range(len(nodes)):
|
437 |
+
self.chain([None, self.nodes[i], None])
|
438 |
+
|
439 |
+
|
440 |
+
class Batch(ConcurrentNode):
|
441 |
+
"""
|
442 |
+
Groups every `batch_size` items into a batch (a list of items) and passes the batch to successive nodes.
|
443 |
+
The `patience` parameter specifies the maximum time to wait for a batch to be filled before sending it to the next node,
|
444 |
+
i.e., when the earliest item in the batch is out of `patience` seconds, the batch is sent regardless of its size.
|
445 |
+
"""
|
446 |
+
def __init__(self, batch_size: int, patience: float = None, in_buffer_size: int = 1, out_buffer_size: int = 1):
|
447 |
+
assert batch_size > 0, "Batch size must be greater than 0."
|
448 |
+
super().__init__('thread', in_buffer_size, out_buffer_size)
|
449 |
+
self.batch_size = batch_size
|
450 |
+
self.patience = patience
|
451 |
+
|
452 |
+
def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
|
453 |
+
try:
|
454 |
+
while True:
|
455 |
+
batch_id, batch_data = [], []
|
456 |
+
# Try to fill the batch
|
457 |
+
for i in range(self.batch_size):
|
458 |
+
if i == 0 or self.patience is None:
|
459 |
+
timeout = None
|
460 |
+
else:
|
461 |
+
timeout = self.patience - (time.time() - earliest_time)
|
462 |
+
if timeout < 0:
|
463 |
+
break
|
464 |
+
try:
|
465 |
+
item = _get_queue_item(input, terminate_flag, timeout)
|
466 |
+
except Empty:
|
467 |
+
break
|
468 |
+
|
469 |
+
if i == 0:
|
470 |
+
earliest_time = time.time()
|
471 |
+
batch_data.append(item.data)
|
472 |
+
batch_id.append(item.id)
|
473 |
+
|
474 |
+
batch = _ItemWrapper(batch_data, batch_id)
|
475 |
+
_put_queue_item(output, batch, terminate_flag)
|
476 |
+
except Terminate:
|
477 |
+
return
|
478 |
+
|
479 |
+
|
480 |
+
class Unbatch(ConcurrentNode):
|
481 |
+
"""
|
482 |
+
Ungroups every batch (a list of items) into individual items and passes them to successive nodes.
|
483 |
+
"""
|
484 |
+
def __init__(self, in_buffer_size: int = 1, out_buffer_size: int = 1):
|
485 |
+
super().__init__('thread', in_buffer_size, out_buffer_size)
|
486 |
+
|
487 |
+
def _loop_fn(self, input: Queue, output: Queue, terminate_flag: Event):
|
488 |
+
try:
|
489 |
+
while True:
|
490 |
+
batch = _get_queue_item(input, terminate_flag)
|
491 |
+
for id, data in zip(batch.id or itertools.repeat(None), batch.data):
|
492 |
+
item = _ItemWrapper(data, id)
|
493 |
+
_put_queue_item(output, item, terminate_flag)
|
494 |
+
except Terminate:
|
495 |
+
return
|
496 |
+
|
497 |
+
|
498 |
+
class Buffer(Node):
|
499 |
+
"A FIFO node that buffers items in a queue. Usefull achieve better temporal balance when its successor node has a variable processing time."
|
500 |
+
def __init__(self, size: int):
|
501 |
+
super().__init__(size, size)
|
502 |
+
self.size = size
|
503 |
+
self.input = self.output = Queue(maxsize=size)
|
moge/utils/tools.py
ADDED
@@ -0,0 +1,240 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import time
|
3 |
+
from pathlib import Path
|
4 |
+
from numbers import Number
|
5 |
+
|
6 |
+
|
7 |
+
def catch_exception(fn):
|
8 |
+
def wrapper(*args, **kwargs):
|
9 |
+
try:
|
10 |
+
return fn(*args, **kwargs)
|
11 |
+
except Exception as e:
|
12 |
+
import traceback
|
13 |
+
print(f"Exception in {fn.__name__}({', '.join(repr(arg) for arg in args)}, {', '.join(f'{k}={v!r}' for k, v in kwargs.items())})")
|
14 |
+
traceback.print_exc(chain=False)
|
15 |
+
time.sleep(0.1)
|
16 |
+
return None
|
17 |
+
return wrapper
|
18 |
+
|
19 |
+
|
20 |
+
class CallbackOnException:
|
21 |
+
def __init__(self, callback: Callable, exception: type):
|
22 |
+
self.exception = exception
|
23 |
+
self.callback = callback
|
24 |
+
|
25 |
+
def __enter__(self):
|
26 |
+
return self
|
27 |
+
|
28 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
29 |
+
if isinstance(exc_val, self.exception):
|
30 |
+
self.callback()
|
31 |
+
return True
|
32 |
+
return False
|
33 |
+
|
34 |
+
def traverse_nested_dict_keys(d: Dict[str, Dict]) -> Generator[Tuple[str, ...], None, None]:
|
35 |
+
for k, v in d.items():
|
36 |
+
if isinstance(v, dict):
|
37 |
+
for sub_key in traverse_nested_dict_keys(v):
|
38 |
+
yield (k, ) + sub_key
|
39 |
+
else:
|
40 |
+
yield (k, )
|
41 |
+
|
42 |
+
|
43 |
+
def get_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], default: Any = None):
|
44 |
+
for k in keys:
|
45 |
+
d = d.get(k, default)
|
46 |
+
if d is None:
|
47 |
+
break
|
48 |
+
return d
|
49 |
+
|
50 |
+
def set_nested_dict(d: Dict[str, Dict], keys: Tuple[str, ...], value: Any):
|
51 |
+
for k in keys[:-1]:
|
52 |
+
d = d.setdefault(k, {})
|
53 |
+
d[keys[-1]] = value
|
54 |
+
|
55 |
+
|
56 |
+
def key_average(list_of_dicts: list) -> Dict[str, Any]:
|
57 |
+
"""
|
58 |
+
Returns a dictionary with the average value of each key in the input list of dictionaries.
|
59 |
+
"""
|
60 |
+
_nested_dict_keys = set()
|
61 |
+
for d in list_of_dicts:
|
62 |
+
_nested_dict_keys.update(traverse_nested_dict_keys(d))
|
63 |
+
_nested_dict_keys = sorted(_nested_dict_keys)
|
64 |
+
result = {}
|
65 |
+
for k in _nested_dict_keys:
|
66 |
+
values = [
|
67 |
+
get_nested_dict(d, k) for d in list_of_dicts
|
68 |
+
if get_nested_dict(d, k) is not None
|
69 |
+
]
|
70 |
+
avg = sum(values) / len(values) if values else float('nan')
|
71 |
+
set_nested_dict(result, k, avg)
|
72 |
+
return result
|
73 |
+
|
74 |
+
|
75 |
+
def flatten_nested_dict(d: Dict[str, Any], parent_key: Tuple[str, ...] = None) -> Dict[Tuple[str, ...], Any]:
|
76 |
+
"""
|
77 |
+
Flattens a nested dictionary into a single-level dictionary, with keys as tuples.
|
78 |
+
"""
|
79 |
+
items = []
|
80 |
+
if parent_key is None:
|
81 |
+
parent_key = ()
|
82 |
+
for k, v in d.items():
|
83 |
+
new_key = parent_key + (k, )
|
84 |
+
if isinstance(v, MutableMapping):
|
85 |
+
items.extend(flatten_nested_dict(v, new_key).items())
|
86 |
+
else:
|
87 |
+
items.append((new_key, v))
|
88 |
+
return dict(items)
|
89 |
+
|
90 |
+
|
91 |
+
def unflatten_nested_dict(d: Dict[str, Any]) -> Dict[str, Any]:
|
92 |
+
"""
|
93 |
+
Unflattens a single-level dictionary into a nested dictionary, with keys as tuples.
|
94 |
+
"""
|
95 |
+
result = {}
|
96 |
+
for k, v in d.items():
|
97 |
+
sub_dict = result
|
98 |
+
for k_ in k[:-1]:
|
99 |
+
if k_ not in sub_dict:
|
100 |
+
sub_dict[k_] = {}
|
101 |
+
sub_dict = sub_dict[k_]
|
102 |
+
sub_dict[k[-1]] = v
|
103 |
+
return result
|
104 |
+
|
105 |
+
|
106 |
+
def read_jsonl(file):
|
107 |
+
import json
|
108 |
+
with open(file, 'r') as f:
|
109 |
+
data = f.readlines()
|
110 |
+
return [json.loads(line) for line in data]
|
111 |
+
|
112 |
+
|
113 |
+
def write_jsonl(data: List[dict], file):
|
114 |
+
import json
|
115 |
+
with open(file, 'w') as f:
|
116 |
+
for item in data:
|
117 |
+
f.write(json.dumps(item) + '\n')
|
118 |
+
|
119 |
+
|
120 |
+
def save_metrics(save_path: Union[str, Path], all_metrics: Dict[str, List[Dict]]):
|
121 |
+
import pandas as pd
|
122 |
+
import json
|
123 |
+
|
124 |
+
with open(save_path, 'w') as f:
|
125 |
+
json.dump(all_metrics, f, indent=4)
|
126 |
+
|
127 |
+
|
128 |
+
def to_hierachical_dataframe(data: List[Dict[Tuple[str, ...], Any]]):
|
129 |
+
import pandas as pd
|
130 |
+
data = [flatten_nested_dict(d) for d in data]
|
131 |
+
df = pd.DataFrame(data)
|
132 |
+
df = df.sort_index(axis=1)
|
133 |
+
df.columns = pd.MultiIndex.from_tuples(df.columns)
|
134 |
+
return df
|
135 |
+
|
136 |
+
|
137 |
+
def recursive_replace(d: Union[List, Dict, str], mapping: Dict[str, str]):
|
138 |
+
if isinstance(d, str):
|
139 |
+
for old, new in mapping.items():
|
140 |
+
d = d.replace(old, new)
|
141 |
+
elif isinstance(d, list):
|
142 |
+
for i, item in enumerate(d):
|
143 |
+
d[i] = recursive_replace(item, mapping)
|
144 |
+
elif isinstance(d, dict):
|
145 |
+
for k, v in d.items():
|
146 |
+
d[k] = recursive_replace(v, mapping)
|
147 |
+
return d
|
148 |
+
|
149 |
+
|
150 |
+
class timeit:
|
151 |
+
_history: Dict[str, List['timeit']] = {}
|
152 |
+
|
153 |
+
def __init__(self, name: str = None, verbose: bool = True, multiple: bool = False):
|
154 |
+
self.name = name
|
155 |
+
self.verbose = verbose
|
156 |
+
self.start = None
|
157 |
+
self.end = None
|
158 |
+
self.multiple = multiple
|
159 |
+
if multiple and name not in timeit._history:
|
160 |
+
timeit._history[name] = []
|
161 |
+
|
162 |
+
def __call__(self, func: Callable):
|
163 |
+
import inspect
|
164 |
+
if inspect.iscoroutinefunction(func):
|
165 |
+
async def wrapper(*args, **kwargs):
|
166 |
+
with timeit(self.name or func.__qualname__):
|
167 |
+
ret = await func(*args, **kwargs)
|
168 |
+
return ret
|
169 |
+
return wrapper
|
170 |
+
else:
|
171 |
+
def wrapper(*args, **kwargs):
|
172 |
+
with timeit(self.name or func.__qualname__):
|
173 |
+
ret = func(*args, **kwargs)
|
174 |
+
return ret
|
175 |
+
return wrapper
|
176 |
+
|
177 |
+
def __enter__(self):
|
178 |
+
self.start = time.time()
|
179 |
+
|
180 |
+
@property
|
181 |
+
def time(self) -> float:
|
182 |
+
assert self.start is not None, "Time not yet started."
|
183 |
+
assert self.end is not None, "Time not yet ended."
|
184 |
+
return self.end - self.start
|
185 |
+
|
186 |
+
@property
|
187 |
+
def history(self) -> List['timeit']:
|
188 |
+
return timeit._history.get(self.name, [])
|
189 |
+
|
190 |
+
def __exit__(self, exc_type, exc_val, exc_tb):
|
191 |
+
self.end = time.time()
|
192 |
+
if self.multiple:
|
193 |
+
timeit._history[self.name].append(self)
|
194 |
+
if self.verbose:
|
195 |
+
if self.multiple:
|
196 |
+
avg = sum(t.time for t in timeit._history[self.name]) / len(timeit._history[self.name])
|
197 |
+
print(f"{self.name or 'It'} took {avg} seconds in average.")
|
198 |
+
else:
|
199 |
+
print(f"{self.name or 'It'} took {self.time} seconds.")
|
200 |
+
|
201 |
+
|
202 |
+
def strip_common_prefix_suffix(strings: List[str]) -> List[str]:
|
203 |
+
first = strings[0]
|
204 |
+
|
205 |
+
for start in range(len(first)):
|
206 |
+
if any(s[start] != strings[0][start] for s in strings):
|
207 |
+
break
|
208 |
+
|
209 |
+
for end in range(1, min(len(s) for s in strings)):
|
210 |
+
if any(s[-end] != first[-end] for s in strings):
|
211 |
+
break
|
212 |
+
|
213 |
+
return [s[start:len(s) - end + 1] for s in strings]
|
214 |
+
|
215 |
+
|
216 |
+
def multithead_execute(inputs: List[Any], num_workers: int, pbar = None):
|
217 |
+
from concurrent.futures import ThreadPoolExecutor
|
218 |
+
from contextlib import nullcontext
|
219 |
+
from tqdm import tqdm
|
220 |
+
|
221 |
+
if pbar is not None:
|
222 |
+
pbar.total = len(inputs) if hasattr(inputs, '__len__') else None
|
223 |
+
else:
|
224 |
+
pbar = tqdm(total=len(inputs) if hasattr(inputs, '__len__') else None)
|
225 |
+
|
226 |
+
def decorator(fn: Callable):
|
227 |
+
with (
|
228 |
+
ThreadPoolExecutor(max_workers=num_workers) as executor,
|
229 |
+
pbar
|
230 |
+
):
|
231 |
+
pbar.refresh()
|
232 |
+
@catch_exception
|
233 |
+
def _fn(input):
|
234 |
+
ret = fn(input)
|
235 |
+
pbar.update()
|
236 |
+
return ret
|
237 |
+
executor.map(_fn, inputs)
|
238 |
+
executor.shutdown(wait=True)
|
239 |
+
|
240 |
+
return decorator
|
moge/utils/vis.py
ADDED
@@ -0,0 +1,51 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import matplotlib
|
3 |
+
|
4 |
+
|
5 |
+
def colorize_depth(depth: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
|
6 |
+
if mask is None:
|
7 |
+
depth = np.where(depth > 0, depth, np.nan)
|
8 |
+
else:
|
9 |
+
depth = np.where((depth > 0) & mask, depth, np.nan)
|
10 |
+
disp = 1 / depth
|
11 |
+
if normalize:
|
12 |
+
min_disp, max_disp = np.nanquantile(disp, 0.001), np.nanquantile(disp, 0.999)
|
13 |
+
disp = (disp - min_disp) / (max_disp - min_disp)
|
14 |
+
colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disp), 0)
|
15 |
+
colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3]
|
16 |
+
return colored
|
17 |
+
|
18 |
+
|
19 |
+
def colorize_depth_affine(depth: np.ndarray, mask: np.ndarray = None, cmap: str = 'Spectral') -> np.ndarray:
|
20 |
+
if mask is not None:
|
21 |
+
depth = np.where(mask, depth, np.nan)
|
22 |
+
|
23 |
+
min_depth, max_depth = np.nanquantile(depth, 0.001), np.nanquantile(depth, 0.999)
|
24 |
+
depth = (depth - min_depth) / (max_depth - min_depth)
|
25 |
+
colored = np.nan_to_num(matplotlib.colormaps[cmap](depth), 0)
|
26 |
+
colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3]
|
27 |
+
return colored
|
28 |
+
|
29 |
+
|
30 |
+
def colorize_disparity(disparity: np.ndarray, mask: np.ndarray = None, normalize: bool = True, cmap: str = 'Spectral') -> np.ndarray:
|
31 |
+
if mask is not None:
|
32 |
+
disparity = np.where(mask, disparity, np.nan)
|
33 |
+
|
34 |
+
if normalize:
|
35 |
+
min_disp, max_disp = np.nanquantile(disparity, 0.001), np.nanquantile(disparity, 0.999)
|
36 |
+
disparity = (disparity - min_disp) / (max_disp - min_disp)
|
37 |
+
colored = np.nan_to_num(matplotlib.colormaps[cmap](1.0 - disparity), 0)
|
38 |
+
colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3]
|
39 |
+
return colored
|
40 |
+
|
41 |
+
|
42 |
+
def colorize_segmentation(segmentation: np.ndarray, cmap: str = 'Set1') -> np.ndarray:
|
43 |
+
colored = matplotlib.colormaps[cmap]((segmentation % 20) / 20)
|
44 |
+
colored = (colored.clip(0, 1) * 255).astype(np.uint8)[:, :, :3]
|
45 |
+
return colored
|
46 |
+
|
47 |
+
|
48 |
+
def colorize_normal(normal: np.ndarray) -> np.ndarray:
|
49 |
+
normal = normal * [0.5, -0.5, -0.5] + 0.5
|
50 |
+
normal = (normal.clip(0, 1) * 255).astype(np.uint8)
|
51 |
+
return normal
|
moge/utils/webfile.py
ADDED
@@ -0,0 +1,73 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import requests
|
2 |
+
from typing import *
|
3 |
+
|
4 |
+
__all__ = ["WebFile"]
|
5 |
+
|
6 |
+
|
7 |
+
class WebFile:
|
8 |
+
def __init__(self, url: str, session: Optional[requests.Session] = None, headers: Optional[Dict[str, str]] = None, size: Optional[int] = None):
|
9 |
+
self.url = url
|
10 |
+
self.session = session or requests.Session()
|
11 |
+
self.session.headers.update(headers or {})
|
12 |
+
self._offset = 0
|
13 |
+
self.size = size if size is not None else self._fetch_size()
|
14 |
+
|
15 |
+
def _fetch_size(self):
|
16 |
+
with self.session.get(self.url, stream=True) as response:
|
17 |
+
response.raise_for_status()
|
18 |
+
content_length = response.headers.get("Content-Length")
|
19 |
+
if content_length is None:
|
20 |
+
raise ValueError("Missing Content-Length in header")
|
21 |
+
return int(content_length)
|
22 |
+
|
23 |
+
def _fetch_data(self, offset: int, n: int) -> bytes:
|
24 |
+
headers = {"Range": f"bytes={offset}-{min(offset + n - 1, self.size)}"}
|
25 |
+
response = self.session.get(self.url, headers=headers)
|
26 |
+
response.raise_for_status()
|
27 |
+
return response.content
|
28 |
+
|
29 |
+
def seekable(self) -> bool:
|
30 |
+
return True
|
31 |
+
|
32 |
+
def tell(self) -> int:
|
33 |
+
return self._offset
|
34 |
+
|
35 |
+
def available(self) -> int:
|
36 |
+
return self.size - self._offset
|
37 |
+
|
38 |
+
def seek(self, offset: int, whence: int = 0) -> None:
|
39 |
+
if whence == 0:
|
40 |
+
new_offset = offset
|
41 |
+
elif whence == 1:
|
42 |
+
new_offset = self._offset + offset
|
43 |
+
elif whence == 2:
|
44 |
+
new_offset = self.size + offset
|
45 |
+
else:
|
46 |
+
raise ValueError("Invalid value for whence")
|
47 |
+
|
48 |
+
self._offset = max(0, min(new_offset, self.size))
|
49 |
+
|
50 |
+
def read(self, n: Optional[int] = None) -> bytes:
|
51 |
+
if n is None or n < 0:
|
52 |
+
n = self.available()
|
53 |
+
else:
|
54 |
+
n = min(n, self.available())
|
55 |
+
|
56 |
+
if n == 0:
|
57 |
+
return b''
|
58 |
+
|
59 |
+
data = self._fetch_data(self._offset, n)
|
60 |
+
self._offset += len(data)
|
61 |
+
|
62 |
+
return data
|
63 |
+
|
64 |
+
def close(self) -> None:
|
65 |
+
pass
|
66 |
+
|
67 |
+
def __enter__(self):
|
68 |
+
return self
|
69 |
+
|
70 |
+
def __exit__(self, exc_type, exc_value, traceback):
|
71 |
+
pass
|
72 |
+
|
73 |
+
|
moge/utils/webzipfile.py
ADDED
@@ -0,0 +1,128 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
import io
|
3 |
+
import os
|
4 |
+
from zipfile import (
|
5 |
+
ZipInfo, BadZipFile, ZipFile, ZipExtFile,
|
6 |
+
sizeFileHeader, structFileHeader, stringFileHeader,
|
7 |
+
_FH_SIGNATURE, _FH_FILENAME_LENGTH, _FH_EXTRA_FIELD_LENGTH, _FH_GENERAL_PURPOSE_FLAG_BITS,
|
8 |
+
_MASK_COMPRESSED_PATCH, _MASK_STRONG_ENCRYPTION, _MASK_UTF_FILENAME, _MASK_ENCRYPTED
|
9 |
+
)
|
10 |
+
import struct
|
11 |
+
from requests import Session
|
12 |
+
|
13 |
+
from .webfile import WebFile
|
14 |
+
|
15 |
+
|
16 |
+
class _SharedWebFile(WebFile):
|
17 |
+
def __init__(self, webfile: WebFile, pos: int):
|
18 |
+
super().__init__(webfile.url, webfile.session, size=webfile.size)
|
19 |
+
self.seek(pos)
|
20 |
+
|
21 |
+
|
22 |
+
class WebZipFile(ZipFile):
|
23 |
+
"Lock-free version of ZipFile that reads from a WebFile, allowing for concurrent reads."
|
24 |
+
def __init__(self, url: str, session: Optional[Session] = None, headers: Optional[Dict[str, str]] = None):
|
25 |
+
"""Open the ZIP file with mode read 'r', write 'w', exclusive create 'x',
|
26 |
+
or append 'a'."""
|
27 |
+
webf = WebFile(url, session=session, headers=headers)
|
28 |
+
super().__init__(webf, mode='r')
|
29 |
+
|
30 |
+
def open(self, name, mode="r", pwd=None, *, force_zip64=False):
|
31 |
+
"""Return file-like object for 'name'.
|
32 |
+
|
33 |
+
name is a string for the file name within the ZIP file, or a ZipInfo
|
34 |
+
object.
|
35 |
+
|
36 |
+
mode should be 'r' to read a file already in the ZIP file, or 'w' to
|
37 |
+
write to a file newly added to the archive.
|
38 |
+
|
39 |
+
pwd is the password to decrypt files (only used for reading).
|
40 |
+
|
41 |
+
When writing, if the file size is not known in advance but may exceed
|
42 |
+
2 GiB, pass force_zip64 to use the ZIP64 format, which can handle large
|
43 |
+
files. If the size is known in advance, it is best to pass a ZipInfo
|
44 |
+
instance for name, with zinfo.file_size set.
|
45 |
+
"""
|
46 |
+
if mode not in {"r", "w"}:
|
47 |
+
raise ValueError('open() requires mode "r" or "w"')
|
48 |
+
if pwd and (mode == "w"):
|
49 |
+
raise ValueError("pwd is only supported for reading files")
|
50 |
+
if not self.fp:
|
51 |
+
raise ValueError(
|
52 |
+
"Attempt to use ZIP archive that was already closed")
|
53 |
+
|
54 |
+
assert mode == "r", "Only read mode is supported for now"
|
55 |
+
|
56 |
+
# Make sure we have an info object
|
57 |
+
if isinstance(name, ZipInfo):
|
58 |
+
# 'name' is already an info object
|
59 |
+
zinfo = name
|
60 |
+
elif mode == 'w':
|
61 |
+
zinfo = ZipInfo(name)
|
62 |
+
zinfo.compress_type = self.compression
|
63 |
+
zinfo._compresslevel = self.compresslevel
|
64 |
+
else:
|
65 |
+
# Get info object for name
|
66 |
+
zinfo = self.getinfo(name)
|
67 |
+
|
68 |
+
if mode == 'w':
|
69 |
+
return self._open_to_write(zinfo, force_zip64=force_zip64)
|
70 |
+
|
71 |
+
if self._writing:
|
72 |
+
raise ValueError("Can't read from the ZIP file while there "
|
73 |
+
"is an open writing handle on it. "
|
74 |
+
"Close the writing handle before trying to read.")
|
75 |
+
|
76 |
+
# Open for reading:
|
77 |
+
self._fileRefCnt += 1
|
78 |
+
zef_file = _SharedWebFile(self.fp, zinfo.header_offset)
|
79 |
+
|
80 |
+
try:
|
81 |
+
# Skip the file header:
|
82 |
+
fheader = zef_file.read(sizeFileHeader)
|
83 |
+
if len(fheader) != sizeFileHeader:
|
84 |
+
raise BadZipFile("Truncated file header")
|
85 |
+
fheader = struct.unpack(structFileHeader, fheader)
|
86 |
+
if fheader[_FH_SIGNATURE] != stringFileHeader:
|
87 |
+
raise BadZipFile("Bad magic number for file header")
|
88 |
+
|
89 |
+
fname = zef_file.read(fheader[_FH_FILENAME_LENGTH])
|
90 |
+
if fheader[_FH_EXTRA_FIELD_LENGTH]:
|
91 |
+
zef_file.seek(fheader[_FH_EXTRA_FIELD_LENGTH], whence=1)
|
92 |
+
|
93 |
+
if zinfo.flag_bits & _MASK_COMPRESSED_PATCH:
|
94 |
+
# Zip 2.7: compressed patched data
|
95 |
+
raise NotImplementedError("compressed patched data (flag bit 5)")
|
96 |
+
|
97 |
+
if zinfo.flag_bits & _MASK_STRONG_ENCRYPTION:
|
98 |
+
# strong encryption
|
99 |
+
raise NotImplementedError("strong encryption (flag bit 6)")
|
100 |
+
|
101 |
+
if fheader[_FH_GENERAL_PURPOSE_FLAG_BITS] & _MASK_UTF_FILENAME:
|
102 |
+
# UTF-8 filename
|
103 |
+
fname_str = fname.decode("utf-8")
|
104 |
+
else:
|
105 |
+
fname_str = fname.decode(self.metadata_encoding or "cp437")
|
106 |
+
|
107 |
+
if fname_str != zinfo.orig_filename:
|
108 |
+
raise BadZipFile(
|
109 |
+
'File name in directory %r and header %r differ.'
|
110 |
+
% (zinfo.orig_filename, fname))
|
111 |
+
|
112 |
+
# check for encrypted flag & handle password
|
113 |
+
is_encrypted = zinfo.flag_bits & _MASK_ENCRYPTED
|
114 |
+
if is_encrypted:
|
115 |
+
if not pwd:
|
116 |
+
pwd = self.pwd
|
117 |
+
if pwd and not isinstance(pwd, bytes):
|
118 |
+
raise TypeError("pwd: expected bytes, got %s" % type(pwd).__name__)
|
119 |
+
if not pwd:
|
120 |
+
raise RuntimeError("File %r is encrypted, password "
|
121 |
+
"required for extraction" % name)
|
122 |
+
else:
|
123 |
+
pwd = None
|
124 |
+
|
125 |
+
return ZipExtFile(zef_file, mode, zinfo, pwd, True)
|
126 |
+
except:
|
127 |
+
zef_file.close()
|
128 |
+
raise
|
packages.txt
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
python3-opencv
|
requirements.txt
ADDED
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
opencv-python
|
2 |
+
plyfile
|
3 |
+
pygltflib
|
4 |
+
transformers
|
5 |
+
scikit-learn
|
utils3d/__init__.py
ADDED
@@ -0,0 +1,14 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
A package for common utility functions in 3D computer graphics and vision. Providing NumPy utilities in `utils3d.numpy`, PyTorch utilities in `utils3d.torch`, and IO utilities in `utils3d.io`.
|
3 |
+
"""
|
4 |
+
import importlib
|
5 |
+
|
6 |
+
__all__ = ['numpy', 'torch', 'io']
|
7 |
+
|
8 |
+
def __getattr__(module_name: str):
|
9 |
+
return importlib.import_module(f'.{module_name}', __package__)
|
10 |
+
|
11 |
+
if __name__ == '__main__':
|
12 |
+
from . import torch
|
13 |
+
from . import numpy
|
14 |
+
from . import io
|
utils3d/io/__init__.py
ADDED
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from .wavefront_obj import *
|
2 |
+
from .colmap import *
|
3 |
+
from .ply import *
|
4 |
+
from .glb import *
|
utils3d/io/colmap.py
ADDED
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
from scipy.spatial.transform import Rotation
|
6 |
+
|
7 |
+
|
8 |
+
__all__ = ['read_extrinsics_from_colmap', 'read_intrinsics_from_colmap', 'write_extrinsics_as_colmap', 'write_intrinsics_as_colmap']
|
9 |
+
|
10 |
+
|
11 |
+
def write_extrinsics_as_colmap(file: Union[str, Path], extrinsics: np.ndarray, image_names: Union[str, List[str]] = 'image_{i:04d}.png', camera_ids: List[int] = None):
|
12 |
+
"""
|
13 |
+
Write extrinsics to colmap `images.txt` file.
|
14 |
+
Args:
|
15 |
+
file: Path to `images.txt` file.
|
16 |
+
extrinsics: (N, 4, 4) array of extrinsics.
|
17 |
+
image_names: str or List of str, image names. Length is N.
|
18 |
+
If str, it should be a format string with `i` as the index. (i starts from 1, in correspondence with IMAGE_ID in colmap)
|
19 |
+
camera_ids: List of int, camera ids. Length is N.
|
20 |
+
If None, it will be set to [1, 2, ..., N].
|
21 |
+
"""
|
22 |
+
assert extrinsics.shape[1:] == (4, 4) and extrinsics.ndim == 3 or extrinsics.shape == (4, 4)
|
23 |
+
if extrinsics.ndim == 2:
|
24 |
+
extrinsics = extrinsics[np.newaxis, ...]
|
25 |
+
quats = Rotation.from_matrix(extrinsics[:, :3, :3]).as_quat()
|
26 |
+
trans = extrinsics[:, :3, 3]
|
27 |
+
if camera_ids is None:
|
28 |
+
camera_ids = list(range(1, len(extrinsics) + 1))
|
29 |
+
if isinstance(image_names, str):
|
30 |
+
image_names = [image_names.format(i=i) for i in range(1, len(extrinsics) + 1)]
|
31 |
+
assert len(extrinsics) == len(image_names) == len(camera_ids), \
|
32 |
+
f'Number of extrinsics ({len(extrinsics)}), image_names ({len(image_names)}), and camera_ids ({len(camera_ids)}) must be the same'
|
33 |
+
with open(file, 'w') as fp:
|
34 |
+
print("# IMAGE_ID, QW, QX, QY, QZ, TX, TY, TZ, CAMERA_ID, NAME", file=fp)
|
35 |
+
for i, (quat, t, name, camera_id) in enumerate(zip(quats.tolist(), trans.tolist(), image_names, camera_ids)):
|
36 |
+
# Colmap has wxyz order while scipy.spatial.transform.Rotation has xyzw order. Haha, wcnm.
|
37 |
+
qx, qy, qz, qw = quat
|
38 |
+
tx, ty, tz = t
|
39 |
+
print(f'{i + 1} {qw:f} {qx:f} {qy:f} {qz:f} {tx:f} {ty:f} {tz:f} {camera_id:d} {name}', file=fp)
|
40 |
+
print()
|
41 |
+
|
42 |
+
|
43 |
+
def write_intrinsics_as_colmap(file: Union[str, Path], intrinsics: np.ndarray, width: int, height: int, normalized: bool = False):
|
44 |
+
"""
|
45 |
+
Write intrinsics to colmap `cameras.txt` file. Currently only support PINHOLE model (no distortion)
|
46 |
+
Args:
|
47 |
+
file: Path to `cameras.txt` file.
|
48 |
+
intrinsics: (N, 3, 3) array of intrinsics.
|
49 |
+
width: Image width.
|
50 |
+
height: Image height.
|
51 |
+
normalized: Whether the intrinsics are normalized. If True, the intrinsics will unnormalized for writing.
|
52 |
+
"""
|
53 |
+
assert intrinsics.shape[1:] == (3, 3) and intrinsics.ndim == 3 or intrinsics.shape == (3, 3)
|
54 |
+
if intrinsics.ndim == 2:
|
55 |
+
intrinsics = intrinsics[np.newaxis, ...]
|
56 |
+
if normalized:
|
57 |
+
intrinsics = intrinsics * np.array([width, height, 1])[:, None]
|
58 |
+
with open(file, 'w') as fp:
|
59 |
+
print("# CAMERA_ID, MODEL, WIDTH, HEIGHT, PARAMS[]", file=fp)
|
60 |
+
for i, intr in enumerate(intrinsics):
|
61 |
+
fx, fy, cx, cy = intr[0, 0], intr[1, 1], intr[0, 2], intr[1, 2]
|
62 |
+
print(f'{i + 1} PINHOLE {width:d} {height:d} {fx:f} {fy:f} {cx:f} {cy:f}', file=fp)
|
63 |
+
|
64 |
+
|
65 |
+
def read_extrinsics_from_colmap(file: Union[str, Path]) -> Union[np.ndarray, List[int], List[str]]:
|
66 |
+
"""
|
67 |
+
Read extrinsics from colmap `images.txt` file.
|
68 |
+
Args:
|
69 |
+
file: Path to `images.txt` file.
|
70 |
+
Returns:
|
71 |
+
extrinsics: (N, 4, 4) array of extrinsics.
|
72 |
+
camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1.
|
73 |
+
image_names: List of str, image names. Length is N.
|
74 |
+
"""
|
75 |
+
with open(file) as fp:
|
76 |
+
lines = fp.readlines()
|
77 |
+
image_names, quats, trans, camera_ids = [], [], [], []
|
78 |
+
i_line = 0
|
79 |
+
for line in lines:
|
80 |
+
line = line.strip()
|
81 |
+
if line.startswith('#'):
|
82 |
+
continue
|
83 |
+
i_line += 1
|
84 |
+
if i_line % 2 == 0:
|
85 |
+
continue
|
86 |
+
image_id, qw, qx, qy, qz, tx, ty, tz, camera_id, name = line.split()
|
87 |
+
quats.append([float(qx), float(qy), float(qz), float(qw)])
|
88 |
+
trans.append([float(tx), float(ty), float(tz)])
|
89 |
+
camera_ids.append(int(camera_id))
|
90 |
+
image_names.append(name)
|
91 |
+
|
92 |
+
quats = np.array(quats, dtype=np.float32)
|
93 |
+
trans = np.array(trans, dtype=np.float32)
|
94 |
+
rotation = Rotation.from_quat(quats).as_matrix()
|
95 |
+
extrinsics = np.concatenate([
|
96 |
+
np.concatenate([rotation, trans[..., None]], axis=-1),
|
97 |
+
np.array([0, 0, 0, 1], dtype=np.float32)[None, None, :].repeat(len(quats), axis=0)
|
98 |
+
], axis=-2)
|
99 |
+
|
100 |
+
return extrinsics, camera_ids, image_names
|
101 |
+
|
102 |
+
|
103 |
+
def read_intrinsics_from_colmap(file: Union[str, Path], normalize: bool = False) -> Tuple[List[int], np.ndarray, np.ndarray]:
|
104 |
+
"""
|
105 |
+
Read intrinsics from colmap `cameras.txt` file.
|
106 |
+
Args:
|
107 |
+
file: Path to `cameras.txt` file.
|
108 |
+
normalize: Whether to normalize the intrinsics. If True, the intrinsics will be normalized. (mapping coordinates to [0, 1] range)
|
109 |
+
Returns:
|
110 |
+
camera_ids: List of int, camera ids. Length is N. Note that camera ids in colmap typically starts from 1.
|
111 |
+
intrinsics: (N, 3, 3) array of intrinsics.
|
112 |
+
distortions: (N, 5) array of distortions.
|
113 |
+
"""
|
114 |
+
with open(file) as fp:
|
115 |
+
lines = fp.readlines()
|
116 |
+
intrinsics, distortions, camera_ids = [], [], []
|
117 |
+
for line in lines:
|
118 |
+
line = line.strip()
|
119 |
+
if not line or line.startswith('#'):
|
120 |
+
continue
|
121 |
+
camera_id, model, width, height, *params = line.split()
|
122 |
+
camera_id, width, height = int(camera_id), int(width), int(height)
|
123 |
+
if model == 'PINHOLE':
|
124 |
+
fx, fy, cx, cy = map(float, params[:4])
|
125 |
+
k1 = k2 = k3 = p1 = p2 = 0.0
|
126 |
+
elif model == 'OPENCV':
|
127 |
+
fx, fy, cx, cy, k1, k2, p1, p2, k3 = *map(float, params[:8]), 0.0
|
128 |
+
elif model == 'SIMPLE_RADIAL':
|
129 |
+
f, cx, cy, k = map(float, params[:4])
|
130 |
+
fx = fy = f
|
131 |
+
k1, k2, p1, p2, k3 = k, 0.0, 0.0, 0.0, 0.0
|
132 |
+
camera_ids.append(camera_id)
|
133 |
+
if normalize:
|
134 |
+
fx, fy, cx, cy = fx / width, fy / height, cx / width, cy / height
|
135 |
+
intrinsics.append([[fx, 0, cx], [0, fy, cy], [0, 0, 1]])
|
136 |
+
distortions.append([k1, k2, p1, p2, k3])
|
137 |
+
intrinsics = np.array(intrinsics, dtype=np.float32)
|
138 |
+
distortions = np.array(distortions, dtype=np.float32)
|
139 |
+
return camera_ids, intrinsics, distortions
|
utils3d/io/glb.py
ADDED
@@ -0,0 +1,105 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from typing import *
|
2 |
+
from pathlib import Path
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
|
6 |
+
|
7 |
+
def write_glb(path: Union[str, Path], vertices: np.ndarray, faces: np.ndarray, vertex_colors: np.ndarray = None, uv: np.ndarray = None):
|
8 |
+
import pygltflib
|
9 |
+
|
10 |
+
has_colors = vertex_colors is not None
|
11 |
+
has_uv = uv is not None
|
12 |
+
|
13 |
+
triangles_bytes = faces.astype(np.uint32).flatten().tobytes()
|
14 |
+
vertices_bytes = vertices.astype(np.float32).tobytes()
|
15 |
+
vertex_colors_bytes = vertex_colors.astype(np.float32).tobytes() if has_colors else None
|
16 |
+
uv_bytes = uv.astype(np.float32).tobytes() if has_uv else None
|
17 |
+
|
18 |
+
|
19 |
+
gltf = pygltflib.GLTF2(
|
20 |
+
scene=0,
|
21 |
+
scenes=[pygltflib.Scene(nodes=[0])],
|
22 |
+
nodes=[pygltflib.Node(mesh=0)],
|
23 |
+
meshes=[
|
24 |
+
pygltflib.Mesh(
|
25 |
+
primitives=[
|
26 |
+
pygltflib.Primitive(
|
27 |
+
attributes=pygltflib.Attributes(
|
28 |
+
POSITION=1,
|
29 |
+
COLOR_0=2 if has_colors else None,
|
30 |
+
TEXCOORD_0=2 + has_colors if has_uv else None
|
31 |
+
),
|
32 |
+
indices=0
|
33 |
+
)
|
34 |
+
]
|
35 |
+
)
|
36 |
+
],
|
37 |
+
accessors=list(filter(None, [
|
38 |
+
pygltflib.Accessor( # triangles accessor
|
39 |
+
bufferView=0,
|
40 |
+
componentType=pygltflib.UNSIGNED_INT,
|
41 |
+
count=faces.size,
|
42 |
+
type=pygltflib.SCALAR,
|
43 |
+
max=[int(faces.max())],
|
44 |
+
min=[int(faces.min())],
|
45 |
+
),
|
46 |
+
pygltflib.Accessor( # vertices accessor
|
47 |
+
bufferView=1,
|
48 |
+
componentType=pygltflib.FLOAT,
|
49 |
+
count=len(vertices),
|
50 |
+
type=pygltflib.VEC3,
|
51 |
+
max=vertices.max(axis=0).tolist(),
|
52 |
+
min=vertices.min(axis=0).tolist(),
|
53 |
+
),
|
54 |
+
pygltflib.Accessor( # vertex colors accessor
|
55 |
+
bufferView=2,
|
56 |
+
componentType=pygltflib.FLOAT,
|
57 |
+
count=len(vertices),
|
58 |
+
type=pygltflib.VEC3,
|
59 |
+
max=vertex_colors.max(axis=0).tolist(),
|
60 |
+
min=vertex_colors.min(axis=0).tolist(),
|
61 |
+
) if has_colors else None,
|
62 |
+
pygltflib.Accessor( # uv accessor
|
63 |
+
bufferView=3,
|
64 |
+
componentType=pygltflib.FLOAT,
|
65 |
+
count=len(uv),
|
66 |
+
type=pygltflib.VEC2,
|
67 |
+
max=uv.max(axis=0).tolist(),
|
68 |
+
min=uv.min(axis=0).tolist(),
|
69 |
+
) if has_uv else None,
|
70 |
+
])),
|
71 |
+
bufferViews=list(filter(None, [
|
72 |
+
pygltflib.BufferView( # triangles buffer view
|
73 |
+
buffer=0,
|
74 |
+
byteLength=len(triangles_bytes),
|
75 |
+
target=pygltflib.ELEMENT_ARRAY_BUFFER,
|
76 |
+
),
|
77 |
+
pygltflib.BufferView( # vertices buffer view
|
78 |
+
buffer=0,
|
79 |
+
byteOffset=len(triangles_bytes),
|
80 |
+
byteLength=len(vertices_bytes),
|
81 |
+
target=pygltflib.ARRAY_BUFFER,
|
82 |
+
),
|
83 |
+
pygltflib.BufferView( # vertex colors buffer view
|
84 |
+
buffer=0,
|
85 |
+
byteOffset=len(triangles_bytes) + len(vertices_bytes),
|
86 |
+
byteLength=len(vertex_colors_bytes),
|
87 |
+
target=pygltflib.ARRAY_BUFFER,
|
88 |
+
) if has_colors else None,
|
89 |
+
pygltflib.BufferView( # uv buffer view
|
90 |
+
buffer=0,
|
91 |
+
byteOffset=len(triangles_bytes) + len(vertices_bytes) + (len(vertex_colors_bytes) if has_colors else 0),
|
92 |
+
byteLength=len(uv_bytes),
|
93 |
+
target=pygltflib.ARRAY_BUFFER,
|
94 |
+
) if has_uv else None,
|
95 |
+
])),
|
96 |
+
buffers=[
|
97 |
+
pygltflib.Buffer(
|
98 |
+
byteLength=len(triangles_bytes) + len(vertices_bytes) + (len(vertex_colors_bytes) if has_colors else 0) + (len(uv_bytes) if has_uv else 0),
|
99 |
+
)
|
100 |
+
]
|
101 |
+
)
|
102 |
+
gltf.set_binary_blob(triangles_bytes + vertices_bytes + (vertex_colors_bytes or b'') + (uv_bytes or b''))
|
103 |
+
with open(path, 'wb') as f:
|
104 |
+
for chunk in gltf.save_to_bytes():
|
105 |
+
f.write(chunk)
|
utils3d/io/ply.py
ADDED
@@ -0,0 +1,104 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
|
3 |
+
from typing import *
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
|
7 |
+
def read_ply(
|
8 |
+
file: Union[str, Path],
|
9 |
+
encoding: Union[str, None] = None,
|
10 |
+
ignore_unknown: bool = False
|
11 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
12 |
+
"""
|
13 |
+
Read .ply file, without preprocessing.
|
14 |
+
|
15 |
+
Args:
|
16 |
+
file (Any): filepath
|
17 |
+
encoding (str, optional):
|
18 |
+
|
19 |
+
Returns:
|
20 |
+
Tuple[np.ndarray, np.ndarray]: vertices, faces
|
21 |
+
"""
|
22 |
+
import plyfile
|
23 |
+
plydata = plyfile.PlyData.read(file)
|
24 |
+
vertices = np.stack([plydata['vertex'][k] for k in ['x', 'y', 'z']], axis=-1)
|
25 |
+
if 'face' in plydata:
|
26 |
+
faces = np.array(plydata['face']['vertex_indices'].tolist())
|
27 |
+
else:
|
28 |
+
faces = None
|
29 |
+
return vertices, faces
|
30 |
+
|
31 |
+
|
32 |
+
def write_ply(
|
33 |
+
file: Union[str, Path],
|
34 |
+
vertices: np.ndarray,
|
35 |
+
faces: np.ndarray = None,
|
36 |
+
edges: np.ndarray = None,
|
37 |
+
vertex_colors: np.ndarray = None,
|
38 |
+
edge_colors: np.ndarray = None,
|
39 |
+
text: bool = False
|
40 |
+
):
|
41 |
+
"""
|
42 |
+
Write .ply file, without preprocessing.
|
43 |
+
|
44 |
+
Args:
|
45 |
+
file (Any): filepath
|
46 |
+
vertices (np.ndarray): [N, 3]
|
47 |
+
faces (np.ndarray): [T, E]
|
48 |
+
edges (np.ndarray): [E, 2]
|
49 |
+
vertex_colors (np.ndarray, optional): [N, 3]. Defaults to None.
|
50 |
+
edge_colors (np.ndarray, optional): [E, 3]. Defaults to None.
|
51 |
+
text (bool, optional): save data in text format. Defaults to False.
|
52 |
+
"""
|
53 |
+
import plyfile
|
54 |
+
assert vertices.ndim == 2 and vertices.shape[1] == 3
|
55 |
+
vertices = vertices.astype(np.float32)
|
56 |
+
if faces is not None:
|
57 |
+
assert faces.ndim == 2
|
58 |
+
faces = faces.astype(np.int32)
|
59 |
+
if edges is not None:
|
60 |
+
assert edges.ndim == 2 and edges.shape[1] == 2
|
61 |
+
edges = edges.astype(np.int32)
|
62 |
+
|
63 |
+
if vertex_colors is not None:
|
64 |
+
assert vertex_colors.ndim == 2 and vertex_colors.shape[1] == 3
|
65 |
+
if vertex_colors.dtype in [np.float32, np.float64]:
|
66 |
+
vertex_colors = vertex_colors * 255
|
67 |
+
vertex_colors = np.clip(vertex_colors, 0, 255).astype(np.uint8)
|
68 |
+
vertices_data = np.zeros(len(vertices), dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
|
69 |
+
vertices_data['x'] = vertices[:, 0]
|
70 |
+
vertices_data['y'] = vertices[:, 1]
|
71 |
+
vertices_data['z'] = vertices[:, 2]
|
72 |
+
vertices_data['red'] = vertex_colors[:, 0]
|
73 |
+
vertices_data['green'] = vertex_colors[:, 1]
|
74 |
+
vertices_data['blue'] = vertex_colors[:, 2]
|
75 |
+
else:
|
76 |
+
vertices_data = np.array([tuple(v) for v in vertices], dtype=[('x', 'f4'), ('y', 'f4'), ('z', 'f4')])
|
77 |
+
|
78 |
+
if faces is not None:
|
79 |
+
faces_data = np.zeros(len(faces), dtype=[('vertex_indices', 'i4', (faces.shape[1],))])
|
80 |
+
faces_data['vertex_indices'] = faces
|
81 |
+
|
82 |
+
if edges is not None:
|
83 |
+
if edge_colors is not None:
|
84 |
+
assert edge_colors.ndim == 2 and edge_colors.shape[1] == 3
|
85 |
+
if edge_colors.dtype in [np.float32, np.float64]:
|
86 |
+
edge_colors = edge_colors * 255
|
87 |
+
edge_colors = np.clip(edge_colors, 0, 255).astype(np.uint8)
|
88 |
+
edges_data = np.zeros(len(edges), dtype=[('vertex1', 'i4'), ('vertex2', 'i4'), ('red', 'u1'), ('green', 'u1'), ('blue', 'u1')])
|
89 |
+
edges_data['vertex1'] = edges[:, 0]
|
90 |
+
edges_data['vertex2'] = edges[:, 1]
|
91 |
+
edges_data['red'] = edge_colors[:, 0]
|
92 |
+
edges_data['green'] = edge_colors[:, 1]
|
93 |
+
edges_data['blue'] = edge_colors[:, 2]
|
94 |
+
else:
|
95 |
+
edges_data = np.array([tuple(e) for e in edges], dtype=[('vertex1', 'i4'), ('vertex2', 'i4')])
|
96 |
+
|
97 |
+
ply_data = [plyfile.PlyElement.describe(vertices_data, 'vertex')]
|
98 |
+
if faces is not None:
|
99 |
+
ply_data.append(plyfile.PlyElement.describe(faces_data, 'face'))
|
100 |
+
if edges is not None:
|
101 |
+
ply_data.append(plyfile.PlyElement.describe(edges_data, 'edge'))
|
102 |
+
|
103 |
+
plyfile.PlyData(ply_data, text=text).write(file)
|
104 |
+
|
utils3d/io/wavefront_obj.py
ADDED
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
from io import TextIOWrapper
|
2 |
+
from typing import Dict, Any, Union, Iterable
|
3 |
+
import numpy as np
|
4 |
+
from pathlib import Path
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'read_obj',
|
8 |
+
'write_obj',
|
9 |
+
'simple_write_obj'
|
10 |
+
]
|
11 |
+
|
12 |
+
def read_obj(
|
13 |
+
file : Union[str, Path, TextIOWrapper],
|
14 |
+
encoding: Union[str, None] = None,
|
15 |
+
ignore_unknown: bool = False
|
16 |
+
):
|
17 |
+
"""
|
18 |
+
Read wavefront .obj file, without preprocessing.
|
19 |
+
|
20 |
+
Why bothering having this read_obj() while we already have other libraries like `trimesh`?
|
21 |
+
This function read the raw format from .obj file and keeps the order of vertices and faces,
|
22 |
+
while trimesh which involves modification like merge/split vertices, which could break the orders of vertices and faces,
|
23 |
+
Those libraries are commonly aiming at geometry processing and rendering supporting various formats.
|
24 |
+
If you want mesh geometry processing, you may turn to `trimesh` for more features.
|
25 |
+
|
26 |
+
### Parameters
|
27 |
+
`file` (str, Path, TextIOWrapper): filepath or file object
|
28 |
+
encoding (str, optional):
|
29 |
+
|
30 |
+
### Returns
|
31 |
+
obj (dict): A dict containing .obj components
|
32 |
+
{
|
33 |
+
'mtllib': [],
|
34 |
+
'v': [[0,1, 0.2, 1.0], [1.2, 0.0, 0.0], ...],
|
35 |
+
'vt': [[0.5, 0.5], ...],
|
36 |
+
'vn': [[0., 0.7, 0.7], [0., -0.7, 0.7], ...],
|
37 |
+
'f': [[0, 1, 2], [2, 3, 4],...],
|
38 |
+
'usemtl': [{'name': 'mtl1', 'f': 7}]
|
39 |
+
}
|
40 |
+
"""
|
41 |
+
if hasattr(file,'read'):
|
42 |
+
lines = file.read().splitlines()
|
43 |
+
else:
|
44 |
+
with open(file, 'r', encoding=encoding) as fp:
|
45 |
+
lines = fp.read().splitlines()
|
46 |
+
mtllib = []
|
47 |
+
v, vt, vn, vp = [], [], [], [] # Vertex coordinates, Vertex texture coordinate, Vertex normal, Vertex parameter
|
48 |
+
f, ft, fn = [], [], [] # Face indices, Face texture indices, Face normal indices
|
49 |
+
o = []
|
50 |
+
s = []
|
51 |
+
usemtl = []
|
52 |
+
|
53 |
+
def pad(l: list, n: Any):
|
54 |
+
return l + [n] * (3 - len(l))
|
55 |
+
|
56 |
+
for i, line in enumerate(lines):
|
57 |
+
sq = line.strip().split()
|
58 |
+
if len(sq) == 0:
|
59 |
+
continue
|
60 |
+
if sq[0] == 'v':
|
61 |
+
assert 4 <= len(sq) <= 5, f'Invalid format of line {i}: {line}'
|
62 |
+
v.append([float(e) for e in sq[1:]][:3])
|
63 |
+
elif sq[0] == 'vt':
|
64 |
+
assert 3 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
|
65 |
+
vt.append([float(e) for e in sq[1:]][:2])
|
66 |
+
elif sq[0] == 'vn':
|
67 |
+
assert len(sq) == 4, f'Invalid format of line {i}: {line}'
|
68 |
+
vn.append([float(e) for e in sq[1:]])
|
69 |
+
elif sq[0] == 'vp':
|
70 |
+
assert 2 <= len(sq) <= 4, f'Invalid format of line {i}: {line}'
|
71 |
+
vp.append(pad([float(e) for e in sq[1:]], 0))
|
72 |
+
elif sq[0] == 'f':
|
73 |
+
spliting = [pad([int(j) - 1 for j in e.split('/')], -1) for e in sq[1:]]
|
74 |
+
f.append([e[0] for e in spliting])
|
75 |
+
ft.append([e[1] for e in spliting])
|
76 |
+
fn.append([e[2] for e in spliting])
|
77 |
+
elif sq[0] == 'usemtl':
|
78 |
+
assert len(sq) == 2
|
79 |
+
usemtl.append((sq[1], len(f)))
|
80 |
+
elif sq[0] == 'o':
|
81 |
+
assert len(sq) == 2
|
82 |
+
o.append((sq[1], len(f)))
|
83 |
+
elif sq[0] == 's':
|
84 |
+
s.append((sq[1], len(f)))
|
85 |
+
elif sq[0] == 'mtllib':
|
86 |
+
assert len(sq) == 2
|
87 |
+
mtllib.append(sq[1])
|
88 |
+
elif sq[0][0] == '#':
|
89 |
+
continue
|
90 |
+
else:
|
91 |
+
if not ignore_unknown:
|
92 |
+
raise Exception(f'Unknown keyword {sq[0]}')
|
93 |
+
|
94 |
+
min_poly_vertices = min(len(f) for f in f)
|
95 |
+
max_poly_vertices = max(len(f) for f in f)
|
96 |
+
|
97 |
+
return {
|
98 |
+
'mtllib': mtllib,
|
99 |
+
'v': np.array(v, dtype=np.float32),
|
100 |
+
'vt': np.array(vt, dtype=np.float32),
|
101 |
+
'vn': np.array(vn, dtype=np.float32),
|
102 |
+
'vp': np.array(vp, dtype=np.float32),
|
103 |
+
'f': np.array(f, dtype=np.int32) if min_poly_vertices == max_poly_vertices else f,
|
104 |
+
'ft': np.array(ft, dtype=np.int32) if min_poly_vertices == max_poly_vertices else ft,
|
105 |
+
'fn': np.array(fn, dtype=np.int32) if min_poly_vertices == max_poly_vertices else fn,
|
106 |
+
'o': o,
|
107 |
+
's': s,
|
108 |
+
'usemtl': usemtl,
|
109 |
+
}
|
110 |
+
|
111 |
+
|
112 |
+
def write_obj(
|
113 |
+
file: Union[str, Path],
|
114 |
+
obj: Dict[str, Any],
|
115 |
+
encoding: Union[str, None] = None
|
116 |
+
):
|
117 |
+
with open(file, 'w', encoding=encoding) as fp:
|
118 |
+
for k in ['v', 'vt', 'vn', 'vp']:
|
119 |
+
if k not in obj:
|
120 |
+
continue
|
121 |
+
for v in obj[k]:
|
122 |
+
print(k, *map(float, v), file=fp)
|
123 |
+
for f in obj['f']:
|
124 |
+
print('f', *((str('/').join(map(int, i)) if isinstance(int(i), Iterable) else i) for i in f), file=fp)
|
125 |
+
|
126 |
+
|
127 |
+
def simple_write_obj(
|
128 |
+
file: Union[str, Path],
|
129 |
+
vertices: np.ndarray,
|
130 |
+
faces: np.ndarray,
|
131 |
+
encoding: Union[str, None] = None
|
132 |
+
):
|
133 |
+
"""
|
134 |
+
Write wavefront .obj file, without preprocessing.
|
135 |
+
|
136 |
+
Args:
|
137 |
+
vertices (np.ndarray): [N, 3]
|
138 |
+
faces (np.ndarray): [T, 3]
|
139 |
+
file (Any): filepath
|
140 |
+
encoding (str, optional):
|
141 |
+
"""
|
142 |
+
with open(file, 'w', encoding=encoding) as fp:
|
143 |
+
for v in vertices:
|
144 |
+
print('v', *map(float, v), file=fp)
|
145 |
+
for f in faces:
|
146 |
+
print('f', *map(int, f + 1), file=fp)
|
utils3d/numpy/__init__.py
ADDED
@@ -0,0 +1,135 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
"""
|
2 |
+
3D utility functions workings with NumPy.
|
3 |
+
"""
|
4 |
+
import importlib
|
5 |
+
import itertools
|
6 |
+
import numpy
|
7 |
+
|
8 |
+
|
9 |
+
__modules_all__ = {
|
10 |
+
'mesh':[
|
11 |
+
'triangulate',
|
12 |
+
'compute_face_normal',
|
13 |
+
'compute_face_angle',
|
14 |
+
'compute_vertex_normal',
|
15 |
+
'compute_vertex_normal_weighted',
|
16 |
+
'remove_corrupted_faces',
|
17 |
+
'merge_duplicate_vertices',
|
18 |
+
'remove_unreferenced_vertices',
|
19 |
+
'subdivide_mesh_simple',
|
20 |
+
'mesh_relations',
|
21 |
+
'flatten_mesh_indices'
|
22 |
+
],
|
23 |
+
'quadmesh': [
|
24 |
+
'calc_quad_candidates',
|
25 |
+
'calc_quad_distortion',
|
26 |
+
'calc_quad_direction',
|
27 |
+
'calc_quad_smoothness',
|
28 |
+
'sovle_quad',
|
29 |
+
'sovle_quad_qp',
|
30 |
+
'tri_to_quad'
|
31 |
+
],
|
32 |
+
'utils': [
|
33 |
+
'sliding_window_1d',
|
34 |
+
'sliding_window_nd',
|
35 |
+
'sliding_window_2d',
|
36 |
+
'max_pool_1d',
|
37 |
+
'max_pool_2d',
|
38 |
+
'max_pool_nd',
|
39 |
+
'depth_edge',
|
40 |
+
'depth_aliasing',
|
41 |
+
'interpolate',
|
42 |
+
'image_scrcoord',
|
43 |
+
'image_uv',
|
44 |
+
'image_pixel_center',
|
45 |
+
'image_pixel',
|
46 |
+
'image_mesh',
|
47 |
+
'image_mesh_from_depth',
|
48 |
+
'depth_to_normal',
|
49 |
+
'point_to_normal',
|
50 |
+
'chessboard',
|
51 |
+
'cube',
|
52 |
+
'square',
|
53 |
+
'camera_frustum',
|
54 |
+
],
|
55 |
+
'transforms': [
|
56 |
+
'perspective',
|
57 |
+
'perspective_from_fov',
|
58 |
+
'perspective_from_fov_xy',
|
59 |
+
'intrinsics_from_focal_center',
|
60 |
+
'intrinsics_from_fov',
|
61 |
+
'view_look_at',
|
62 |
+
'extrinsics_look_at',
|
63 |
+
'perspective_to_intrinsics',
|
64 |
+
'perspective_to_near_far',
|
65 |
+
'intrinsics_to_perspective',
|
66 |
+
'extrinsics_to_view',
|
67 |
+
'view_to_extrinsics',
|
68 |
+
'normalize_intrinsics',
|
69 |
+
'crop_intrinsics',
|
70 |
+
'pixel_to_uv',
|
71 |
+
'pixel_to_ndc',
|
72 |
+
'uv_to_pixel',
|
73 |
+
'project_depth',
|
74 |
+
'depth_buffer_to_linear',
|
75 |
+
'unproject_cv',
|
76 |
+
'unproject_gl',
|
77 |
+
'project_cv',
|
78 |
+
'project_gl',
|
79 |
+
'quaternion_to_matrix',
|
80 |
+
'axis_angle_to_matrix',
|
81 |
+
'matrix_to_quaternion',
|
82 |
+
'extrinsics_to_essential',
|
83 |
+
'euler_axis_angle_rotation',
|
84 |
+
'euler_angles_to_matrix',
|
85 |
+
'skew_symmetric',
|
86 |
+
'rotation_matrix_from_vectors',
|
87 |
+
'ray_intersection',
|
88 |
+
'se3_matrix',
|
89 |
+
'slerp_quaternion',
|
90 |
+
'slerp_vector',
|
91 |
+
'lerp',
|
92 |
+
'lerp_se3_matrix',
|
93 |
+
'piecewise_lerp',
|
94 |
+
'piecewise_lerp_se3_matrix',
|
95 |
+
'apply_transform'
|
96 |
+
],
|
97 |
+
'spline': [
|
98 |
+
'linear_spline_interpolate',
|
99 |
+
],
|
100 |
+
'rasterization': [
|
101 |
+
'RastContext',
|
102 |
+
'rasterize_triangle_faces',
|
103 |
+
'rasterize_edges',
|
104 |
+
'texture',
|
105 |
+
'warp_image_by_depth',
|
106 |
+
],
|
107 |
+
}
|
108 |
+
|
109 |
+
|
110 |
+
__all__ = list(itertools.chain(*__modules_all__.values()))
|
111 |
+
|
112 |
+
def __getattr__(name):
|
113 |
+
try:
|
114 |
+
return globals()[name]
|
115 |
+
except KeyError:
|
116 |
+
pass
|
117 |
+
|
118 |
+
try:
|
119 |
+
module_name = next(m for m in __modules_all__ if name in __modules_all__[m])
|
120 |
+
except StopIteration:
|
121 |
+
raise AttributeError(f"module '{__name__}' has no attribute '{name}'")
|
122 |
+
module = importlib.import_module(f'.{module_name}', __name__)
|
123 |
+
for key in __modules_all__[module_name]:
|
124 |
+
globals()[key] = getattr(module, key)
|
125 |
+
|
126 |
+
return globals()[name]
|
127 |
+
|
128 |
+
|
129 |
+
if __name__ == '__main__':
|
130 |
+
from .quadmesh import *
|
131 |
+
from .transforms import *
|
132 |
+
from .mesh import *
|
133 |
+
from .utils import *
|
134 |
+
from .rasterization import *
|
135 |
+
from .spline import *
|
utils3d/numpy/_helpers.py
ADDED
@@ -0,0 +1,88 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# decorator
|
2 |
+
import numpy as np
|
3 |
+
from numbers import Number
|
4 |
+
import inspect
|
5 |
+
|
6 |
+
|
7 |
+
def get_args_order(func, args, kwargs):
|
8 |
+
"""
|
9 |
+
Get the order of the arguments of a function.
|
10 |
+
"""
|
11 |
+
names = inspect.getfullargspec(func).args
|
12 |
+
names_idx = {name: i for i, name in enumerate(names)}
|
13 |
+
args_order = []
|
14 |
+
kwargs_order = {}
|
15 |
+
for name, arg in kwargs.items():
|
16 |
+
if name in names:
|
17 |
+
kwargs_order[name] = names_idx[name]
|
18 |
+
names.remove(name)
|
19 |
+
for i, arg in enumerate(args):
|
20 |
+
if i < len(names):
|
21 |
+
args_order.append(names_idx[names[i]])
|
22 |
+
return args_order, kwargs_order
|
23 |
+
|
24 |
+
|
25 |
+
def broadcast_args(args, kwargs, args_dim, kwargs_dim):
|
26 |
+
spatial = []
|
27 |
+
for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())):
|
28 |
+
if isinstance(arg, np.ndarray) and arg_dim is not None:
|
29 |
+
arg_spatial = arg.shape[:arg.ndim-arg_dim]
|
30 |
+
if len(arg_spatial) > len(spatial):
|
31 |
+
spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
|
32 |
+
for j in range(len(arg_spatial)):
|
33 |
+
if spatial[-j] < arg_spatial[-j]:
|
34 |
+
if spatial[-j] == 1:
|
35 |
+
spatial[-j] = arg_spatial[-j]
|
36 |
+
else:
|
37 |
+
raise ValueError("Cannot broadcast arguments.")
|
38 |
+
for i, arg in enumerate(args):
|
39 |
+
if isinstance(arg, np.ndarray) and args_dim[i] is not None:
|
40 |
+
args[i] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]])
|
41 |
+
for key, arg in kwargs.items():
|
42 |
+
if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
|
43 |
+
kwargs[key] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]])
|
44 |
+
return args, kwargs, spatial
|
45 |
+
|
46 |
+
|
47 |
+
def batched(*dims):
|
48 |
+
"""
|
49 |
+
Decorator that allows a function to be called with batched arguments.
|
50 |
+
"""
|
51 |
+
def decorator(func):
|
52 |
+
def wrapper(*args, **kwargs):
|
53 |
+
args = list(args)
|
54 |
+
# get arguments dimensions
|
55 |
+
args_order, kwargs_order = get_args_order(func, args, kwargs)
|
56 |
+
args_dim = [dims[i] for i in args_order]
|
57 |
+
kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
|
58 |
+
# convert to numpy array
|
59 |
+
for i, arg in enumerate(args):
|
60 |
+
if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
|
61 |
+
args[i] = np.array(arg)
|
62 |
+
for key, arg in kwargs.items():
|
63 |
+
if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None:
|
64 |
+
kwargs[key] = np.array(arg)
|
65 |
+
# broadcast arguments
|
66 |
+
args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
|
67 |
+
for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
|
68 |
+
if isinstance(arg, np.ndarray) and arg_dim is not None:
|
69 |
+
args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]])
|
70 |
+
for key, arg in kwargs.items():
|
71 |
+
if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
|
72 |
+
kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]])
|
73 |
+
# call function
|
74 |
+
results = func(*args, **kwargs)
|
75 |
+
type_results = type(results)
|
76 |
+
results = list(results) if isinstance(results, (tuple, list)) else [results]
|
77 |
+
# restore spatial dimensions
|
78 |
+
for i, result in enumerate(results):
|
79 |
+
results[i] = result.reshape([*spatial, *result.shape[1:]])
|
80 |
+
if type_results == tuple:
|
81 |
+
results = tuple(results)
|
82 |
+
elif type_results == list:
|
83 |
+
results = list(results)
|
84 |
+
else:
|
85 |
+
results = results[0]
|
86 |
+
return results
|
87 |
+
return wrapper
|
88 |
+
return decorator
|
utils3d/numpy/mesh.py
ADDED
@@ -0,0 +1,355 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
from typing import *
|
3 |
+
from ._helpers import batched
|
4 |
+
|
5 |
+
|
6 |
+
__all__ = [
|
7 |
+
'triangulate',
|
8 |
+
'compute_face_normal',
|
9 |
+
'compute_face_angle',
|
10 |
+
'compute_vertex_normal',
|
11 |
+
'compute_vertex_normal_weighted',
|
12 |
+
'remove_corrupted_faces',
|
13 |
+
'merge_duplicate_vertices',
|
14 |
+
'remove_unreferenced_vertices',
|
15 |
+
'subdivide_mesh_simple',
|
16 |
+
'mesh_relations',
|
17 |
+
'flatten_mesh_indices'
|
18 |
+
]
|
19 |
+
|
20 |
+
|
21 |
+
def triangulate(
|
22 |
+
faces: np.ndarray,
|
23 |
+
vertices: np.ndarray = None,
|
24 |
+
backslash: np.ndarray = None
|
25 |
+
) -> np.ndarray:
|
26 |
+
"""
|
27 |
+
Triangulate a polygonal mesh.
|
28 |
+
|
29 |
+
Args:
|
30 |
+
faces (np.ndarray): [L, P] polygonal faces
|
31 |
+
vertices (np.ndarray, optional): [N, 3] 3-dimensional vertices.
|
32 |
+
If given, the triangulation is performed according to the distance
|
33 |
+
between vertices. Defaults to None.
|
34 |
+
backslash (np.ndarray, optional): [L] boolean array indicating
|
35 |
+
how to triangulate the quad faces. Defaults to None.
|
36 |
+
|
37 |
+
Returns:
|
38 |
+
(np.ndarray): [L * (P - 2), 3] triangular faces
|
39 |
+
"""
|
40 |
+
if faces.shape[-1] == 3:
|
41 |
+
return faces
|
42 |
+
P = faces.shape[-1]
|
43 |
+
if vertices is not None:
|
44 |
+
assert faces.shape[-1] == 4, "now only support quad mesh"
|
45 |
+
if backslash is None:
|
46 |
+
backslash = np.linalg.norm(vertices[faces[:, 0]] - vertices[faces[:, 2]], axis=-1) < \
|
47 |
+
np.linalg.norm(vertices[faces[:, 1]] - vertices[faces[:, 3]], axis=-1)
|
48 |
+
if backslash is None:
|
49 |
+
loop_indice = np.stack([
|
50 |
+
np.zeros(P - 2, dtype=int),
|
51 |
+
np.arange(1, P - 1, 1, dtype=int),
|
52 |
+
np.arange(2, P, 1, dtype=int)
|
53 |
+
], axis=1)
|
54 |
+
return faces[:, loop_indice].reshape((-1, 3))
|
55 |
+
else:
|
56 |
+
assert faces.shape[-1] == 4, "now only support quad mesh"
|
57 |
+
faces = np.where(
|
58 |
+
backslash[:, None],
|
59 |
+
faces[:, [0, 1, 2, 0, 2, 3]],
|
60 |
+
faces[:, [0, 1, 3, 3, 1, 2]]
|
61 |
+
).reshape((-1, 3))
|
62 |
+
return faces
|
63 |
+
|
64 |
+
|
65 |
+
@batched(2, None)
|
66 |
+
def compute_face_normal(
|
67 |
+
vertices: np.ndarray,
|
68 |
+
faces: np.ndarray
|
69 |
+
) -> np.ndarray:
|
70 |
+
"""
|
71 |
+
Compute face normals of a triangular mesh
|
72 |
+
|
73 |
+
Args:
|
74 |
+
vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
|
75 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
76 |
+
|
77 |
+
Returns:
|
78 |
+
normals (np.ndarray): [..., T, 3] face normals
|
79 |
+
"""
|
80 |
+
normal = np.cross(
|
81 |
+
vertices[..., faces[:, 1], :] - vertices[..., faces[:, 0], :],
|
82 |
+
vertices[..., faces[:, 2], :] - vertices[..., faces[:, 0], :]
|
83 |
+
)
|
84 |
+
normal_norm = np.linalg.norm(normal, axis=-1, keepdims=True)
|
85 |
+
normal_norm[normal_norm == 0] = 1
|
86 |
+
normal /= normal_norm
|
87 |
+
return normal
|
88 |
+
|
89 |
+
|
90 |
+
@batched(2, None)
|
91 |
+
def compute_face_angle(
|
92 |
+
vertices: np.ndarray,
|
93 |
+
faces: np.ndarray,
|
94 |
+
eps: float = 1e-12
|
95 |
+
) -> np.ndarray:
|
96 |
+
"""
|
97 |
+
Compute face angles of a triangular mesh
|
98 |
+
|
99 |
+
Args:
|
100 |
+
vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
|
101 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
102 |
+
|
103 |
+
Returns:
|
104 |
+
angles (np.ndarray): [..., T, 3] face angles
|
105 |
+
"""
|
106 |
+
face_angle = np.zeros_like(faces, dtype=vertices.dtype)
|
107 |
+
for i in range(3):
|
108 |
+
edge1 = vertices[..., faces[:, (i + 1) % 3], :] - vertices[..., faces[:, i], :]
|
109 |
+
edge2 = vertices[..., faces[:, (i + 2) % 3], :] - vertices[..., faces[:, i], :]
|
110 |
+
face_angle[..., i] = np.arccos(np.sum(
|
111 |
+
edge1 / np.clip(np.linalg.norm(edge1, axis=-1, keepdims=True), eps, None) *
|
112 |
+
edge2 / np.clip(np.linalg.norm(edge2, axis=-1, keepdims=True), eps, None),
|
113 |
+
axis=-1
|
114 |
+
))
|
115 |
+
return face_angle
|
116 |
+
|
117 |
+
|
118 |
+
@batched(2, None, 2)
|
119 |
+
def compute_vertex_normal(
|
120 |
+
vertices: np.ndarray,
|
121 |
+
faces: np.ndarray,
|
122 |
+
face_normal: np.ndarray = None
|
123 |
+
) -> np.ndarray:
|
124 |
+
"""
|
125 |
+
Compute vertex normals of a triangular mesh by averaging neightboring face normals
|
126 |
+
TODO: can be improved.
|
127 |
+
|
128 |
+
Args:
|
129 |
+
vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
|
130 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
131 |
+
face_normal (np.ndarray, optional): [..., T, 3] face normals.
|
132 |
+
None to compute face normals from vertices and faces. Defaults to None.
|
133 |
+
|
134 |
+
Returns:
|
135 |
+
normals (np.ndarray): [..., N, 3] vertex normals
|
136 |
+
"""
|
137 |
+
if face_normal is None:
|
138 |
+
face_normal = compute_face_normal(vertices, faces)
|
139 |
+
vertex_normal = np.zeros_like(vertices, dtype=vertices.dtype)
|
140 |
+
for n in range(vertices.shape[0]):
|
141 |
+
for i in range(3):
|
142 |
+
vertex_normal[n, :, 0] += np.bincount(faces[:, i], weights=face_normal[n, :, 0], minlength=vertices.shape[1])
|
143 |
+
vertex_normal[n, :, 1] += np.bincount(faces[:, i], weights=face_normal[n, :, 1], minlength=vertices.shape[1])
|
144 |
+
vertex_normal[n, :, 2] += np.bincount(faces[:, i], weights=face_normal[n, :, 2], minlength=vertices.shape[1])
|
145 |
+
vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True)
|
146 |
+
vertex_normal_norm[vertex_normal_norm == 0] = 1
|
147 |
+
vertex_normal /= vertex_normal_norm
|
148 |
+
return vertex_normal
|
149 |
+
|
150 |
+
|
151 |
+
@batched(2, None, 2)
|
152 |
+
def compute_vertex_normal_weighted(
|
153 |
+
vertices: np.ndarray,
|
154 |
+
faces: np.ndarray,
|
155 |
+
face_normal: np.ndarray = None
|
156 |
+
) -> np.ndarray:
|
157 |
+
"""
|
158 |
+
Compute vertex normals of a triangular mesh by weighted sum of neightboring face normals
|
159 |
+
according to the angles
|
160 |
+
|
161 |
+
Args:
|
162 |
+
vertices (np.ndarray): [..., N, 3] 3-dimensional vertices
|
163 |
+
faces (np.ndarray): [..., T, 3] triangular face indices
|
164 |
+
face_normal (np.ndarray, optional): [..., T, 3] face normals.
|
165 |
+
None to compute face normals from vertices and faces. Defaults to None.
|
166 |
+
|
167 |
+
Returns:
|
168 |
+
normals (np.ndarray): [..., N, 3] vertex normals
|
169 |
+
"""
|
170 |
+
if face_normal is None:
|
171 |
+
face_normal = compute_face_normal(vertices, faces)
|
172 |
+
face_angle = compute_face_angle(vertices, faces)
|
173 |
+
vertex_normal = np.zeros_like(vertices)
|
174 |
+
for n in range(vertices.shape[0]):
|
175 |
+
for i in range(3):
|
176 |
+
vertex_normal[n, :, 0] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 0] * face_angle[n, :, i], minlength=vertices.shape[1])
|
177 |
+
vertex_normal[n, :, 1] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 1] * face_angle[n, :, i], minlength=vertices.shape[1])
|
178 |
+
vertex_normal[n, :, 2] += np.bincount(faces[n, :, i], weights=face_normal[n, :, 2] * face_angle[n, :, i], minlength=vertices.shape[1])
|
179 |
+
vertex_normal_norm = np.linalg.norm(vertex_normal, axis=-1, keepdims=True)
|
180 |
+
vertex_normal_norm[vertex_normal_norm == 0] = 1
|
181 |
+
vertex_normal /= vertex_normal_norm
|
182 |
+
return vertex_normal
|
183 |
+
|
184 |
+
|
185 |
+
def remove_corrupted_faces(
|
186 |
+
faces: np.ndarray
|
187 |
+
) -> np.ndarray:
|
188 |
+
"""
|
189 |
+
Remove corrupted faces (faces with duplicated vertices)
|
190 |
+
|
191 |
+
Args:
|
192 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
193 |
+
|
194 |
+
Returns:
|
195 |
+
np.ndarray: [T_, 3] triangular face indices
|
196 |
+
"""
|
197 |
+
corrupted = (faces[:, 0] == faces[:, 1]) | (faces[:, 1] == faces[:, 2]) | (faces[:, 2] == faces[:, 0])
|
198 |
+
return faces[~corrupted]
|
199 |
+
|
200 |
+
|
201 |
+
def merge_duplicate_vertices(
|
202 |
+
vertices: np.ndarray,
|
203 |
+
faces: np.ndarray,
|
204 |
+
tol: float = 1e-6
|
205 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
206 |
+
"""
|
207 |
+
Merge duplicate vertices of a triangular mesh.
|
208 |
+
Duplicate vertices are merged by selecte one of them, and the face indices are updated accordingly.
|
209 |
+
|
210 |
+
Args:
|
211 |
+
vertices (np.ndarray): [N, 3] 3-dimensional vertices
|
212 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
213 |
+
tol (float, optional): tolerance for merging. Defaults to 1e-6.
|
214 |
+
|
215 |
+
Returns:
|
216 |
+
vertices (np.ndarray): [N_, 3] 3-dimensional vertices
|
217 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
218 |
+
"""
|
219 |
+
vertices_round = np.round(vertices / tol)
|
220 |
+
_, uni_i, uni_inv = np.unique(vertices_round, return_index=True, return_inverse=True, axis=0)
|
221 |
+
vertices = vertices[uni_i]
|
222 |
+
faces = uni_inv[faces]
|
223 |
+
return vertices, faces
|
224 |
+
|
225 |
+
|
226 |
+
def remove_unreferenced_vertices(
|
227 |
+
faces: np.ndarray,
|
228 |
+
*vertice_attrs,
|
229 |
+
return_indices: bool = False
|
230 |
+
) -> Tuple[np.ndarray, ...]:
|
231 |
+
"""
|
232 |
+
Remove unreferenced vertices of a mesh.
|
233 |
+
Unreferenced vertices are removed, and the face indices are updated accordingly.
|
234 |
+
|
235 |
+
Args:
|
236 |
+
faces (np.ndarray): [T, P] face indices
|
237 |
+
*vertice_attrs: vertex attributes
|
238 |
+
|
239 |
+
Returns:
|
240 |
+
faces (np.ndarray): [T, P] face indices
|
241 |
+
*vertice_attrs: vertex attributes
|
242 |
+
indices (np.ndarray, optional): [N] indices of vertices that are kept. Defaults to None.
|
243 |
+
"""
|
244 |
+
P = faces.shape[-1]
|
245 |
+
fewer_indices, inv_map = np.unique(faces, return_inverse=True)
|
246 |
+
faces = inv_map.astype(np.int32).reshape(-1, P)
|
247 |
+
ret = [faces]
|
248 |
+
for attr in vertice_attrs:
|
249 |
+
ret.append(attr[fewer_indices])
|
250 |
+
if return_indices:
|
251 |
+
ret.append(fewer_indices)
|
252 |
+
return tuple(ret)
|
253 |
+
|
254 |
+
|
255 |
+
def subdivide_mesh_simple(
|
256 |
+
vertices: np.ndarray,
|
257 |
+
faces: np.ndarray,
|
258 |
+
n: int = 1
|
259 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
260 |
+
"""
|
261 |
+
Subdivide a triangular mesh by splitting each triangle into 4 smaller triangles.
|
262 |
+
NOTE: All original vertices are kept, and new vertices are appended to the end of the vertex list.
|
263 |
+
|
264 |
+
Args:
|
265 |
+
vertices (np.ndarray): [N, 3] 3-dimensional vertices
|
266 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
267 |
+
n (int, optional): number of subdivisions. Defaults to 1.
|
268 |
+
|
269 |
+
Returns:
|
270 |
+
vertices (np.ndarray): [N_, 3] subdivided 3-dimensional vertices
|
271 |
+
faces (np.ndarray): [4 * T, 3] subdivided triangular face indices
|
272 |
+
"""
|
273 |
+
for _ in range(n):
|
274 |
+
edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=0)
|
275 |
+
edges = np.sort(edges, axis=2)
|
276 |
+
uni_edges, uni_inv = np.unique(edges.reshape(-1, 2), return_inverse=True, axis=0)
|
277 |
+
uni_inv = uni_inv.reshape(3, -1)
|
278 |
+
midpoints = (vertices[uni_edges[:, 0]] + vertices[uni_edges[:, 1]]) / 2
|
279 |
+
|
280 |
+
n_vertices = vertices.shape[0]
|
281 |
+
vertices = np.concatenate([vertices, midpoints], axis=0)
|
282 |
+
faces = np.concatenate([
|
283 |
+
np.stack([faces[:, 0], n_vertices + uni_inv[0], n_vertices + uni_inv[2]], axis=1),
|
284 |
+
np.stack([faces[:, 1], n_vertices + uni_inv[1], n_vertices + uni_inv[0]], axis=1),
|
285 |
+
np.stack([faces[:, 2], n_vertices + uni_inv[2], n_vertices + uni_inv[1]], axis=1),
|
286 |
+
np.stack([n_vertices + uni_inv[0], n_vertices + uni_inv[1], n_vertices + uni_inv[2]], axis=1),
|
287 |
+
], axis=0)
|
288 |
+
return vertices, faces
|
289 |
+
|
290 |
+
|
291 |
+
def mesh_relations(
|
292 |
+
faces: np.ndarray,
|
293 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
294 |
+
"""
|
295 |
+
Calculate the relation between vertices and faces.
|
296 |
+
NOTE: The input mesh must be a manifold triangle mesh.
|
297 |
+
|
298 |
+
Args:
|
299 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
300 |
+
|
301 |
+
Returns:
|
302 |
+
edges (np.ndarray): [E, 2] edge indices
|
303 |
+
edge2face (np.ndarray): [E, 2] edge to face relation. The second column is -1 if the edge is boundary.
|
304 |
+
face2edge (np.ndarray): [T, 3] face to edge relation
|
305 |
+
face2face (np.ndarray): [T, 3] face to face relation
|
306 |
+
"""
|
307 |
+
T = faces.shape[0]
|
308 |
+
edges = np.stack([faces[:, [0, 1]], faces[:, [1, 2]], faces[:, [2, 0]]], axis=1).reshape(-1, 2) # [3T, 2]
|
309 |
+
edges = np.sort(edges, axis=1) # [3T, 2]
|
310 |
+
edges, face2edge, occurence = np.unique(edges, axis=0, return_inverse=True, return_counts=True) # [E, 2], [3T], [E]
|
311 |
+
E = edges.shape[0]
|
312 |
+
assert np.all(occurence <= 2), "The input mesh is not a manifold mesh."
|
313 |
+
|
314 |
+
# Edge to face relation
|
315 |
+
padding = np.arange(E, dtype=np.int32)[occurence == 1]
|
316 |
+
padded_face2edge = np.concatenate([face2edge, padding], axis=0) # [2E]
|
317 |
+
edge2face = np.argsort(padded_face2edge, kind='stable').reshape(-1, 2) // 3 # [E, 2]
|
318 |
+
edge2face_valid = edge2face[:, 1] < T # [E]
|
319 |
+
edge2face[~edge2face_valid, 1] = -1
|
320 |
+
|
321 |
+
# Face to edge relation
|
322 |
+
face2edge = face2edge.reshape(-1, 3) # [T, 3]
|
323 |
+
|
324 |
+
# Face to face relation
|
325 |
+
face2face = edge2face[face2edge] # [T, 3, 2]
|
326 |
+
face2face = face2face[face2face != np.arange(T)[:, None, None]].reshape(T, 3) # [T, 3]
|
327 |
+
|
328 |
+
return edges, edge2face, face2edge, face2face
|
329 |
+
|
330 |
+
|
331 |
+
@overload
|
332 |
+
def flatten_mesh_indices(faces1: np.ndarray, attr1: np.ndarray, *other_faces_attrs_pairs: np.ndarray) -> Tuple[np.ndarray, ...]:
|
333 |
+
"""
|
334 |
+
Rearrange the indices of a mesh to a flattened version. Vertices will be no longer shared.
|
335 |
+
|
336 |
+
### Parameters:
|
337 |
+
- `faces1`: [T, P] face indices of the first attribute
|
338 |
+
- `attr1`: [N1, ...] attributes of the first mesh
|
339 |
+
- ...
|
340 |
+
|
341 |
+
### Returns:
|
342 |
+
- `faces`: [T, P] flattened face indices, contigous from 0 to T * P - 1
|
343 |
+
- `attr1`: [T * P, ...] attributes of the first mesh, where every P values correspond to a face
|
344 |
+
_ ...
|
345 |
+
"""
|
346 |
+
def flatten_mesh_indices(*args: np.ndarray) -> Tuple[np.ndarray, ...]:
|
347 |
+
assert len(args) % 2 == 0, "The number of arguments must be even."
|
348 |
+
T, P = args[0].shape
|
349 |
+
assert all(arg.shape[0] == T and arg.shape[1] == P for arg in args[::2]), "The faces must have the same shape."
|
350 |
+
attr_flat = []
|
351 |
+
for faces_, attr_ in zip(args[::2], args[1::2]):
|
352 |
+
attr_flat_ = attr_[faces_].reshape(-1, *attr_.shape[1:])
|
353 |
+
attr_flat.append(attr_flat_)
|
354 |
+
faces_flat = np.arange(T * P, dtype=np.int32).reshape(T, P)
|
355 |
+
return faces_flat, *attr_flat
|
utils3d/numpy/quadmesh.py
ADDED
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import numpy as np
|
2 |
+
import scipy as sp
|
3 |
+
import scipy.optimize as spopt
|
4 |
+
import piqp
|
5 |
+
from typing import *
|
6 |
+
|
7 |
+
|
8 |
+
__all__ = [
|
9 |
+
'calc_quad_candidates',
|
10 |
+
'calc_quad_distortion',
|
11 |
+
'calc_quad_direction',
|
12 |
+
'calc_quad_smoothness',
|
13 |
+
'sovle_quad',
|
14 |
+
'sovle_quad_qp',
|
15 |
+
'tri_to_quad'
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
def calc_quad_candidates(
|
20 |
+
edges: np.ndarray,
|
21 |
+
face2edge: np.ndarray,
|
22 |
+
edge2face: np.ndarray,
|
23 |
+
):
|
24 |
+
"""
|
25 |
+
Calculate the candidate quad faces.
|
26 |
+
|
27 |
+
Args:
|
28 |
+
edges (np.ndarray): [E, 2] edge indices
|
29 |
+
face2edge (np.ndarray): [T, 3] face to edge relation
|
30 |
+
edge2face (np.ndarray): [E, 2] edge to face relation
|
31 |
+
|
32 |
+
Returns:
|
33 |
+
quads (np.ndarray): [Q, 4] quad candidate indices
|
34 |
+
quad2edge (np.ndarray): [Q, 4] edge to quad candidate relation
|
35 |
+
quad2adj (np.ndarray): [Q, 8] adjacent quad candidates of each quad candidate
|
36 |
+
quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
|
37 |
+
"""
|
38 |
+
E = edges.shape[0]
|
39 |
+
T = face2edge.shape[0]
|
40 |
+
|
41 |
+
quads_valid = edge2face[:, 1] != -1
|
42 |
+
Q = quads_valid.sum()
|
43 |
+
quad2face = edge2face[quads_valid] # [Q, 2]
|
44 |
+
quad2edge = face2edge[quad2face] # [Q, 2, 3]
|
45 |
+
flag = quad2edge == np.arange(E)[quads_valid][:, None, None] # [Q, 2, 3]
|
46 |
+
flag = flag.argmax(axis=-1) # [Q, 2]
|
47 |
+
quad2edge = np.stack([
|
48 |
+
quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 1) % 3],
|
49 |
+
quad2edge[np.arange(Q)[:, None], np.arange(2)[None, :], (flag + 2) % 3],
|
50 |
+
], axis=-1).reshape(Q, 4) # [Q, 4]
|
51 |
+
|
52 |
+
quads = np.concatenate([
|
53 |
+
np.where(
|
54 |
+
(edges[quad2edge[:, 0:1], 1:] == edges[quad2edge[:, 1:2], :]).any(axis=-1),
|
55 |
+
edges[quad2edge[:, 0:1], [[0, 1]]],
|
56 |
+
edges[quad2edge[:, 0:1], [[1, 0]]],
|
57 |
+
),
|
58 |
+
np.where(
|
59 |
+
(edges[quad2edge[:, 2:3], 1:] == edges[quad2edge[:, 3:4], :]).any(axis=-1),
|
60 |
+
edges[quad2edge[:, 2:3], [[0, 1]]],
|
61 |
+
edges[quad2edge[:, 2:3], [[1, 0]]],
|
62 |
+
),
|
63 |
+
], axis=1) # [Q, 4]
|
64 |
+
|
65 |
+
quad2adj = edge2face[quad2edge] # [Q, 4, 2]
|
66 |
+
quad2adj = quad2adj[quad2adj != quad2face[:, [0,0,1,1], None]].reshape(Q, 4) # [Q, 4]
|
67 |
+
quad2adj_valid = quad2adj != -1
|
68 |
+
quad2adj = face2edge[quad2adj] # [Q, 4, 3]
|
69 |
+
quad2adj[~quad2adj_valid, 0] = quad2edge[~quad2adj_valid]
|
70 |
+
quad2adj[~quad2adj_valid, 1:] = -1
|
71 |
+
quad2adj = quad2adj[quad2adj != quad2edge[..., None]].reshape(Q, 8) # [Q, 8]
|
72 |
+
edge_valid = -np.ones(E, dtype=np.int32)
|
73 |
+
edge_valid[quads_valid] = np.arange(Q)
|
74 |
+
quad2adj_valid = quad2adj != -1
|
75 |
+
quad2adj[quad2adj_valid] = edge_valid[quad2adj[quad2adj_valid]] # [Q, 8]
|
76 |
+
|
77 |
+
return quads, quad2edge, quad2adj, quads_valid
|
78 |
+
|
79 |
+
|
80 |
+
def calc_quad_distortion(
|
81 |
+
vertices: np.ndarray,
|
82 |
+
quads: np.ndarray,
|
83 |
+
):
|
84 |
+
"""
|
85 |
+
Calculate the distortion of each candidate quad face.
|
86 |
+
|
87 |
+
Args:
|
88 |
+
vertices (np.ndarray): [N, 3] 3-dimensional vertices
|
89 |
+
quads (np.ndarray): [Q, 4] quad face indices
|
90 |
+
|
91 |
+
Returns:
|
92 |
+
distortion (np.ndarray): [Q] distortion of each quad face
|
93 |
+
"""
|
94 |
+
edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3]
|
95 |
+
edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3]
|
96 |
+
edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3]
|
97 |
+
edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3]
|
98 |
+
cross = vertices[quads[:, 0]] - vertices[quads[:, 2]] # [Q, 3]
|
99 |
+
|
100 |
+
len0 = np.maximum(np.linalg.norm(edge0, axis=-1), 1e-10) # [Q]
|
101 |
+
len1 = np.maximum(np.linalg.norm(edge1, axis=-1), 1e-10) # [Q]
|
102 |
+
len2 = np.maximum(np.linalg.norm(edge2, axis=-1), 1e-10) # [Q]
|
103 |
+
len3 = np.maximum(np.linalg.norm(edge3, axis=-1), 1e-10) # [Q]
|
104 |
+
len_cross = np.maximum(np.linalg.norm(cross, axis=-1), 1e-10) # [Q]
|
105 |
+
|
106 |
+
angle0 = np.arccos(np.clip(np.sum(-edge0 * edge1, axis=-1) / (len0 * len1), -1, 1)) # [Q]
|
107 |
+
angle1 = np.arccos(np.clip(np.sum(-edge1 * cross, axis=-1) / (len1 * len_cross), -1, 1)) \
|
108 |
+
+ np.arccos(np.clip(np.sum(cross * edge2, axis=-1) / (len_cross * len2), -1, 1)) # [Q]
|
109 |
+
angle2 = np.arccos(np.clip(np.sum(-edge2 * edge3, axis=-1) / (len2 * len3), -1, 1)) # [Q]
|
110 |
+
angle3 = np.arccos(np.clip(np.sum(-edge3 * -cross, axis=-1) / (len3 * len_cross), -1, 1)) \
|
111 |
+
+ np.arccos(np.clip(np.sum(-cross * edge0, axis=-1) / (len_cross * len0), -1, 1)) # [Q]
|
112 |
+
|
113 |
+
normal0 = np.cross(edge0, edge1) # [Q, 3]
|
114 |
+
normal1 = np.cross(edge2, edge3) # [Q, 3]
|
115 |
+
normal0 = normal0 / np.maximum(np.linalg.norm(normal0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
116 |
+
normal1 = normal1 / np.maximum(np.linalg.norm(normal1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
117 |
+
angle_normal = np.arccos(np.clip(np.sum(normal0 * normal1, axis=-1), -1, 1)) # [Q]
|
118 |
+
|
119 |
+
D90 = np.pi / 2
|
120 |
+
D180 = np.pi
|
121 |
+
D360 = np.pi * 2
|
122 |
+
ang_eng = (np.abs(angle0 - D90)**2 + np.abs(angle1 - D90)**2 + np.abs(angle2 - D90)**2 + np.abs(angle3 - D90)**2) / 4 # [Q]
|
123 |
+
dist_eng = np.abs(angle0 - angle2)**2 / np.minimum(np.maximum(np.minimum(angle0, angle2), 1e-10), np.maximum(D180 - np.maximum(angle0, angle2), 1e-10)) \
|
124 |
+
+ np.abs(angle1 - angle3)**2 / np.minimum(np.maximum(np.minimum(angle1, angle3), 1e-10), np.maximum(D180 - np.maximum(angle1, angle3), 1e-10)) # [Q]
|
125 |
+
plane_eng = np.where(angle_normal < D90/2, np.abs(angle_normal)**2, 1e10) # [Q]
|
126 |
+
eng = ang_eng + 2 * dist_eng + 2 * plane_eng # [Q]
|
127 |
+
|
128 |
+
return eng
|
129 |
+
|
130 |
+
|
131 |
+
def calc_quad_direction(
|
132 |
+
vertices: np.ndarray,
|
133 |
+
quads: np.ndarray,
|
134 |
+
):
|
135 |
+
"""
|
136 |
+
Calculate the direction of each candidate quad face.
|
137 |
+
|
138 |
+
Args:
|
139 |
+
vertices (np.ndarray): [N, 3] 3-dimensional vertices
|
140 |
+
quads (np.ndarray): [Q, 4] quad face indices
|
141 |
+
|
142 |
+
Returns:
|
143 |
+
direction (np.ndarray): [Q, 4] direction of each quad face.
|
144 |
+
Represented by the angle between the crossing and each edge.
|
145 |
+
"""
|
146 |
+
mid0 = (vertices[quads[:, 0]] + vertices[quads[:, 1]]) / 2 # [Q, 3]
|
147 |
+
mid1 = (vertices[quads[:, 1]] + vertices[quads[:, 2]]) / 2 # [Q, 3]
|
148 |
+
mid2 = (vertices[quads[:, 2]] + vertices[quads[:, 3]]) / 2 # [Q, 3]
|
149 |
+
mid3 = (vertices[quads[:, 3]] + vertices[quads[:, 0]]) / 2 # [Q, 3]
|
150 |
+
|
151 |
+
cross0 = mid2 - mid0 # [Q, 3]
|
152 |
+
cross1 = mid3 - mid1 # [Q, 3]
|
153 |
+
cross0 = cross0 / np.maximum(np.linalg.norm(cross0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
154 |
+
cross1 = cross1 / np.maximum(np.linalg.norm(cross1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
155 |
+
|
156 |
+
edge0 = vertices[quads[:, 1]] - vertices[quads[:, 0]] # [Q, 3]
|
157 |
+
edge1 = vertices[quads[:, 2]] - vertices[quads[:, 1]] # [Q, 3]
|
158 |
+
edge2 = vertices[quads[:, 3]] - vertices[quads[:, 2]] # [Q, 3]
|
159 |
+
edge3 = vertices[quads[:, 0]] - vertices[quads[:, 3]] # [Q, 3]
|
160 |
+
edge0 = edge0 / np.maximum(np.linalg.norm(edge0, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
161 |
+
edge1 = edge1 / np.maximum(np.linalg.norm(edge1, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
162 |
+
edge2 = edge2 / np.maximum(np.linalg.norm(edge2, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
163 |
+
edge3 = edge3 / np.maximum(np.linalg.norm(edge3, axis=-1, keepdims=True), 1e-10) # [Q, 3]
|
164 |
+
|
165 |
+
direction = np.stack([
|
166 |
+
np.arccos(np.clip(np.sum(cross0 * edge0, axis=-1), -1, 1)),
|
167 |
+
np.arccos(np.clip(np.sum(cross1 * edge1, axis=-1), -1, 1)),
|
168 |
+
np.arccos(np.clip(np.sum(-cross0 * edge2, axis=-1), -1, 1)),
|
169 |
+
np.arccos(np.clip(np.sum(-cross1 * edge3, axis=-1), -1, 1)),
|
170 |
+
], axis=-1) # [Q, 4]
|
171 |
+
|
172 |
+
return direction
|
173 |
+
|
174 |
+
|
175 |
+
def calc_quad_smoothness(
|
176 |
+
quad2edge: np.ndarray,
|
177 |
+
quad2adj: np.ndarray,
|
178 |
+
quads_direction: np.ndarray,
|
179 |
+
):
|
180 |
+
"""
|
181 |
+
Calculate the smoothness of each candidate quad face connection.
|
182 |
+
|
183 |
+
Args:
|
184 |
+
quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
|
185 |
+
quads_direction (np.ndarray): [Q, 4] direction of each quad face
|
186 |
+
|
187 |
+
Returns:
|
188 |
+
smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
|
189 |
+
"""
|
190 |
+
Q = quad2adj.shape[0]
|
191 |
+
quad2adj_valid = quad2adj != -1
|
192 |
+
connections = np.stack([
|
193 |
+
np.arange(Q)[:, None].repeat(8, axis=1),
|
194 |
+
quad2adj,
|
195 |
+
], axis=-1)[quad2adj_valid] # [C, 2]
|
196 |
+
shared_edge_idx_0 = np.array([[0, 0, 1, 1, 2, 2, 3, 3]]).repeat(Q, axis=0)[quad2adj_valid] # [C]
|
197 |
+
shared_edge_idx_1 = np.argmax(quad2edge[quad2adj][quad2adj_valid] == quad2edge[connections[:, 0], shared_edge_idx_0][:, None], axis=-1) # [C]
|
198 |
+
valid_smoothness = np.abs(quads_direction[connections[:, 0], shared_edge_idx_0] - quads_direction[connections[:, 1], shared_edge_idx_1])**2 # [C]
|
199 |
+
smoothness = np.zeros([Q, 8], dtype=np.float32)
|
200 |
+
smoothness[quad2adj_valid] = valid_smoothness
|
201 |
+
return smoothness
|
202 |
+
|
203 |
+
|
204 |
+
def sovle_quad(
|
205 |
+
face2edge: np.ndarray,
|
206 |
+
edge2face: np.ndarray,
|
207 |
+
quad2adj: np.ndarray,
|
208 |
+
quads_distortion: np.ndarray,
|
209 |
+
quads_smoothness: np.ndarray,
|
210 |
+
quads_valid: np.ndarray,
|
211 |
+
):
|
212 |
+
"""
|
213 |
+
Solve the quad mesh from the candidate quad faces.
|
214 |
+
|
215 |
+
Args:
|
216 |
+
face2edge (np.ndarray): [T, 3] face to edge relation
|
217 |
+
edge2face (np.ndarray): [E, 2] edge to face relation
|
218 |
+
quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
|
219 |
+
quads_distortion (np.ndarray): [Q] distortion of each quad face
|
220 |
+
quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
|
221 |
+
quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
|
222 |
+
|
223 |
+
Returns:
|
224 |
+
weights (np.ndarray): [Q] weight of each valid quad face
|
225 |
+
"""
|
226 |
+
T = face2edge.shape[0]
|
227 |
+
E = edge2face.shape[0]
|
228 |
+
Q = quads_distortion.shape[0]
|
229 |
+
edge_valid = -np.ones(E, dtype=np.int32)
|
230 |
+
edge_valid[quads_valid] = np.arange(Q)
|
231 |
+
|
232 |
+
quads_connection = np.stack([
|
233 |
+
np.arange(Q)[:, None].repeat(8, axis=1),
|
234 |
+
quad2adj,
|
235 |
+
], axis=-1)[quad2adj != -1] # [C, 2]
|
236 |
+
quads_connection = np.sort(quads_connection, axis=-1) # [C, 2]
|
237 |
+
quads_connection, quads_connection_idx = np.unique(quads_connection, axis=0, return_index=True) # [C, 2], [C]
|
238 |
+
quads_smoothness = quads_smoothness[quad2adj != -1] # [C]
|
239 |
+
quads_smoothness = quads_smoothness[quads_connection_idx] # [C]
|
240 |
+
C = quads_connection.shape[0]
|
241 |
+
|
242 |
+
# Construct the linear programming problem
|
243 |
+
|
244 |
+
# Variables:
|
245 |
+
# quads_weight: [Q] weight of each quad face
|
246 |
+
# tri_min_weight: [T] minimum weight of each triangle face
|
247 |
+
# conn_min_weight: [C] minimum weight of each quad face connection
|
248 |
+
# conn_max_weight: [C] maximum weight of each quad face connection
|
249 |
+
# Objective:
|
250 |
+
# mimi
|
251 |
+
|
252 |
+
c = np.concatenate([
|
253 |
+
quads_distortion - 3,
|
254 |
+
quads_smoothness*4 - 2,
|
255 |
+
quads_smoothness*4,
|
256 |
+
], axis=0) # [Q+C]
|
257 |
+
|
258 |
+
A_ub_triplet = np.concatenate([
|
259 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3]
|
260 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3]
|
261 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3]
|
262 |
+
np.stack([np.arange(T, T+C), np.arange(Q, Q+C), np.ones(C)], axis=1), # [C, 3]
|
263 |
+
np.stack([np.arange(T, T+C), quads_connection[:, 0], -np.ones(C)], axis=1), # [C, 3]
|
264 |
+
np.stack([np.arange(T, T+C), quads_connection[:, 1], -np.ones(C)], axis=1), # [C, 3]
|
265 |
+
np.stack([np.arange(T+C, T+2*C), np.arange(Q+C, Q+2*C), -np.ones(C)], axis=1), # [C, 3]
|
266 |
+
np.stack([np.arange(T+C, T+2*C), quads_connection[:, 0], np.ones(C)], axis=1), # [C, 3]
|
267 |
+
np.stack([np.arange(T+C, T+2*C), quads_connection[:, 1], np.ones(C)], axis=1), # [C, 3]
|
268 |
+
], axis=0) # [3T+6C, 3]
|
269 |
+
A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3]
|
270 |
+
A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T+2*C, Q+2*C]) # [T,
|
271 |
+
b_ub = np.concatenate([np.ones(T), -np.ones(C), np.ones(C)], axis=0) # [T+2C]
|
272 |
+
bound = np.stack([
|
273 |
+
np.concatenate([np.zeros(Q), -np.ones(C), np.zeros(C)], axis=0),
|
274 |
+
np.concatenate([np.ones(Q), np.ones(C), np.ones(C)], axis=0),
|
275 |
+
], axis=1) # [Q+2C, 2]
|
276 |
+
A_eq = None
|
277 |
+
b_eq = None
|
278 |
+
|
279 |
+
print('Solver statistics:')
|
280 |
+
print(f' #T = {T}')
|
281 |
+
print(f' #Q = {Q}')
|
282 |
+
print(f' #C = {C}')
|
283 |
+
|
284 |
+
# Solve the linear programming problem
|
285 |
+
last_num_valid = 0
|
286 |
+
for i in range(100):
|
287 |
+
res_ = spopt.linprog(c, A_ub=A_ub, b_ub=b_ub, A_eq=A_eq, b_eq=b_eq, bounds=bound)
|
288 |
+
if not res_.success:
|
289 |
+
print(f' Iter {i} | Failed with {res_.message}')
|
290 |
+
break
|
291 |
+
res = res_
|
292 |
+
weights = res.x[:Q]
|
293 |
+
valid = (weights > 0.5)
|
294 |
+
num_valid = valid.sum()
|
295 |
+
print(f' Iter {i} | #Q_valid = {num_valid}')
|
296 |
+
if num_valid == last_num_valid:
|
297 |
+
break
|
298 |
+
last_num_valid = num_valid
|
299 |
+
A_eq_triplet = np.stack([
|
300 |
+
np.arange(num_valid),
|
301 |
+
np.arange(Q)[valid],
|
302 |
+
np.ones(num_valid),
|
303 |
+
], axis=1) # [num_valid, 3]
|
304 |
+
A_eq = sp.sparse.coo_matrix((A_eq_triplet[:, 2], (A_eq_triplet[:, 0], A_eq_triplet[:, 1])), shape=[num_valid, Q+2*C]) # [num_valid, Q+C]
|
305 |
+
b_eq = np.where(weights[valid] > 0.5, 1, 0) # [num_valid]
|
306 |
+
|
307 |
+
# Return the result
|
308 |
+
quads_weight = res.x[:Q]
|
309 |
+
conn_min_weight = res.x[Q:Q+C]
|
310 |
+
conn_max_weight = res.x[Q+C:Q+2*C]
|
311 |
+
return quads_weight, conn_min_weight, conn_max_weight
|
312 |
+
|
313 |
+
|
314 |
+
def sovle_quad_qp(
|
315 |
+
face2edge: np.ndarray,
|
316 |
+
edge2face: np.ndarray,
|
317 |
+
quad2adj: np.ndarray,
|
318 |
+
quads_distortion: np.ndarray,
|
319 |
+
quads_smoothness: np.ndarray,
|
320 |
+
quads_valid: np.ndarray,
|
321 |
+
):
|
322 |
+
"""
|
323 |
+
Solve the quad mesh from the candidate quad faces.
|
324 |
+
|
325 |
+
Args:
|
326 |
+
face2edge (np.ndarray): [T, 3] face to edge relation
|
327 |
+
edge2face (np.ndarray): [E, 2] edge to face relation
|
328 |
+
quad2adj (np.ndarray): [Q, 8] adjacent quad faces of each quad face
|
329 |
+
quads_distortion (np.ndarray): [Q] distortion of each quad face
|
330 |
+
quads_smoothness (np.ndarray): [Q, 8] smoothness of each quad face connection
|
331 |
+
quads_valid (np.ndarray): [E] whether the quad corresponding to the edge is valid
|
332 |
+
|
333 |
+
Returns:
|
334 |
+
weights (np.ndarray): [Q] weight of each valid quad face
|
335 |
+
"""
|
336 |
+
T = face2edge.shape[0]
|
337 |
+
E = edge2face.shape[0]
|
338 |
+
Q = quads_distortion.shape[0]
|
339 |
+
edge_valid = -np.ones(E, dtype=np.int32)
|
340 |
+
edge_valid[quads_valid] = np.arange(Q)
|
341 |
+
|
342 |
+
# Construct the quadratic programming problem
|
343 |
+
C_smoothness_triplet = np.stack([
|
344 |
+
np.arange(Q)[:, None].repeat(8, axis=1)[quad2adj != -1],
|
345 |
+
quad2adj[quad2adj != -1],
|
346 |
+
5 * quads_smoothness[quad2adj != -1],
|
347 |
+
], axis=-1) # [C, 3]
|
348 |
+
# C_smoothness_triplet = np.concatenate([
|
349 |
+
# C_smoothness_triplet,
|
350 |
+
# np.stack([np.arange(Q), np.arange(Q), 20*np.ones(Q)], axis=1),
|
351 |
+
# ], axis=0) # [C+Q, 3]
|
352 |
+
C_smoothness = sp.sparse.coo_matrix((C_smoothness_triplet[:, 2], (C_smoothness_triplet[:, 0], C_smoothness_triplet[:, 1])), shape=[Q, Q]) # [Q, Q]
|
353 |
+
C_smoothness = C_smoothness.tocsc()
|
354 |
+
C_dist = quads_distortion - 20 # [Q]
|
355 |
+
|
356 |
+
A_eq = sp.sparse.coo_matrix((np.zeros(Q), (np.zeros(Q), np.arange(Q))), shape=[1, Q]) # [1, Q]\
|
357 |
+
A_eq = A_eq.tocsc()
|
358 |
+
b_eq = np.array([0])
|
359 |
+
|
360 |
+
A_ub_triplet = np.concatenate([
|
361 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 0]], np.ones(T)], axis=1), # [T, 3]
|
362 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 1]], np.ones(T)], axis=1), # [T, 3]
|
363 |
+
np.stack([np.arange(T), edge_valid[face2edge[:, 2]], np.ones(T)], axis=1), # [T, 3]
|
364 |
+
], axis=0) # [3T, 3]
|
365 |
+
A_ub_triplet = A_ub_triplet[A_ub_triplet[:, 1] != -1] # [3T', 3]
|
366 |
+
A_ub = sp.sparse.coo_matrix((A_ub_triplet[:, 2], (A_ub_triplet[:, 0], A_ub_triplet[:, 1])), shape=[T, Q]) # [T, Q]
|
367 |
+
A_ub = A_ub.tocsc()
|
368 |
+
b_ub = np.ones(T)
|
369 |
+
|
370 |
+
lb = np.zeros(Q)
|
371 |
+
ub = np.ones(Q)
|
372 |
+
|
373 |
+
solver = piqp.SparseSolver()
|
374 |
+
solver.settings.verbose = True
|
375 |
+
solver.settings.compute_timings = True
|
376 |
+
solver.setup(C_smoothness, C_dist, A_eq, b_eq, A_ub, b_ub, lb, ub)
|
377 |
+
|
378 |
+
status = solver.solve()
|
379 |
+
|
380 |
+
# x = cp.Variable(Q)
|
381 |
+
# prob = cp.Problem(
|
382 |
+
# cp.Minimize(cp.quad_form(x, C_smoothness) + C_dist.T @ x),
|
383 |
+
# [
|
384 |
+
# A_ub @ x <= b_ub,
|
385 |
+
# x >= 0, x <= 1,
|
386 |
+
# ]
|
387 |
+
# )
|
388 |
+
|
389 |
+
# # Solve the quadratic programming problem
|
390 |
+
# prob.solve(solver=cp.PIQP, verbose=True)
|
391 |
+
|
392 |
+
# Return the result
|
393 |
+
weights = solver.result.x
|
394 |
+
return weights
|
395 |
+
|
396 |
+
|
397 |
+
def tri_to_quad(
|
398 |
+
vertices: np.ndarray,
|
399 |
+
faces: np.ndarray,
|
400 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
401 |
+
"""
|
402 |
+
Convert a triangle mesh to a quad mesh.
|
403 |
+
NOTE: The input mesh must be a manifold mesh.
|
404 |
+
|
405 |
+
Args:
|
406 |
+
vertices (np.ndarray): [N, 3] 3-dimensional vertices
|
407 |
+
faces (np.ndarray): [T, 3] triangular face indices
|
408 |
+
|
409 |
+
Returns:
|
410 |
+
vertices (np.ndarray): [N_, 3] 3-dimensional vertices
|
411 |
+
faces (np.ndarray): [Q, 4] quad face indices
|
412 |
+
"""
|
413 |
+
raise NotImplementedError
|
414 |
+
|
415 |
+
|
416 |
+
if __name__ == '__main__':
|
417 |
+
import os
|
418 |
+
import sys
|
419 |
+
sys.path.insert(0, os.path.abspath(os.path.join(os.path.dirname(__file__), '..', '..', '..')))
|
420 |
+
import utils3d
|
421 |
+
import numpy as np
|
422 |
+
import cv2
|
423 |
+
from vis import vis_edge_color
|
424 |
+
|
425 |
+
file = 'miku'
|
426 |
+
|
427 |
+
vertices, faces = utils3d.io.read_ply(f'test/assets/{file}.ply')
|
428 |
+
edges, edge2face, face2edge, face2face = calc_relations(faces)
|
429 |
+
quad_cands, quad2edge, quad2adj, quad_valid = calc_quad_candidates(edges, face2edge, edge2face)
|
430 |
+
distortion = calc_quad_distortion(vertices, quad_cands)
|
431 |
+
direction = calc_quad_direction(vertices, quad_cands)
|
432 |
+
smoothness = calc_quad_smoothness(quad2edge, quad2adj, direction)
|
433 |
+
boundary_edges = edges[edge2face[:, 1] == -1]
|
434 |
+
quads_weight, conn_min_weight, conn_max_weight = sovle_quad(face2edge, edge2face, quad2adj, distortion, smoothness, quad_valid)
|
435 |
+
quads = quad_cands[quads_weight > 0.5]
|
436 |
+
print('Mesh statistics')
|
437 |
+
print(f' #V = {vertices.shape[0]}')
|
438 |
+
print(f' #F = {faces.shape[0]}')
|
439 |
+
print(f' #E = {edges.shape[0]}')
|
440 |
+
print(f' #B = {boundary_edges.shape[0]}')
|
441 |
+
print(f' #Q_cand = {quad_cands.shape[0]}')
|
442 |
+
print(f' #Q = {quads.shape[0]}')
|
443 |
+
|
444 |
+
utils3d.io.write_ply(f'test/assets/{file}_boundary_edges.ply', vertices=vertices, edges=boundary_edges)
|
445 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_candidates.ply', vertices=vertices, faces=quads)
|
446 |
+
|
447 |
+
edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8)
|
448 |
+
distortion = (distortion - distortion.min()) / (distortion.max() - distortion.min())
|
449 |
+
distortion = (distortion * 255).astype(np.uint8)
|
450 |
+
edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap(distortion, cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
|
451 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_distortion.ply', **vis_edge_color(vertices, edges, edge_colors))
|
452 |
+
|
453 |
+
edge_colors = np.zeros([edges.shape[0], 3], dtype=np.uint8)
|
454 |
+
edge_colors[quad_valid] = cv2.cvtColor(cv2.applyColorMap((quads_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
|
455 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_candidates_weights.ply', **vis_edge_color(vertices, edges, edge_colors))
|
456 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad.ply', vertices=vertices, faces=quads)
|
457 |
+
|
458 |
+
quad_centers = vertices[quad_cands].mean(axis=1)
|
459 |
+
conns = np.stack([
|
460 |
+
np.arange(quad_cands.shape[0])[:, None].repeat(8, axis=1),
|
461 |
+
quad2adj,
|
462 |
+
], axis=-1)[quad2adj != -1] # [C, 2]
|
463 |
+
conns, conns_idx = np.unique(np.sort(conns, axis=-1), axis=0, return_index=True) # [C, 2], [C]
|
464 |
+
smoothness = smoothness[quad2adj != -1][conns_idx] # [C]
|
465 |
+
conns_color = cv2.cvtColor(cv2.applyColorMap((smoothness * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
|
466 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_conn_smoothness.ply', **vis_edge_color(quad_centers, conns, conns_color))
|
467 |
+
conns_color = cv2.cvtColor(cv2.applyColorMap((conn_min_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
|
468 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_conn_min.ply', **vis_edge_color(quad_centers, conns, conns_color))
|
469 |
+
conns_color = cv2.cvtColor(cv2.applyColorMap((conn_max_weight * 255).astype(np.uint8), cv2.COLORMAP_JET), cv2.COLOR_BGR2RGB).reshape(-1, 3)
|
470 |
+
utils3d.io.write_ply(f'test/assets/{file}_quad_conn_max.ply', **vis_edge_color(quad_centers, conns, conns_color))
|
471 |
+
|
472 |
+
|
utils3d/numpy/rasterization.py
ADDED
@@ -0,0 +1,471 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import os
|
2 |
+
from typing import *
|
3 |
+
|
4 |
+
import numpy as np
|
5 |
+
import moderngl
|
6 |
+
|
7 |
+
from . import transforms, utils, mesh
|
8 |
+
|
9 |
+
|
10 |
+
__all__ = [
|
11 |
+
'RastContext',
|
12 |
+
'rasterize_triangle_faces',
|
13 |
+
'rasterize_edges',
|
14 |
+
'texture',
|
15 |
+
'warp_image_by_depth',
|
16 |
+
]
|
17 |
+
|
18 |
+
|
19 |
+
def map_np_dtype(dtype) -> str:
|
20 |
+
if dtype == int:
|
21 |
+
return 'i4'
|
22 |
+
elif dtype == np.uint8:
|
23 |
+
return 'u1'
|
24 |
+
elif dtype == np.uint32:
|
25 |
+
return 'u2'
|
26 |
+
elif dtype == np.float16:
|
27 |
+
return 'f2'
|
28 |
+
elif dtype == np.float32:
|
29 |
+
return 'f4'
|
30 |
+
|
31 |
+
|
32 |
+
def one_value(dtype):
|
33 |
+
if dtype == 'u1':
|
34 |
+
return 255
|
35 |
+
elif dtype == 'u2':
|
36 |
+
return 65535
|
37 |
+
else:
|
38 |
+
return 1
|
39 |
+
|
40 |
+
|
41 |
+
class RastContext:
|
42 |
+
def __init__(self, standalone: bool = True, backend: str = None, **kwargs):
|
43 |
+
"""
|
44 |
+
Create a moderngl context.
|
45 |
+
|
46 |
+
Args:
|
47 |
+
standalone (bool, optional): whether to create a standalone context. Defaults to True.
|
48 |
+
backend (str, optional): backend to use. Defaults to None.
|
49 |
+
|
50 |
+
Keyword Args:
|
51 |
+
See moderngl.create_context
|
52 |
+
"""
|
53 |
+
if backend is None:
|
54 |
+
self.mgl_ctx = moderngl.create_context(standalone=standalone, **kwargs)
|
55 |
+
else:
|
56 |
+
self.mgl_ctx = moderngl.create_context(standalone=standalone, backend=backend, **kwargs)
|
57 |
+
|
58 |
+
self.__prog_src = {}
|
59 |
+
self.__prog = {}
|
60 |
+
|
61 |
+
def __del__(self):
|
62 |
+
self.mgl_ctx.release()
|
63 |
+
|
64 |
+
def screen_quad(self) -> moderngl.VertexArray:
|
65 |
+
self.screen_quad_vbo = self.mgl_ctx.buffer(np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]], dtype='f4'))
|
66 |
+
self.screen_quad_ibo = self.mgl_ctx.buffer(np.array([0, 1, 2, 0, 2, 3], dtype=np.int32))
|
67 |
+
|
68 |
+
def program_vertex_attribute(self, n: int) -> moderngl.Program:
|
69 |
+
assert n in [1, 2, 3, 4], 'vertex attribute only supports channels 1, 2, 3, 4'
|
70 |
+
|
71 |
+
if 'vertex_attribute_vsh' not in self.__prog_src:
|
72 |
+
with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.vsh'), 'r') as f:
|
73 |
+
self.__prog_src['vertex_attribute_vsh'] = f.read()
|
74 |
+
if 'vertex_attribute_fsh' not in self.__prog_src:
|
75 |
+
with open(os.path.join(os.path.dirname(__file__), 'shaders', 'vertex_attribute.fsh'), 'r') as f:
|
76 |
+
self.__prog_src['vertex_attribute_fsh'] = f.read()
|
77 |
+
|
78 |
+
if f'vertex_attribute_{n}' not in self.__prog:
|
79 |
+
vsh = self.__prog_src['vertex_attribute_vsh'].replace('vecN', f'vec{n}')
|
80 |
+
fsh = self.__prog_src['vertex_attribute_fsh'].replace('vecN', f'vec{n}')
|
81 |
+
self.__prog[f'vertex_attribute_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh)
|
82 |
+
|
83 |
+
return self.__prog[f'vertex_attribute_{n}']
|
84 |
+
|
85 |
+
def program_texture(self, n: int) -> moderngl.Program:
|
86 |
+
assert n in [1, 2, 3, 4], 'texture only supports channels 1, 2, 3, 4'
|
87 |
+
|
88 |
+
if 'texture_vsh' not in self.__prog_src:
|
89 |
+
with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.vsh'), 'r') as f:
|
90 |
+
self.__prog_src['texture_vsh'] = f.read()
|
91 |
+
if 'texture_fsh' not in self.__prog_src:
|
92 |
+
with open(os.path.join(os.path.dirname(__file__), 'shaders', 'texture.fsh'), 'r') as f:
|
93 |
+
self.__prog_src['texture_fsh'] = f.read()
|
94 |
+
|
95 |
+
if f'texture_{n}' not in self.__prog:
|
96 |
+
vsh = self.__prog_src['texture_vsh'].replace('vecN', f'vec{n}')
|
97 |
+
fsh = self.__prog_src['texture_fsh'].replace('vecN', f'vec{n}')
|
98 |
+
self.__prog[f'texture_{n}'] = self.mgl_ctx.program(vertex_shader=vsh, fragment_shader=fsh)
|
99 |
+
self.__prog[f'texture_{n}']['tex'] = 0
|
100 |
+
self.__prog[f'texture_{n}']['uv'] = 1
|
101 |
+
|
102 |
+
return self.__prog[f'texture_{n}']
|
103 |
+
|
104 |
+
|
105 |
+
def rasterize_triangle_faces(
|
106 |
+
ctx: RastContext,
|
107 |
+
vertices: np.ndarray,
|
108 |
+
faces: np.ndarray,
|
109 |
+
attr: np.ndarray,
|
110 |
+
width: int,
|
111 |
+
height: int,
|
112 |
+
transform: np.ndarray = None,
|
113 |
+
cull_backface: bool = True,
|
114 |
+
return_depth: bool = False,
|
115 |
+
image: np.ndarray = None,
|
116 |
+
depth: np.ndarray = None
|
117 |
+
) -> Tuple[np.ndarray, np.ndarray]:
|
118 |
+
"""
|
119 |
+
Rasterize vertex attribute.
|
120 |
+
|
121 |
+
Args:
|
122 |
+
vertices (np.ndarray): [N, 3]
|
123 |
+
faces (np.ndarray): [T, 3]
|
124 |
+
attr (np.ndarray): [N, C]
|
125 |
+
width (int): width of rendered image
|
126 |
+
height (int): height of rendered image
|
127 |
+
transform (np.ndarray): [4, 4] model-view-projection transformation matrix.
|
128 |
+
cull_backface (bool): whether to cull backface
|
129 |
+
image: (np.ndarray): [H, W, C] background image
|
130 |
+
depth: (np.ndarray): [H, W] background depth
|
131 |
+
|
132 |
+
Returns:
|
133 |
+
image (np.ndarray): [H, W, C] rendered image
|
134 |
+
depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.
|
135 |
+
"""
|
136 |
+
assert vertices.ndim == 2 and vertices.shape[1] == 3
|
137 |
+
assert faces.ndim == 2 and faces.shape[1] == 3, f"Faces should be a 2D array with shape (T, 3), but got {faces.shape}"
|
138 |
+
assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}'
|
139 |
+
assert vertices.shape[0] == attr.shape[0]
|
140 |
+
assert vertices.dtype == np.float32
|
141 |
+
assert faces.dtype == np.uint32 or faces.dtype == np.int32
|
142 |
+
assert attr.dtype == np.float32, "Attribute should be float32"
|
143 |
+
|
144 |
+
C = attr.shape[1]
|
145 |
+
prog = ctx.program_vertex_attribute(C)
|
146 |
+
|
147 |
+
transform = np.eye(4, np.float32) if transform is None else transform
|
148 |
+
|
149 |
+
# Create buffers
|
150 |
+
ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(faces, dtype='i4'))
|
151 |
+
vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4'))
|
152 |
+
vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4'))
|
153 |
+
vao = ctx.mgl_ctx.vertex_array(
|
154 |
+
prog,
|
155 |
+
[
|
156 |
+
(vbo_vertices, '3f', 'i_position'),
|
157 |
+
(vbo_attr, f'{C}f', 'i_attr'),
|
158 |
+
],
|
159 |
+
ibo,
|
160 |
+
mode=moderngl.TRIANGLES,
|
161 |
+
)
|
162 |
+
|
163 |
+
# Create framebuffer
|
164 |
+
image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None)
|
165 |
+
depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None)
|
166 |
+
fbo = ctx.mgl_ctx.framebuffer(
|
167 |
+
color_attachments=[image_tex],
|
168 |
+
depth_attachment=depth_tex,
|
169 |
+
)
|
170 |
+
|
171 |
+
# Render
|
172 |
+
prog['u_mvp'].write(transform.transpose().copy().astype('f4'))
|
173 |
+
fbo.use()
|
174 |
+
fbo.viewport = (0, 0, width, height)
|
175 |
+
ctx.mgl_ctx.depth_func = '<'
|
176 |
+
ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST)
|
177 |
+
if cull_backface:
|
178 |
+
ctx.mgl_ctx.enable(ctx.mgl_ctx.CULL_FACE)
|
179 |
+
else:
|
180 |
+
ctx.mgl_ctx.disable(ctx.mgl_ctx.CULL_FACE)
|
181 |
+
vao.render()
|
182 |
+
ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST)
|
183 |
+
|
184 |
+
# Read
|
185 |
+
image = np.zeros((height, width, C), dtype='f4')
|
186 |
+
image_tex.read_into(image)
|
187 |
+
image = image[::-1, :, :]
|
188 |
+
if return_depth:
|
189 |
+
depth = np.zeros((height, width), dtype='f4')
|
190 |
+
depth_tex.read_into(depth)
|
191 |
+
depth = depth[::-1, :]
|
192 |
+
else:
|
193 |
+
depth = None
|
194 |
+
|
195 |
+
# Release
|
196 |
+
vao.release()
|
197 |
+
ibo.release()
|
198 |
+
vbo_vertices.release()
|
199 |
+
vbo_attr.release()
|
200 |
+
fbo.release()
|
201 |
+
image_tex.release()
|
202 |
+
depth_tex.release()
|
203 |
+
|
204 |
+
return image, depth
|
205 |
+
|
206 |
+
|
207 |
+
def rasterize_edges(
|
208 |
+
ctx: RastContext,
|
209 |
+
vertices: np.ndarray,
|
210 |
+
edges: np.ndarray,
|
211 |
+
attr: np.ndarray,
|
212 |
+
width: int,
|
213 |
+
height: int,
|
214 |
+
transform: np.ndarray = None,
|
215 |
+
line_width: float = 1.0,
|
216 |
+
return_depth: bool = False,
|
217 |
+
image: np.ndarray = None,
|
218 |
+
depth: np.ndarray = None
|
219 |
+
) -> Tuple[np.ndarray, ...]:
|
220 |
+
"""
|
221 |
+
Rasterize vertex attribute.
|
222 |
+
|
223 |
+
Args:
|
224 |
+
vertices (np.ndarray): [N, 3]
|
225 |
+
faces (np.ndarray): [T, 3]
|
226 |
+
attr (np.ndarray): [N, C]
|
227 |
+
width (int): width of rendered image
|
228 |
+
height (int): height of rendered image
|
229 |
+
transform (np.ndarray): [4, 4] model-view-projection matrix
|
230 |
+
line_width (float): width of line. Defaults to 1.0. NOTE: Values other than 1.0 may not work across all platforms.
|
231 |
+
cull_backface (bool): whether to cull backface
|
232 |
+
|
233 |
+
Returns:
|
234 |
+
image (np.ndarray): [H, W, C] rendered image
|
235 |
+
depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.
|
236 |
+
"""
|
237 |
+
assert vertices.ndim == 2 and vertices.shape[1] == 3
|
238 |
+
assert edges.ndim == 2 and edges.shape[1] == 2, f"Edges should be a 2D array with shape (T, 2), but got {edges.shape}"
|
239 |
+
assert attr.ndim == 2 and attr.shape[1] in [1, 2, 3, 4], f'Vertex attribute only supports channels 1, 2, 3, 4, but got {attr.shape}'
|
240 |
+
assert vertices.shape[0] == attr.shape[0]
|
241 |
+
assert vertices.dtype == np.float32
|
242 |
+
assert edges.dtype == np.uint32 or edges.dtype == np.int32
|
243 |
+
assert attr.dtype == np.float32, "Attribute should be float32"
|
244 |
+
|
245 |
+
C = attr.shape[1]
|
246 |
+
prog = ctx.program_vertex_attribute(C)
|
247 |
+
|
248 |
+
transform = transform if transform is not None else np.eye(4, np.float32)
|
249 |
+
|
250 |
+
# Create buffers
|
251 |
+
ibo = ctx.mgl_ctx.buffer(np.ascontiguousarray(edges, dtype='i4'))
|
252 |
+
vbo_vertices = ctx.mgl_ctx.buffer(np.ascontiguousarray(vertices, dtype='f4'))
|
253 |
+
vbo_attr = ctx.mgl_ctx.buffer(np.ascontiguousarray(attr, dtype='f4'))
|
254 |
+
vao = ctx.mgl_ctx.vertex_array(
|
255 |
+
prog,
|
256 |
+
[
|
257 |
+
(vbo_vertices, '3f', 'i_position'),
|
258 |
+
(vbo_attr, f'{C}f', 'i_attr'),
|
259 |
+
],
|
260 |
+
ibo,
|
261 |
+
mode=moderngl.LINES,
|
262 |
+
)
|
263 |
+
|
264 |
+
# Create framebuffer
|
265 |
+
image_tex = ctx.mgl_ctx.texture((width, height), C, dtype='f4', data=np.ascontiguousarray(image[::-1, :, :]) if image is not None else None)
|
266 |
+
depth_tex = ctx.mgl_ctx.depth_texture((width, height), data=np.ascontiguousarray(depth[::-1, :]) if depth is not None else None)
|
267 |
+
fbo = ctx.mgl_ctx.framebuffer(
|
268 |
+
color_attachments=[image_tex],
|
269 |
+
depth_attachment=depth_tex,
|
270 |
+
)
|
271 |
+
|
272 |
+
# Render
|
273 |
+
prog['u_mvp'].write(transform.transpose().copy().astype('f4'))
|
274 |
+
fbo.use()
|
275 |
+
fbo.viewport = (0, 0, width, height)
|
276 |
+
ctx.mgl_ctx.depth_func = '<'
|
277 |
+
ctx.mgl_ctx.enable(ctx.mgl_ctx.DEPTH_TEST)
|
278 |
+
ctx.mgl_ctx.line_width = line_width
|
279 |
+
vao.render()
|
280 |
+
ctx.mgl_ctx.disable(ctx.mgl_ctx.DEPTH_TEST)
|
281 |
+
|
282 |
+
# Read
|
283 |
+
image = np.zeros((height, width, C), dtype='f4')
|
284 |
+
image_tex.read_into(image)
|
285 |
+
image = image[::-1, :, :]
|
286 |
+
if return_depth:
|
287 |
+
depth = np.zeros((height, width), dtype='f4')
|
288 |
+
depth_tex.read_into(depth)
|
289 |
+
depth = depth[::-1, :]
|
290 |
+
else:
|
291 |
+
depth = None
|
292 |
+
|
293 |
+
# Release
|
294 |
+
vao.release()
|
295 |
+
ibo.release()
|
296 |
+
vbo_vertices.release()
|
297 |
+
vbo_attr.release()
|
298 |
+
fbo.release()
|
299 |
+
image_tex.release()
|
300 |
+
depth_tex.release()
|
301 |
+
|
302 |
+
return image, depth
|
303 |
+
|
304 |
+
|
305 |
+
def texture(
|
306 |
+
ctx: RastContext,
|
307 |
+
uv: np.ndarray,
|
308 |
+
texture: np.ndarray,
|
309 |
+
interpolation: str= 'linear',
|
310 |
+
wrap: str = 'clamp'
|
311 |
+
) -> np.ndarray:
|
312 |
+
"""
|
313 |
+
Given an UV image, texturing from the texture map
|
314 |
+
"""
|
315 |
+
assert len(texture.shape) == 3 and 1 <= texture.shape[2] <= 4
|
316 |
+
assert uv.shape[2] == 2
|
317 |
+
height, width = uv.shape[:2]
|
318 |
+
texture_dtype = map_np_dtype(texture.dtype)
|
319 |
+
|
320 |
+
# Create VAO
|
321 |
+
screen_quad_vbo = ctx.mgl_ctx.buffer(np.array([[-1, -1], [1, -1], [1, 1], [-1, 1]], dtype='f4'))
|
322 |
+
screen_quad_ibo = ctx.mgl_ctx.buffer(np.array([0, 1, 2, 0, 2, 3], dtype=np.int32))
|
323 |
+
screen_quad_vao = ctx.mgl_ctx.vertex_array(ctx.program_texture(texture.shape[2]), [(screen_quad_vbo, '2f4', 'in_vert')], index_buffer=screen_quad_ibo, index_element_size=4)
|
324 |
+
|
325 |
+
# Create texture, set filter and bind. TODO: min mag filter, mipmap
|
326 |
+
texture_tex = ctx.mgl_ctx.texture((texture.shape[1], texture.shape[0]), texture.shape[2], dtype=texture_dtype, data=np.ascontiguousarray(texture))
|
327 |
+
if interpolation == 'linear':
|
328 |
+
texture_tex.filter = (moderngl.LINEAR, moderngl.LINEAR)
|
329 |
+
elif interpolation == 'nearest':
|
330 |
+
texture_tex.filter = (moderngl.NEAREST, moderngl.NEAREST)
|
331 |
+
texture_tex.use(location=0)
|
332 |
+
texture_uv = ctx.mgl_ctx.texture((width, height), 2, dtype='f4', data=np.ascontiguousarray(uv.astype('f4', copy=False)))
|
333 |
+
texture_uv.filter = (moderngl.NEAREST, moderngl.NEAREST)
|
334 |
+
texture_uv.use(location=1)
|
335 |
+
|
336 |
+
# Create render buffer and frame buffer
|
337 |
+
rb = ctx.mgl_ctx.renderbuffer((uv.shape[1], uv.shape[0]), texture.shape[2], dtype=texture_dtype)
|
338 |
+
fbo = ctx.mgl_ctx.framebuffer(color_attachments=[rb])
|
339 |
+
|
340 |
+
# Render
|
341 |
+
fbo.use()
|
342 |
+
fbo.viewport = (0, 0, width, height)
|
343 |
+
ctx.mgl_ctx.disable(ctx.mgl_ctx.BLEND)
|
344 |
+
screen_quad_vao.render()
|
345 |
+
|
346 |
+
# Read buffer
|
347 |
+
image_buffer = np.frombuffer(fbo.read(components=texture.shape[2], attachment=0, dtype=texture_dtype), dtype=texture_dtype).reshape((height, width, texture.shape[2]))
|
348 |
+
|
349 |
+
# Release
|
350 |
+
texture_tex.release()
|
351 |
+
rb.release()
|
352 |
+
fbo.release()
|
353 |
+
|
354 |
+
return image_buffer
|
355 |
+
|
356 |
+
|
357 |
+
def warp_image_by_depth(
|
358 |
+
ctx: RastContext,
|
359 |
+
src_depth: np.ndarray,
|
360 |
+
src_image: np.ndarray = None,
|
361 |
+
width: int = None,
|
362 |
+
height: int = None,
|
363 |
+
*,
|
364 |
+
extrinsics_src: np.ndarray = None,
|
365 |
+
extrinsics_tgt: np.ndarray = None,
|
366 |
+
intrinsics_src: np.ndarray = None,
|
367 |
+
intrinsics_tgt: np.ndarray = None,
|
368 |
+
near: float = 0.1,
|
369 |
+
far: float = 100.0,
|
370 |
+
cull_backface: bool = True,
|
371 |
+
ssaa: int = 1,
|
372 |
+
return_depth: bool = False,
|
373 |
+
) -> Tuple[np.ndarray, ...]:
|
374 |
+
"""
|
375 |
+
Warp image by depth map.
|
376 |
+
|
377 |
+
Args:
|
378 |
+
ctx (RastContext): rasterizer context
|
379 |
+
src_depth (np.ndarray): [H, W]
|
380 |
+
src_image (np.ndarray, optional): [H, W, C]. The image to warp. Defaults to None (use uv coordinates).
|
381 |
+
width (int, optional): width of the output image. None to use depth map width. Defaults to None.
|
382 |
+
height (int, optional): height of the output image. None to use depth map height. Defaults to None.
|
383 |
+
extrinsics_src (np.ndarray, optional): extrinsics matrix of the source camera. Defaults to None (identity).
|
384 |
+
extrinsics_tgt (np.ndarray, optional): extrinsics matrix of the target camera. Defaults to None (identity).
|
385 |
+
intrinsics_src (np.ndarray, optional): intrinsics matrix of the source camera. Defaults to None (use the same as intrinsics_tgt).
|
386 |
+
intrinsics_tgt (np.ndarray, optional): intrinsics matrix of the target camera. Defaults to None (use the same as intrinsics_src).
|
387 |
+
cull_backface (bool, optional): whether to cull backface. Defaults to True.
|
388 |
+
ssaa (int, optional): super sampling anti-aliasing. Defaults to 1.
|
389 |
+
|
390 |
+
Returns:
|
391 |
+
tgt_image (np.ndarray): [H, W, C] warped image (or uv coordinates if image is None).
|
392 |
+
tgt_depth (np.ndarray): [H, W] screen space depth, ranging from 0 to 1. If return_depth is False, it is None.
|
393 |
+
"""
|
394 |
+
assert src_depth.ndim == 2
|
395 |
+
|
396 |
+
if width is None:
|
397 |
+
width = src_depth.shape[1]
|
398 |
+
if height is None:
|
399 |
+
height = src_depth.shape[0]
|
400 |
+
if src_image is not None:
|
401 |
+
assert src_image.shape[-2:] == src_depth.shape[-2:], f'Shape of source image {src_image.shape} does not match shape of source depth {src_depth.shape}'
|
402 |
+
|
403 |
+
# set up default camera parameters
|
404 |
+
extrinsics_src = np.eye(4) if extrinsics_src is None else extrinsics_src
|
405 |
+
extrinsics_tgt = np.eye(4) if extrinsics_tgt is None else extrinsics_tgt
|
406 |
+
intrinsics_src = intrinsics_tgt if intrinsics_src is None else intrinsics_src
|
407 |
+
intrinsics_tgt = intrinsics_src if intrinsics_tgt is None else intrinsics_tgt
|
408 |
+
|
409 |
+
assert all(x is not None for x in [extrinsics_src, extrinsics_tgt, intrinsics_src, intrinsics_tgt]), "Make sure you have provided all the necessary camera parameters."
|
410 |
+
|
411 |
+
# check shapes
|
412 |
+
assert extrinsics_src.shape == (4, 4) and extrinsics_tgt.shape == (4, 4)
|
413 |
+
assert intrinsics_src.shape == (3, 3) and intrinsics_tgt.shape == (3, 3)
|
414 |
+
|
415 |
+
# convert to view and perspective matrices
|
416 |
+
view_tgt = transforms.extrinsics_to_view(extrinsics_tgt)
|
417 |
+
perspective_tgt = transforms.intrinsics_to_perspective(intrinsics_tgt, near=near, far=far)
|
418 |
+
|
419 |
+
# unproject depth map
|
420 |
+
uv, faces = utils.image_mesh(*src_depth.shape[-2:])
|
421 |
+
pts = transforms.unproject_cv(uv, src_depth.reshape(-1), extrinsics_src, intrinsics_src)
|
422 |
+
faces = mesh.triangulate(faces, vertices=pts)
|
423 |
+
|
424 |
+
# rasterize attributes
|
425 |
+
if src_image is not None:
|
426 |
+
attr = src_image.reshape(-1, src_image.shape[-1])
|
427 |
+
else:
|
428 |
+
attr = uv
|
429 |
+
|
430 |
+
tgt_image, tgt_depth = rasterize_triangle_faces(
|
431 |
+
ctx,
|
432 |
+
pts,
|
433 |
+
faces,
|
434 |
+
attr,
|
435 |
+
width * ssaa,
|
436 |
+
height * ssaa,
|
437 |
+
transform=perspective_tgt @ view_tgt,
|
438 |
+
cull_backface=cull_backface,
|
439 |
+
return_depth=return_depth,
|
440 |
+
)
|
441 |
+
|
442 |
+
if ssaa > 1:
|
443 |
+
tgt_image = tgt_image.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3))
|
444 |
+
tgt_depth = tgt_depth.reshape(height, ssaa, width, ssaa, -1).mean(axis=(1, 3)) if return_depth else None
|
445 |
+
|
446 |
+
return tgt_image, tgt_depth
|
447 |
+
|
448 |
+
def test():
|
449 |
+
"""
|
450 |
+
Test if rasterization works. It will render a cube with random colors and save it as a CHECKME.png file.
|
451 |
+
"""
|
452 |
+
ctx = RastContext(backend='egl')
|
453 |
+
vertices, faces = utils.cube(tri=True)
|
454 |
+
attr = np.random.rand(len(vertices), 3).astype(np.float32)
|
455 |
+
perspective = transforms.perspective(np.deg2rad(60), 1, 0.01, 100)
|
456 |
+
view = transforms.view_look_at(np.array([2, 2, 2]), np.array([0, 0, 0]), np.array([0, 1, 0]))
|
457 |
+
image, _ = rasterize_triangle_faces(
|
458 |
+
ctx,
|
459 |
+
vertices,
|
460 |
+
faces,
|
461 |
+
attr,
|
462 |
+
512, 512,
|
463 |
+
view=view,
|
464 |
+
projection=perspective,
|
465 |
+
cull_backface=True,
|
466 |
+
ssaa=1,
|
467 |
+
return_depth=True,
|
468 |
+
)
|
469 |
+
import cv2
|
470 |
+
cv2.imwrite('CHECKME.png', cv2.cvtColor((image.clip(0, 1) * 255).astype(np.uint8), cv2.COLOR_RGB2BGR))
|
471 |
+
|