Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes.
See raw diff
- depth_anything_v2/__init__.py +1 -0
- depth_anything_v2/__pycache__/__init__.cpython-310.pyc +0 -0
- depth_anything_v2/__pycache__/dinov2.cpython-310.pyc +0 -0
- depth_anything_v2/__pycache__/dpt.cpython-310.pyc +0 -0
- depth_anything_v2/__pycache__/moments_dataset.cpython-310.pyc +0 -0
- depth_anything_v2/__pycache__/processing_utils.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2.py +415 -0
- depth_anything_v2/dinov2_layers/__init__.py +11 -0
- depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc +0 -0
- depth_anything_v2/dinov2_layers/attention.py +83 -0
- depth_anything_v2/dinov2_layers/block.py +252 -0
- depth_anything_v2/dinov2_layers/drop_path.py +35 -0
- depth_anything_v2/dinov2_layers/layer_scale.py +28 -0
- depth_anything_v2/dinov2_layers/mlp.py +41 -0
- depth_anything_v2/dinov2_layers/patch_embed.py +89 -0
- depth_anything_v2/dinov2_layers/swiglu_ffn.py +63 -0
- depth_anything_v2/dpt.py +224 -0
- depth_anything_v2/moments_dataset.py +54 -0
- depth_anything_v2/moments_processing.py +345 -0
- depth_anything_v2/processing_utils.py +318 -0
- depth_anything_v2/softmax-splatting/README.md +90 -0
- depth_anything_v2/softmax-splatting/__pycache__/softsplat.cpython-310.pyc +0 -0
- depth_anything_v2/softmax-splatting/benchmark_middlebury.py +35 -0
- depth_anything_v2/softmax-splatting/benchmark_xiph.py +91 -0
- depth_anything_v2/softmax-splatting/correlation/README.md +1 -0
- depth_anything_v2/softmax-splatting/correlation/correlation.py +400 -0
- depth_anything_v2/softmax-splatting/images/README.md +1 -0
- depth_anything_v2/softmax-splatting/images/flow.flo +0 -0
- depth_anything_v2/softmax-splatting/images/one.png +0 -0
- depth_anything_v2/softmax-splatting/images/two.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/Beanbags/frame10.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/Beanbags/frame10i11.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/Beanbags/frame11.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/Dimetrodon/frame10.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/Dimetrodon/frame10i11.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/Dimetrodon/frame11.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/DogDance/frame10.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/DogDance/frame10i11.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/DogDance/frame11.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/Grove2/frame10.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/Grove2/frame10i11.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/Grove2/frame11.png +0 -0
- depth_anything_v2/softmax-splatting/middlebury/Grove3/frame10.png +0 -0
depth_anything_v2/__init__.py
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
from .dpt import DepthAnythingV2
|
depth_anything_v2/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (193 Bytes). View file
|
|
depth_anything_v2/__pycache__/dinov2.cpython-310.pyc
ADDED
Binary file (12.2 kB). View file
|
|
depth_anything_v2/__pycache__/dpt.cpython-310.pyc
ADDED
Binary file (6.17 kB). View file
|
|
depth_anything_v2/__pycache__/moments_dataset.cpython-310.pyc
ADDED
Binary file (1.97 kB). View file
|
|
depth_anything_v2/__pycache__/processing_utils.cpython-310.pyc
ADDED
Binary file (7.67 kB). View file
|
|
depth_anything_v2/dinov2.py
ADDED
@@ -0,0 +1,415 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
#
|
3 |
+
# This source code is licensed under the Apache License, Version 2.0
|
4 |
+
# found in the LICENSE file in the root directory of this source tree.
|
5 |
+
|
6 |
+
# References:
|
7 |
+
# https://github.com/facebookresearch/dino/blob/main/vision_transformer.py
|
8 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
9 |
+
|
10 |
+
from functools import partial
|
11 |
+
import math
|
12 |
+
import logging
|
13 |
+
from typing import Sequence, Tuple, Union, Callable
|
14 |
+
|
15 |
+
import torch
|
16 |
+
import torch.nn as nn
|
17 |
+
import torch.utils.checkpoint
|
18 |
+
from torch.nn.init import trunc_normal_
|
19 |
+
|
20 |
+
from .dinov2_layers import Mlp, PatchEmbed, SwiGLUFFNFused, MemEffAttention, NestedTensorBlock as Block
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
def named_apply(fn: Callable, module: nn.Module, name="", depth_first=True, include_root=False) -> nn.Module:
|
27 |
+
if not depth_first and include_root:
|
28 |
+
fn(module=module, name=name)
|
29 |
+
for child_name, child_module in module.named_children():
|
30 |
+
child_name = ".".join((name, child_name)) if name else child_name
|
31 |
+
named_apply(fn=fn, module=child_module, name=child_name, depth_first=depth_first, include_root=True)
|
32 |
+
if depth_first and include_root:
|
33 |
+
fn(module=module, name=name)
|
34 |
+
return module
|
35 |
+
|
36 |
+
|
37 |
+
class BlockChunk(nn.ModuleList):
|
38 |
+
def forward(self, x):
|
39 |
+
for b in self:
|
40 |
+
x = b(x)
|
41 |
+
return x
|
42 |
+
|
43 |
+
|
44 |
+
class DinoVisionTransformer(nn.Module):
|
45 |
+
def __init__(
|
46 |
+
self,
|
47 |
+
img_size=224,
|
48 |
+
patch_size=16,
|
49 |
+
in_chans=3,
|
50 |
+
embed_dim=768,
|
51 |
+
depth=12,
|
52 |
+
num_heads=12,
|
53 |
+
mlp_ratio=4.0,
|
54 |
+
qkv_bias=True,
|
55 |
+
ffn_bias=True,
|
56 |
+
proj_bias=True,
|
57 |
+
drop_path_rate=0.0,
|
58 |
+
drop_path_uniform=False,
|
59 |
+
init_values=None, # for layerscale: None or 0 => no layerscale
|
60 |
+
embed_layer=PatchEmbed,
|
61 |
+
act_layer=nn.GELU,
|
62 |
+
block_fn=Block,
|
63 |
+
ffn_layer="mlp",
|
64 |
+
block_chunks=1,
|
65 |
+
num_register_tokens=0,
|
66 |
+
interpolate_antialias=False,
|
67 |
+
interpolate_offset=0.1,
|
68 |
+
):
|
69 |
+
"""
|
70 |
+
Args:
|
71 |
+
img_size (int, tuple): input image size
|
72 |
+
patch_size (int, tuple): patch size
|
73 |
+
in_chans (int): number of input channels
|
74 |
+
embed_dim (int): embedding dimension
|
75 |
+
depth (int): depth of transformer
|
76 |
+
num_heads (int): number of attention heads
|
77 |
+
mlp_ratio (int): ratio of mlp hidden dim to embedding dim
|
78 |
+
qkv_bias (bool): enable bias for qkv if True
|
79 |
+
proj_bias (bool): enable bias for proj in attn if True
|
80 |
+
ffn_bias (bool): enable bias for ffn if True
|
81 |
+
drop_path_rate (float): stochastic depth rate
|
82 |
+
drop_path_uniform (bool): apply uniform drop rate across blocks
|
83 |
+
weight_init (str): weight init scheme
|
84 |
+
init_values (float): layer-scale init values
|
85 |
+
embed_layer (nn.Module): patch embedding layer
|
86 |
+
act_layer (nn.Module): MLP activation layer
|
87 |
+
block_fn (nn.Module): transformer block class
|
88 |
+
ffn_layer (str): "mlp", "swiglu", "swiglufused" or "identity"
|
89 |
+
block_chunks: (int) split block sequence into block_chunks units for FSDP wrap
|
90 |
+
num_register_tokens: (int) number of extra cls tokens (so-called "registers")
|
91 |
+
interpolate_antialias: (str) flag to apply anti-aliasing when interpolating positional embeddings
|
92 |
+
interpolate_offset: (float) work-around offset to apply when interpolating positional embeddings
|
93 |
+
"""
|
94 |
+
super().__init__()
|
95 |
+
norm_layer = partial(nn.LayerNorm, eps=1e-6)
|
96 |
+
|
97 |
+
self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
|
98 |
+
self.num_tokens = 1
|
99 |
+
self.n_blocks = depth
|
100 |
+
self.num_heads = num_heads
|
101 |
+
self.patch_size = patch_size
|
102 |
+
self.num_register_tokens = num_register_tokens
|
103 |
+
self.interpolate_antialias = interpolate_antialias
|
104 |
+
self.interpolate_offset = interpolate_offset
|
105 |
+
|
106 |
+
self.patch_embed = embed_layer(img_size=img_size, patch_size=patch_size, in_chans=in_chans, embed_dim=embed_dim)
|
107 |
+
num_patches = self.patch_embed.num_patches
|
108 |
+
|
109 |
+
self.cls_token = nn.Parameter(torch.zeros(1, 1, embed_dim))
|
110 |
+
self.pos_embed = nn.Parameter(torch.zeros(1, num_patches + self.num_tokens, embed_dim))
|
111 |
+
assert num_register_tokens >= 0
|
112 |
+
self.register_tokens = (
|
113 |
+
nn.Parameter(torch.zeros(1, num_register_tokens, embed_dim)) if num_register_tokens else None
|
114 |
+
)
|
115 |
+
|
116 |
+
if drop_path_uniform is True:
|
117 |
+
dpr = [drop_path_rate] * depth
|
118 |
+
else:
|
119 |
+
dpr = [x.item() for x in torch.linspace(0, drop_path_rate, depth)] # stochastic depth decay rule
|
120 |
+
|
121 |
+
if ffn_layer == "mlp":
|
122 |
+
logger.info("using MLP layer as FFN")
|
123 |
+
ffn_layer = Mlp
|
124 |
+
elif ffn_layer == "swiglufused" or ffn_layer == "swiglu":
|
125 |
+
logger.info("using SwiGLU layer as FFN")
|
126 |
+
ffn_layer = SwiGLUFFNFused
|
127 |
+
elif ffn_layer == "identity":
|
128 |
+
logger.info("using Identity layer as FFN")
|
129 |
+
|
130 |
+
def f(*args, **kwargs):
|
131 |
+
return nn.Identity()
|
132 |
+
|
133 |
+
ffn_layer = f
|
134 |
+
else:
|
135 |
+
raise NotImplementedError
|
136 |
+
|
137 |
+
blocks_list = [
|
138 |
+
block_fn(
|
139 |
+
dim=embed_dim,
|
140 |
+
num_heads=num_heads,
|
141 |
+
mlp_ratio=mlp_ratio,
|
142 |
+
qkv_bias=qkv_bias,
|
143 |
+
proj_bias=proj_bias,
|
144 |
+
ffn_bias=ffn_bias,
|
145 |
+
drop_path=dpr[i],
|
146 |
+
norm_layer=norm_layer,
|
147 |
+
act_layer=act_layer,
|
148 |
+
ffn_layer=ffn_layer,
|
149 |
+
init_values=init_values,
|
150 |
+
)
|
151 |
+
for i in range(depth)
|
152 |
+
]
|
153 |
+
if block_chunks > 0:
|
154 |
+
self.chunked_blocks = True
|
155 |
+
chunked_blocks = []
|
156 |
+
chunksize = depth // block_chunks
|
157 |
+
for i in range(0, depth, chunksize):
|
158 |
+
# this is to keep the block index consistent if we chunk the block list
|
159 |
+
chunked_blocks.append([nn.Identity()] * i + blocks_list[i : i + chunksize])
|
160 |
+
self.blocks = nn.ModuleList([BlockChunk(p) for p in chunked_blocks])
|
161 |
+
else:
|
162 |
+
self.chunked_blocks = False
|
163 |
+
self.blocks = nn.ModuleList(blocks_list)
|
164 |
+
|
165 |
+
self.norm = norm_layer(embed_dim)
|
166 |
+
self.head = nn.Identity()
|
167 |
+
|
168 |
+
self.mask_token = nn.Parameter(torch.zeros(1, embed_dim))
|
169 |
+
|
170 |
+
self.init_weights()
|
171 |
+
|
172 |
+
def init_weights(self):
|
173 |
+
trunc_normal_(self.pos_embed, std=0.02)
|
174 |
+
nn.init.normal_(self.cls_token, std=1e-6)
|
175 |
+
if self.register_tokens is not None:
|
176 |
+
nn.init.normal_(self.register_tokens, std=1e-6)
|
177 |
+
named_apply(init_weights_vit_timm, self)
|
178 |
+
|
179 |
+
def interpolate_pos_encoding(self, x, w, h):
|
180 |
+
previous_dtype = x.dtype
|
181 |
+
npatch = x.shape[1] - 1
|
182 |
+
N = self.pos_embed.shape[1] - 1
|
183 |
+
if npatch == N and w == h:
|
184 |
+
return self.pos_embed
|
185 |
+
pos_embed = self.pos_embed.float()
|
186 |
+
class_pos_embed = pos_embed[:, 0]
|
187 |
+
patch_pos_embed = pos_embed[:, 1:]
|
188 |
+
dim = x.shape[-1]
|
189 |
+
w0 = w // self.patch_size
|
190 |
+
h0 = h // self.patch_size
|
191 |
+
# we add a small number to avoid floating point error in the interpolation
|
192 |
+
# see discussion at https://github.com/facebookresearch/dino/issues/8
|
193 |
+
# DINOv2 with register modify the interpolate_offset from 0.1 to 0.0
|
194 |
+
w0, h0 = w0 + self.interpolate_offset, h0 + self.interpolate_offset
|
195 |
+
# w0, h0 = w0 + 0.1, h0 + 0.1
|
196 |
+
|
197 |
+
sqrt_N = math.sqrt(N)
|
198 |
+
sx, sy = float(w0) / sqrt_N, float(h0) / sqrt_N
|
199 |
+
patch_pos_embed = nn.functional.interpolate(
|
200 |
+
patch_pos_embed.reshape(1, int(sqrt_N), int(sqrt_N), dim).permute(0, 3, 1, 2),
|
201 |
+
scale_factor=(sx, sy),
|
202 |
+
# (int(w0), int(h0)), # to solve the upsampling shape issue
|
203 |
+
mode="bicubic",
|
204 |
+
antialias=self.interpolate_antialias
|
205 |
+
)
|
206 |
+
|
207 |
+
assert int(w0) == patch_pos_embed.shape[-2]
|
208 |
+
assert int(h0) == patch_pos_embed.shape[-1]
|
209 |
+
patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim)
|
210 |
+
return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype)
|
211 |
+
|
212 |
+
def prepare_tokens_with_masks(self, x, masks=None):
|
213 |
+
B, nc, w, h = x.shape
|
214 |
+
x = self.patch_embed(x)
|
215 |
+
if masks is not None:
|
216 |
+
x = torch.where(masks.unsqueeze(-1), self.mask_token.to(x.dtype).unsqueeze(0), x)
|
217 |
+
|
218 |
+
x = torch.cat((self.cls_token.expand(x.shape[0], -1, -1), x), dim=1)
|
219 |
+
x = x + self.interpolate_pos_encoding(x, w, h)
|
220 |
+
|
221 |
+
if self.register_tokens is not None:
|
222 |
+
x = torch.cat(
|
223 |
+
(
|
224 |
+
x[:, :1],
|
225 |
+
self.register_tokens.expand(x.shape[0], -1, -1),
|
226 |
+
x[:, 1:],
|
227 |
+
),
|
228 |
+
dim=1,
|
229 |
+
)
|
230 |
+
|
231 |
+
return x
|
232 |
+
|
233 |
+
def forward_features_list(self, x_list, masks_list):
|
234 |
+
x = [self.prepare_tokens_with_masks(x, masks) for x, masks in zip(x_list, masks_list)]
|
235 |
+
for blk in self.blocks:
|
236 |
+
x = blk(x)
|
237 |
+
|
238 |
+
all_x = x
|
239 |
+
output = []
|
240 |
+
for x, masks in zip(all_x, masks_list):
|
241 |
+
x_norm = self.norm(x)
|
242 |
+
output.append(
|
243 |
+
{
|
244 |
+
"x_norm_clstoken": x_norm[:, 0],
|
245 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
246 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
247 |
+
"x_prenorm": x,
|
248 |
+
"masks": masks,
|
249 |
+
}
|
250 |
+
)
|
251 |
+
return output
|
252 |
+
|
253 |
+
def forward_features(self, x, masks=None):
|
254 |
+
if isinstance(x, list):
|
255 |
+
return self.forward_features_list(x, masks)
|
256 |
+
|
257 |
+
x = self.prepare_tokens_with_masks(x, masks)
|
258 |
+
|
259 |
+
for blk in self.blocks:
|
260 |
+
x = blk(x)
|
261 |
+
|
262 |
+
x_norm = self.norm(x)
|
263 |
+
return {
|
264 |
+
"x_norm_clstoken": x_norm[:, 0],
|
265 |
+
"x_norm_regtokens": x_norm[:, 1 : self.num_register_tokens + 1],
|
266 |
+
"x_norm_patchtokens": x_norm[:, self.num_register_tokens + 1 :],
|
267 |
+
"x_prenorm": x,
|
268 |
+
"masks": masks,
|
269 |
+
}
|
270 |
+
|
271 |
+
def _get_intermediate_layers_not_chunked(self, x, n=1):
|
272 |
+
x = self.prepare_tokens_with_masks(x)
|
273 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
274 |
+
output, total_block_len = [], len(self.blocks)
|
275 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
276 |
+
for i, blk in enumerate(self.blocks):
|
277 |
+
x = blk(x)
|
278 |
+
if i in blocks_to_take:
|
279 |
+
output.append(x)
|
280 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
281 |
+
return output
|
282 |
+
|
283 |
+
def _get_intermediate_layers_chunked(self, x, n=1):
|
284 |
+
x = self.prepare_tokens_with_masks(x)
|
285 |
+
output, i, total_block_len = [], 0, len(self.blocks[-1])
|
286 |
+
# If n is an int, take the n last blocks. If it's a list, take them
|
287 |
+
blocks_to_take = range(total_block_len - n, total_block_len) if isinstance(n, int) else n
|
288 |
+
for block_chunk in self.blocks:
|
289 |
+
for blk in block_chunk[i:]: # Passing the nn.Identity()
|
290 |
+
x = blk(x)
|
291 |
+
if i in blocks_to_take:
|
292 |
+
output.append(x)
|
293 |
+
i += 1
|
294 |
+
assert len(output) == len(blocks_to_take), f"only {len(output)} / {len(blocks_to_take)} blocks found"
|
295 |
+
return output
|
296 |
+
|
297 |
+
def get_intermediate_layers(
|
298 |
+
self,
|
299 |
+
x: torch.Tensor,
|
300 |
+
n: Union[int, Sequence] = 1, # Layers or n last layers to take
|
301 |
+
reshape: bool = False,
|
302 |
+
return_class_token: bool = False,
|
303 |
+
norm=True
|
304 |
+
) -> Tuple[Union[torch.Tensor, Tuple[torch.Tensor]]]:
|
305 |
+
if self.chunked_blocks:
|
306 |
+
outputs = self._get_intermediate_layers_chunked(x, n)
|
307 |
+
else:
|
308 |
+
outputs = self._get_intermediate_layers_not_chunked(x, n)
|
309 |
+
if norm:
|
310 |
+
outputs = [self.norm(out) for out in outputs]
|
311 |
+
class_tokens = [out[:, 0] for out in outputs]
|
312 |
+
outputs = [out[:, 1 + self.num_register_tokens:] for out in outputs]
|
313 |
+
if reshape:
|
314 |
+
B, _, w, h = x.shape
|
315 |
+
outputs = [
|
316 |
+
out.reshape(B, w // self.patch_size, h // self.patch_size, -1).permute(0, 3, 1, 2).contiguous()
|
317 |
+
for out in outputs
|
318 |
+
]
|
319 |
+
if return_class_token:
|
320 |
+
return tuple(zip(outputs, class_tokens))
|
321 |
+
return tuple(outputs)
|
322 |
+
|
323 |
+
def forward(self, *args, is_training=False, **kwargs):
|
324 |
+
ret = self.forward_features(*args, **kwargs)
|
325 |
+
if is_training:
|
326 |
+
return ret
|
327 |
+
else:
|
328 |
+
return self.head(ret["x_norm_clstoken"])
|
329 |
+
|
330 |
+
|
331 |
+
def init_weights_vit_timm(module: nn.Module, name: str = ""):
|
332 |
+
"""ViT weight initialization, original timm impl (for reproducibility)"""
|
333 |
+
if isinstance(module, nn.Linear):
|
334 |
+
trunc_normal_(module.weight, std=0.02)
|
335 |
+
if module.bias is not None:
|
336 |
+
nn.init.zeros_(module.bias)
|
337 |
+
|
338 |
+
|
339 |
+
def vit_small(patch_size=16, num_register_tokens=0, **kwargs):
|
340 |
+
model = DinoVisionTransformer(
|
341 |
+
patch_size=patch_size,
|
342 |
+
embed_dim=384,
|
343 |
+
depth=12,
|
344 |
+
num_heads=6,
|
345 |
+
mlp_ratio=4,
|
346 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
347 |
+
num_register_tokens=num_register_tokens,
|
348 |
+
**kwargs,
|
349 |
+
)
|
350 |
+
return model
|
351 |
+
|
352 |
+
|
353 |
+
def vit_base(patch_size=16, num_register_tokens=0, **kwargs):
|
354 |
+
model = DinoVisionTransformer(
|
355 |
+
patch_size=patch_size,
|
356 |
+
embed_dim=768,
|
357 |
+
depth=12,
|
358 |
+
num_heads=12,
|
359 |
+
mlp_ratio=4,
|
360 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
361 |
+
num_register_tokens=num_register_tokens,
|
362 |
+
**kwargs,
|
363 |
+
)
|
364 |
+
return model
|
365 |
+
|
366 |
+
|
367 |
+
def vit_large(patch_size=16, num_register_tokens=0, **kwargs):
|
368 |
+
model = DinoVisionTransformer(
|
369 |
+
patch_size=patch_size,
|
370 |
+
embed_dim=1024,
|
371 |
+
depth=24,
|
372 |
+
num_heads=16,
|
373 |
+
mlp_ratio=4,
|
374 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
375 |
+
num_register_tokens=num_register_tokens,
|
376 |
+
**kwargs,
|
377 |
+
)
|
378 |
+
return model
|
379 |
+
|
380 |
+
|
381 |
+
def vit_giant2(patch_size=16, num_register_tokens=0, **kwargs):
|
382 |
+
"""
|
383 |
+
Close to ViT-giant, with embed-dim 1536 and 24 heads => embed-dim per head 64
|
384 |
+
"""
|
385 |
+
model = DinoVisionTransformer(
|
386 |
+
patch_size=patch_size,
|
387 |
+
embed_dim=1536,
|
388 |
+
depth=40,
|
389 |
+
num_heads=24,
|
390 |
+
mlp_ratio=4,
|
391 |
+
block_fn=partial(Block, attn_class=MemEffAttention),
|
392 |
+
num_register_tokens=num_register_tokens,
|
393 |
+
**kwargs,
|
394 |
+
)
|
395 |
+
return model
|
396 |
+
|
397 |
+
|
398 |
+
def DINOv2(model_name):
|
399 |
+
model_zoo = {
|
400 |
+
"vits": vit_small,
|
401 |
+
"vitb": vit_base,
|
402 |
+
"vitl": vit_large,
|
403 |
+
"vitg": vit_giant2
|
404 |
+
}
|
405 |
+
|
406 |
+
return model_zoo[model_name](
|
407 |
+
img_size=518,
|
408 |
+
patch_size=14,
|
409 |
+
init_values=1.0,
|
410 |
+
ffn_layer="mlp" if model_name != "vitg" else "swiglufused",
|
411 |
+
block_chunks=0,
|
412 |
+
num_register_tokens=0,
|
413 |
+
interpolate_antialias=False,
|
414 |
+
interpolate_offset=0.1
|
415 |
+
)
|
depth_anything_v2/dinov2_layers/__init__.py
ADDED
@@ -0,0 +1,11 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from .mlp import Mlp
|
8 |
+
from .patch_embed import PatchEmbed
|
9 |
+
from .swiglu_ffn import SwiGLUFFN, SwiGLUFFNFused
|
10 |
+
from .block import NestedTensorBlock
|
11 |
+
from .attention import MemEffAttention
|
depth_anything_v2/dinov2_layers/__pycache__/__init__.cpython-310.pyc
ADDED
Binary file (406 Bytes). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/attention.cpython-310.pyc
ADDED
Binary file (2.37 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/block.cpython-310.pyc
ADDED
Binary file (7.98 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/drop_path.cpython-310.pyc
ADDED
Binary file (1.21 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/layer_scale.cpython-310.pyc
ADDED
Binary file (1.01 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/mlp.cpython-310.pyc
ADDED
Binary file (1.2 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/patch_embed.cpython-310.pyc
ADDED
Binary file (2.65 kB). View file
|
|
depth_anything_v2/dinov2_layers/__pycache__/swiglu_ffn.cpython-310.pyc
ADDED
Binary file (2 kB). View file
|
|
depth_anything_v2/dinov2_layers/attention.py
ADDED
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/models/vision_transformer.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
from torch import nn
|
15 |
+
|
16 |
+
|
17 |
+
logger = logging.getLogger("dinov2")
|
18 |
+
|
19 |
+
|
20 |
+
try:
|
21 |
+
from xformers.ops import memory_efficient_attention, unbind, fmha
|
22 |
+
|
23 |
+
XFORMERS_AVAILABLE = True
|
24 |
+
except ImportError:
|
25 |
+
logger.warning("xFormers not available")
|
26 |
+
XFORMERS_AVAILABLE = False
|
27 |
+
|
28 |
+
|
29 |
+
class Attention(nn.Module):
|
30 |
+
def __init__(
|
31 |
+
self,
|
32 |
+
dim: int,
|
33 |
+
num_heads: int = 8,
|
34 |
+
qkv_bias: bool = False,
|
35 |
+
proj_bias: bool = True,
|
36 |
+
attn_drop: float = 0.0,
|
37 |
+
proj_drop: float = 0.0,
|
38 |
+
) -> None:
|
39 |
+
super().__init__()
|
40 |
+
self.num_heads = num_heads
|
41 |
+
head_dim = dim // num_heads
|
42 |
+
self.scale = head_dim**-0.5
|
43 |
+
|
44 |
+
self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
|
45 |
+
self.attn_drop = nn.Dropout(attn_drop)
|
46 |
+
self.proj = nn.Linear(dim, dim, bias=proj_bias)
|
47 |
+
self.proj_drop = nn.Dropout(proj_drop)
|
48 |
+
|
49 |
+
def forward(self, x: Tensor) -> Tensor:
|
50 |
+
B, N, C = x.shape
|
51 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads).permute(2, 0, 3, 1, 4)
|
52 |
+
|
53 |
+
q, k, v = qkv[0] * self.scale, qkv[1], qkv[2]
|
54 |
+
attn = q @ k.transpose(-2, -1)
|
55 |
+
|
56 |
+
attn = attn.softmax(dim=-1)
|
57 |
+
attn = self.attn_drop(attn)
|
58 |
+
|
59 |
+
x = (attn @ v).transpose(1, 2).reshape(B, N, C)
|
60 |
+
x = self.proj(x)
|
61 |
+
x = self.proj_drop(x)
|
62 |
+
return x
|
63 |
+
|
64 |
+
|
65 |
+
class MemEffAttention(Attention):
|
66 |
+
def forward(self, x: Tensor, attn_bias=None) -> Tensor:
|
67 |
+
if not XFORMERS_AVAILABLE:
|
68 |
+
assert attn_bias is None, "xFormers is required for nested tensors usage"
|
69 |
+
return super().forward(x)
|
70 |
+
|
71 |
+
B, N, C = x.shape
|
72 |
+
qkv = self.qkv(x).reshape(B, N, 3, self.num_heads, C // self.num_heads)
|
73 |
+
|
74 |
+
q, k, v = unbind(qkv, 2)
|
75 |
+
|
76 |
+
x = memory_efficient_attention(q, k, v, attn_bias=attn_bias)
|
77 |
+
x = x.reshape([B, N, C])
|
78 |
+
|
79 |
+
x = self.proj(x)
|
80 |
+
x = self.proj_drop(x)
|
81 |
+
return x
|
82 |
+
|
83 |
+
|
depth_anything_v2/dinov2_layers/block.py
ADDED
@@ -0,0 +1,252 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
import logging
|
12 |
+
from typing import Callable, List, Any, Tuple, Dict
|
13 |
+
|
14 |
+
import torch
|
15 |
+
from torch import nn, Tensor
|
16 |
+
|
17 |
+
from .attention import Attention, MemEffAttention
|
18 |
+
from .drop_path import DropPath
|
19 |
+
from .layer_scale import LayerScale
|
20 |
+
from .mlp import Mlp
|
21 |
+
|
22 |
+
|
23 |
+
logger = logging.getLogger("dinov2")
|
24 |
+
|
25 |
+
|
26 |
+
try:
|
27 |
+
from xformers.ops import fmha
|
28 |
+
from xformers.ops import scaled_index_add, index_select_cat
|
29 |
+
|
30 |
+
XFORMERS_AVAILABLE = True
|
31 |
+
except ImportError:
|
32 |
+
logger.warning("xFormers not available")
|
33 |
+
XFORMERS_AVAILABLE = False
|
34 |
+
|
35 |
+
|
36 |
+
class Block(nn.Module):
|
37 |
+
def __init__(
|
38 |
+
self,
|
39 |
+
dim: int,
|
40 |
+
num_heads: int,
|
41 |
+
mlp_ratio: float = 4.0,
|
42 |
+
qkv_bias: bool = False,
|
43 |
+
proj_bias: bool = True,
|
44 |
+
ffn_bias: bool = True,
|
45 |
+
drop: float = 0.0,
|
46 |
+
attn_drop: float = 0.0,
|
47 |
+
init_values=None,
|
48 |
+
drop_path: float = 0.0,
|
49 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
50 |
+
norm_layer: Callable[..., nn.Module] = nn.LayerNorm,
|
51 |
+
attn_class: Callable[..., nn.Module] = Attention,
|
52 |
+
ffn_layer: Callable[..., nn.Module] = Mlp,
|
53 |
+
) -> None:
|
54 |
+
super().__init__()
|
55 |
+
# print(f"biases: qkv: {qkv_bias}, proj: {proj_bias}, ffn: {ffn_bias}")
|
56 |
+
self.norm1 = norm_layer(dim)
|
57 |
+
self.attn = attn_class(
|
58 |
+
dim,
|
59 |
+
num_heads=num_heads,
|
60 |
+
qkv_bias=qkv_bias,
|
61 |
+
proj_bias=proj_bias,
|
62 |
+
attn_drop=attn_drop,
|
63 |
+
proj_drop=drop,
|
64 |
+
)
|
65 |
+
self.ls1 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
66 |
+
self.drop_path1 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
67 |
+
|
68 |
+
self.norm2 = norm_layer(dim)
|
69 |
+
mlp_hidden_dim = int(dim * mlp_ratio)
|
70 |
+
self.mlp = ffn_layer(
|
71 |
+
in_features=dim,
|
72 |
+
hidden_features=mlp_hidden_dim,
|
73 |
+
act_layer=act_layer,
|
74 |
+
drop=drop,
|
75 |
+
bias=ffn_bias,
|
76 |
+
)
|
77 |
+
self.ls2 = LayerScale(dim, init_values=init_values) if init_values else nn.Identity()
|
78 |
+
self.drop_path2 = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
|
79 |
+
|
80 |
+
self.sample_drop_ratio = drop_path
|
81 |
+
|
82 |
+
def forward(self, x: Tensor) -> Tensor:
|
83 |
+
def attn_residual_func(x: Tensor) -> Tensor:
|
84 |
+
return self.ls1(self.attn(self.norm1(x)))
|
85 |
+
|
86 |
+
def ffn_residual_func(x: Tensor) -> Tensor:
|
87 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
88 |
+
|
89 |
+
if self.training and self.sample_drop_ratio > 0.1:
|
90 |
+
# the overhead is compensated only for a drop path rate larger than 0.1
|
91 |
+
x = drop_add_residual_stochastic_depth(
|
92 |
+
x,
|
93 |
+
residual_func=attn_residual_func,
|
94 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
95 |
+
)
|
96 |
+
x = drop_add_residual_stochastic_depth(
|
97 |
+
x,
|
98 |
+
residual_func=ffn_residual_func,
|
99 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
100 |
+
)
|
101 |
+
elif self.training and self.sample_drop_ratio > 0.0:
|
102 |
+
x = x + self.drop_path1(attn_residual_func(x))
|
103 |
+
x = x + self.drop_path1(ffn_residual_func(x)) # FIXME: drop_path2
|
104 |
+
else:
|
105 |
+
x = x + attn_residual_func(x)
|
106 |
+
x = x + ffn_residual_func(x)
|
107 |
+
return x
|
108 |
+
|
109 |
+
|
110 |
+
def drop_add_residual_stochastic_depth(
|
111 |
+
x: Tensor,
|
112 |
+
residual_func: Callable[[Tensor], Tensor],
|
113 |
+
sample_drop_ratio: float = 0.0,
|
114 |
+
) -> Tensor:
|
115 |
+
# 1) extract subset using permutation
|
116 |
+
b, n, d = x.shape
|
117 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
118 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
119 |
+
x_subset = x[brange]
|
120 |
+
|
121 |
+
# 2) apply residual_func to get residual
|
122 |
+
residual = residual_func(x_subset)
|
123 |
+
|
124 |
+
x_flat = x.flatten(1)
|
125 |
+
residual = residual.flatten(1)
|
126 |
+
|
127 |
+
residual_scale_factor = b / sample_subset_size
|
128 |
+
|
129 |
+
# 3) add the residual
|
130 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
131 |
+
return x_plus_residual.view_as(x)
|
132 |
+
|
133 |
+
|
134 |
+
def get_branges_scales(x, sample_drop_ratio=0.0):
|
135 |
+
b, n, d = x.shape
|
136 |
+
sample_subset_size = max(int(b * (1 - sample_drop_ratio)), 1)
|
137 |
+
brange = (torch.randperm(b, device=x.device))[:sample_subset_size]
|
138 |
+
residual_scale_factor = b / sample_subset_size
|
139 |
+
return brange, residual_scale_factor
|
140 |
+
|
141 |
+
|
142 |
+
def add_residual(x, brange, residual, residual_scale_factor, scaling_vector=None):
|
143 |
+
if scaling_vector is None:
|
144 |
+
x_flat = x.flatten(1)
|
145 |
+
residual = residual.flatten(1)
|
146 |
+
x_plus_residual = torch.index_add(x_flat, 0, brange, residual.to(dtype=x.dtype), alpha=residual_scale_factor)
|
147 |
+
else:
|
148 |
+
x_plus_residual = scaled_index_add(
|
149 |
+
x, brange, residual.to(dtype=x.dtype), scaling=scaling_vector, alpha=residual_scale_factor
|
150 |
+
)
|
151 |
+
return x_plus_residual
|
152 |
+
|
153 |
+
|
154 |
+
attn_bias_cache: Dict[Tuple, Any] = {}
|
155 |
+
|
156 |
+
|
157 |
+
def get_attn_bias_and_cat(x_list, branges=None):
|
158 |
+
"""
|
159 |
+
this will perform the index select, cat the tensors, and provide the attn_bias from cache
|
160 |
+
"""
|
161 |
+
batch_sizes = [b.shape[0] for b in branges] if branges is not None else [x.shape[0] for x in x_list]
|
162 |
+
all_shapes = tuple((b, x.shape[1]) for b, x in zip(batch_sizes, x_list))
|
163 |
+
if all_shapes not in attn_bias_cache.keys():
|
164 |
+
seqlens = []
|
165 |
+
for b, x in zip(batch_sizes, x_list):
|
166 |
+
for _ in range(b):
|
167 |
+
seqlens.append(x.shape[1])
|
168 |
+
attn_bias = fmha.BlockDiagonalMask.from_seqlens(seqlens)
|
169 |
+
attn_bias._batch_sizes = batch_sizes
|
170 |
+
attn_bias_cache[all_shapes] = attn_bias
|
171 |
+
|
172 |
+
if branges is not None:
|
173 |
+
cat_tensors = index_select_cat([x.flatten(1) for x in x_list], branges).view(1, -1, x_list[0].shape[-1])
|
174 |
+
else:
|
175 |
+
tensors_bs1 = tuple(x.reshape([1, -1, *x.shape[2:]]) for x in x_list)
|
176 |
+
cat_tensors = torch.cat(tensors_bs1, dim=1)
|
177 |
+
|
178 |
+
return attn_bias_cache[all_shapes], cat_tensors
|
179 |
+
|
180 |
+
|
181 |
+
def drop_add_residual_stochastic_depth_list(
|
182 |
+
x_list: List[Tensor],
|
183 |
+
residual_func: Callable[[Tensor, Any], Tensor],
|
184 |
+
sample_drop_ratio: float = 0.0,
|
185 |
+
scaling_vector=None,
|
186 |
+
) -> Tensor:
|
187 |
+
# 1) generate random set of indices for dropping samples in the batch
|
188 |
+
branges_scales = [get_branges_scales(x, sample_drop_ratio=sample_drop_ratio) for x in x_list]
|
189 |
+
branges = [s[0] for s in branges_scales]
|
190 |
+
residual_scale_factors = [s[1] for s in branges_scales]
|
191 |
+
|
192 |
+
# 2) get attention bias and index+concat the tensors
|
193 |
+
attn_bias, x_cat = get_attn_bias_and_cat(x_list, branges)
|
194 |
+
|
195 |
+
# 3) apply residual_func to get residual, and split the result
|
196 |
+
residual_list = attn_bias.split(residual_func(x_cat, attn_bias=attn_bias)) # type: ignore
|
197 |
+
|
198 |
+
outputs = []
|
199 |
+
for x, brange, residual, residual_scale_factor in zip(x_list, branges, residual_list, residual_scale_factors):
|
200 |
+
outputs.append(add_residual(x, brange, residual, residual_scale_factor, scaling_vector).view_as(x))
|
201 |
+
return outputs
|
202 |
+
|
203 |
+
|
204 |
+
class NestedTensorBlock(Block):
|
205 |
+
def forward_nested(self, x_list: List[Tensor]) -> List[Tensor]:
|
206 |
+
"""
|
207 |
+
x_list contains a list of tensors to nest together and run
|
208 |
+
"""
|
209 |
+
assert isinstance(self.attn, MemEffAttention)
|
210 |
+
|
211 |
+
if self.training and self.sample_drop_ratio > 0.0:
|
212 |
+
|
213 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
214 |
+
return self.attn(self.norm1(x), attn_bias=attn_bias)
|
215 |
+
|
216 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
217 |
+
return self.mlp(self.norm2(x))
|
218 |
+
|
219 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
220 |
+
x_list,
|
221 |
+
residual_func=attn_residual_func,
|
222 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
223 |
+
scaling_vector=self.ls1.gamma if isinstance(self.ls1, LayerScale) else None,
|
224 |
+
)
|
225 |
+
x_list = drop_add_residual_stochastic_depth_list(
|
226 |
+
x_list,
|
227 |
+
residual_func=ffn_residual_func,
|
228 |
+
sample_drop_ratio=self.sample_drop_ratio,
|
229 |
+
scaling_vector=self.ls2.gamma if isinstance(self.ls1, LayerScale) else None,
|
230 |
+
)
|
231 |
+
return x_list
|
232 |
+
else:
|
233 |
+
|
234 |
+
def attn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
235 |
+
return self.ls1(self.attn(self.norm1(x), attn_bias=attn_bias))
|
236 |
+
|
237 |
+
def ffn_residual_func(x: Tensor, attn_bias=None) -> Tensor:
|
238 |
+
return self.ls2(self.mlp(self.norm2(x)))
|
239 |
+
|
240 |
+
attn_bias, x = get_attn_bias_and_cat(x_list)
|
241 |
+
x = x + attn_residual_func(x, attn_bias=attn_bias)
|
242 |
+
x = x + ffn_residual_func(x)
|
243 |
+
return attn_bias.split(x)
|
244 |
+
|
245 |
+
def forward(self, x_or_x_list):
|
246 |
+
if isinstance(x_or_x_list, Tensor):
|
247 |
+
return super().forward(x_or_x_list)
|
248 |
+
elif isinstance(x_or_x_list, list):
|
249 |
+
assert XFORMERS_AVAILABLE, "Please install xFormers for nested tensors usage"
|
250 |
+
return self.forward_nested(x_or_x_list)
|
251 |
+
else:
|
252 |
+
raise AssertionError
|
depth_anything_v2/dinov2_layers/drop_path.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/drop.py
|
10 |
+
|
11 |
+
|
12 |
+
from torch import nn
|
13 |
+
|
14 |
+
|
15 |
+
def drop_path(x, drop_prob: float = 0.0, training: bool = False):
|
16 |
+
if drop_prob == 0.0 or not training:
|
17 |
+
return x
|
18 |
+
keep_prob = 1 - drop_prob
|
19 |
+
shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets
|
20 |
+
random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
|
21 |
+
if keep_prob > 0.0:
|
22 |
+
random_tensor.div_(keep_prob)
|
23 |
+
output = x * random_tensor
|
24 |
+
return output
|
25 |
+
|
26 |
+
|
27 |
+
class DropPath(nn.Module):
|
28 |
+
"""Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks)."""
|
29 |
+
|
30 |
+
def __init__(self, drop_prob=None):
|
31 |
+
super(DropPath, self).__init__()
|
32 |
+
self.drop_prob = drop_prob
|
33 |
+
|
34 |
+
def forward(self, x):
|
35 |
+
return drop_path(x, self.drop_prob, self.training)
|
depth_anything_v2/dinov2_layers/layer_scale.py
ADDED
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# Modified from: https://github.com/huggingface/pytorch-image-models/blob/main/timm/models/vision_transformer.py#L103-L110
|
8 |
+
|
9 |
+
from typing import Union
|
10 |
+
|
11 |
+
import torch
|
12 |
+
from torch import Tensor
|
13 |
+
from torch import nn
|
14 |
+
|
15 |
+
|
16 |
+
class LayerScale(nn.Module):
|
17 |
+
def __init__(
|
18 |
+
self,
|
19 |
+
dim: int,
|
20 |
+
init_values: Union[float, Tensor] = 1e-5,
|
21 |
+
inplace: bool = False,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
self.inplace = inplace
|
25 |
+
self.gamma = nn.Parameter(init_values * torch.ones(dim))
|
26 |
+
|
27 |
+
def forward(self, x: Tensor) -> Tensor:
|
28 |
+
return x.mul_(self.gamma) if self.inplace else x * self.gamma
|
depth_anything_v2/dinov2_layers/mlp.py
ADDED
@@ -0,0 +1,41 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/mlp.py
|
10 |
+
|
11 |
+
|
12 |
+
from typing import Callable, Optional
|
13 |
+
|
14 |
+
from torch import Tensor, nn
|
15 |
+
|
16 |
+
|
17 |
+
class Mlp(nn.Module):
|
18 |
+
def __init__(
|
19 |
+
self,
|
20 |
+
in_features: int,
|
21 |
+
hidden_features: Optional[int] = None,
|
22 |
+
out_features: Optional[int] = None,
|
23 |
+
act_layer: Callable[..., nn.Module] = nn.GELU,
|
24 |
+
drop: float = 0.0,
|
25 |
+
bias: bool = True,
|
26 |
+
) -> None:
|
27 |
+
super().__init__()
|
28 |
+
out_features = out_features or in_features
|
29 |
+
hidden_features = hidden_features or in_features
|
30 |
+
self.fc1 = nn.Linear(in_features, hidden_features, bias=bias)
|
31 |
+
self.act = act_layer()
|
32 |
+
self.fc2 = nn.Linear(hidden_features, out_features, bias=bias)
|
33 |
+
self.drop = nn.Dropout(drop)
|
34 |
+
|
35 |
+
def forward(self, x: Tensor) -> Tensor:
|
36 |
+
x = self.fc1(x)
|
37 |
+
x = self.act(x)
|
38 |
+
x = self.drop(x)
|
39 |
+
x = self.fc2(x)
|
40 |
+
x = self.drop(x)
|
41 |
+
return x
|
depth_anything_v2/dinov2_layers/patch_embed.py
ADDED
@@ -0,0 +1,89 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
# References:
|
8 |
+
# https://github.com/facebookresearch/dino/blob/master/vision_transformer.py
|
9 |
+
# https://github.com/rwightman/pytorch-image-models/tree/master/timm/layers/patch_embed.py
|
10 |
+
|
11 |
+
from typing import Callable, Optional, Tuple, Union
|
12 |
+
|
13 |
+
from torch import Tensor
|
14 |
+
import torch.nn as nn
|
15 |
+
|
16 |
+
|
17 |
+
def make_2tuple(x):
|
18 |
+
if isinstance(x, tuple):
|
19 |
+
assert len(x) == 2
|
20 |
+
return x
|
21 |
+
|
22 |
+
assert isinstance(x, int)
|
23 |
+
return (x, x)
|
24 |
+
|
25 |
+
|
26 |
+
class PatchEmbed(nn.Module):
|
27 |
+
"""
|
28 |
+
2D image to patch embedding: (B,C,H,W) -> (B,N,D)
|
29 |
+
|
30 |
+
Args:
|
31 |
+
img_size: Image size.
|
32 |
+
patch_size: Patch token size.
|
33 |
+
in_chans: Number of input image channels.
|
34 |
+
embed_dim: Number of linear projection output channels.
|
35 |
+
norm_layer: Normalization layer.
|
36 |
+
"""
|
37 |
+
|
38 |
+
def __init__(
|
39 |
+
self,
|
40 |
+
img_size: Union[int, Tuple[int, int]] = 224,
|
41 |
+
patch_size: Union[int, Tuple[int, int]] = 16,
|
42 |
+
in_chans: int = 3,
|
43 |
+
embed_dim: int = 768,
|
44 |
+
norm_layer: Optional[Callable] = None,
|
45 |
+
flatten_embedding: bool = True,
|
46 |
+
) -> None:
|
47 |
+
super().__init__()
|
48 |
+
|
49 |
+
image_HW = make_2tuple(img_size)
|
50 |
+
patch_HW = make_2tuple(patch_size)
|
51 |
+
patch_grid_size = (
|
52 |
+
image_HW[0] // patch_HW[0],
|
53 |
+
image_HW[1] // patch_HW[1],
|
54 |
+
)
|
55 |
+
|
56 |
+
self.img_size = image_HW
|
57 |
+
self.patch_size = patch_HW
|
58 |
+
self.patches_resolution = patch_grid_size
|
59 |
+
self.num_patches = patch_grid_size[0] * patch_grid_size[1]
|
60 |
+
|
61 |
+
self.in_chans = in_chans
|
62 |
+
self.embed_dim = embed_dim
|
63 |
+
|
64 |
+
self.flatten_embedding = flatten_embedding
|
65 |
+
|
66 |
+
self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_HW, stride=patch_HW)
|
67 |
+
self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity()
|
68 |
+
|
69 |
+
def forward(self, x: Tensor) -> Tensor:
|
70 |
+
_, _, H, W = x.shape
|
71 |
+
patch_H, patch_W = self.patch_size
|
72 |
+
|
73 |
+
assert H % patch_H == 0, f"Input image height {H} is not a multiple of patch height {patch_H}"
|
74 |
+
assert W % patch_W == 0, f"Input image width {W} is not a multiple of patch width: {patch_W}"
|
75 |
+
|
76 |
+
x = self.proj(x) # B C H W
|
77 |
+
H, W = x.size(2), x.size(3)
|
78 |
+
x = x.flatten(2).transpose(1, 2) # B HW C
|
79 |
+
x = self.norm(x)
|
80 |
+
if not self.flatten_embedding:
|
81 |
+
x = x.reshape(-1, H, W, self.embed_dim) # B H W C
|
82 |
+
return x
|
83 |
+
|
84 |
+
def flops(self) -> float:
|
85 |
+
Ho, Wo = self.patches_resolution
|
86 |
+
flops = Ho * Wo * self.embed_dim * self.in_chans * (self.patch_size[0] * self.patch_size[1])
|
87 |
+
if self.norm is not None:
|
88 |
+
flops += Ho * Wo * self.embed_dim
|
89 |
+
return flops
|
depth_anything_v2/dinov2_layers/swiglu_ffn.py
ADDED
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright (c) Meta Platforms, Inc. and affiliates.
|
2 |
+
# All rights reserved.
|
3 |
+
#
|
4 |
+
# This source code is licensed under the license found in the
|
5 |
+
# LICENSE file in the root directory of this source tree.
|
6 |
+
|
7 |
+
from typing import Callable, Optional
|
8 |
+
|
9 |
+
from torch import Tensor, nn
|
10 |
+
import torch.nn.functional as F
|
11 |
+
|
12 |
+
|
13 |
+
class SwiGLUFFN(nn.Module):
|
14 |
+
def __init__(
|
15 |
+
self,
|
16 |
+
in_features: int,
|
17 |
+
hidden_features: Optional[int] = None,
|
18 |
+
out_features: Optional[int] = None,
|
19 |
+
act_layer: Callable[..., nn.Module] = None,
|
20 |
+
drop: float = 0.0,
|
21 |
+
bias: bool = True,
|
22 |
+
) -> None:
|
23 |
+
super().__init__()
|
24 |
+
out_features = out_features or in_features
|
25 |
+
hidden_features = hidden_features or in_features
|
26 |
+
self.w12 = nn.Linear(in_features, 2 * hidden_features, bias=bias)
|
27 |
+
self.w3 = nn.Linear(hidden_features, out_features, bias=bias)
|
28 |
+
|
29 |
+
def forward(self, x: Tensor) -> Tensor:
|
30 |
+
x12 = self.w12(x)
|
31 |
+
x1, x2 = x12.chunk(2, dim=-1)
|
32 |
+
hidden = F.silu(x1) * x2
|
33 |
+
return self.w3(hidden)
|
34 |
+
|
35 |
+
|
36 |
+
try:
|
37 |
+
from xformers.ops import SwiGLU
|
38 |
+
|
39 |
+
XFORMERS_AVAILABLE = True
|
40 |
+
except ImportError:
|
41 |
+
SwiGLU = SwiGLUFFN
|
42 |
+
XFORMERS_AVAILABLE = False
|
43 |
+
|
44 |
+
|
45 |
+
class SwiGLUFFNFused(SwiGLU):
|
46 |
+
def __init__(
|
47 |
+
self,
|
48 |
+
in_features: int,
|
49 |
+
hidden_features: Optional[int] = None,
|
50 |
+
out_features: Optional[int] = None,
|
51 |
+
act_layer: Callable[..., nn.Module] = None,
|
52 |
+
drop: float = 0.0,
|
53 |
+
bias: bool = True,
|
54 |
+
) -> None:
|
55 |
+
out_features = out_features or in_features
|
56 |
+
hidden_features = hidden_features or in_features
|
57 |
+
hidden_features = (int(hidden_features * 2 / 3) + 7) // 8 * 8
|
58 |
+
super().__init__(
|
59 |
+
in_features=in_features,
|
60 |
+
hidden_features=hidden_features,
|
61 |
+
out_features=out_features,
|
62 |
+
bias=bias,
|
63 |
+
)
|
depth_anything_v2/dpt.py
ADDED
@@ -0,0 +1,224 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import pdb
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import torch
|
5 |
+
import torch.nn as nn
|
6 |
+
import torch.nn.functional as F
|
7 |
+
from torchvision.transforms import Compose
|
8 |
+
import numpy as np
|
9 |
+
|
10 |
+
from .dinov2 import DINOv2
|
11 |
+
from .util.blocks import FeatureFusionBlock, _make_scratch
|
12 |
+
from .util.transform import Resize, NormalizeImage, PrepareForNet
|
13 |
+
|
14 |
+
|
15 |
+
def _make_fusion_block(features, use_bn, size=None):
|
16 |
+
return FeatureFusionBlock(
|
17 |
+
features,
|
18 |
+
nn.ReLU(False),
|
19 |
+
deconv=False,
|
20 |
+
bn=use_bn,
|
21 |
+
expand=False,
|
22 |
+
align_corners=True,
|
23 |
+
size=size,
|
24 |
+
)
|
25 |
+
|
26 |
+
|
27 |
+
class ConvBlock(nn.Module):
|
28 |
+
def __init__(self, in_feature, out_feature):
|
29 |
+
super().__init__()
|
30 |
+
|
31 |
+
self.conv_block = nn.Sequential(
|
32 |
+
nn.Conv2d(in_feature, out_feature, kernel_size=3, stride=1, padding=1),
|
33 |
+
nn.BatchNorm2d(out_feature),
|
34 |
+
nn.ReLU(True)
|
35 |
+
)
|
36 |
+
|
37 |
+
def forward(self, x):
|
38 |
+
return self.conv_block(x)
|
39 |
+
|
40 |
+
|
41 |
+
class DPTHead(nn.Module):
|
42 |
+
def __init__(
|
43 |
+
self,
|
44 |
+
in_channels,
|
45 |
+
features=256,
|
46 |
+
use_bn=False,
|
47 |
+
out_channels=[256, 512, 1024, 1024],
|
48 |
+
use_clstoken=False
|
49 |
+
):
|
50 |
+
super(DPTHead, self).__init__()
|
51 |
+
|
52 |
+
self.use_clstoken = use_clstoken
|
53 |
+
|
54 |
+
self.projects = nn.ModuleList([
|
55 |
+
nn.Conv2d(
|
56 |
+
in_channels=in_channels,
|
57 |
+
out_channels=out_channel,
|
58 |
+
kernel_size=1,
|
59 |
+
stride=1,
|
60 |
+
padding=0,
|
61 |
+
) for out_channel in out_channels
|
62 |
+
])
|
63 |
+
|
64 |
+
self.resize_layers = nn.ModuleList([
|
65 |
+
nn.ConvTranspose2d(
|
66 |
+
in_channels=out_channels[0],
|
67 |
+
out_channels=out_channels[0],
|
68 |
+
kernel_size=4,
|
69 |
+
stride=4,
|
70 |
+
padding=0),
|
71 |
+
nn.ConvTranspose2d(
|
72 |
+
in_channels=out_channels[1],
|
73 |
+
out_channels=out_channels[1],
|
74 |
+
kernel_size=2,
|
75 |
+
stride=2,
|
76 |
+
padding=0),
|
77 |
+
nn.Identity(),
|
78 |
+
nn.Conv2d(
|
79 |
+
in_channels=out_channels[3],
|
80 |
+
out_channels=out_channels[3],
|
81 |
+
kernel_size=3,
|
82 |
+
stride=2,
|
83 |
+
padding=1)
|
84 |
+
])
|
85 |
+
|
86 |
+
if use_clstoken:
|
87 |
+
self.readout_projects = nn.ModuleList()
|
88 |
+
for _ in range(len(self.projects)):
|
89 |
+
self.readout_projects.append(
|
90 |
+
nn.Sequential(
|
91 |
+
nn.Linear(2 * in_channels, in_channels),
|
92 |
+
nn.GELU()))
|
93 |
+
|
94 |
+
self.scratch = _make_scratch(
|
95 |
+
out_channels,
|
96 |
+
features,
|
97 |
+
groups=1,
|
98 |
+
expand=False,
|
99 |
+
)
|
100 |
+
|
101 |
+
self.scratch.stem_transpose = None
|
102 |
+
|
103 |
+
self.scratch.refinenet1 = _make_fusion_block(features, use_bn)
|
104 |
+
self.scratch.refinenet2 = _make_fusion_block(features, use_bn)
|
105 |
+
self.scratch.refinenet3 = _make_fusion_block(features, use_bn)
|
106 |
+
self.scratch.refinenet4 = _make_fusion_block(features, use_bn)
|
107 |
+
|
108 |
+
head_features_1 = features
|
109 |
+
head_features_2 = 32
|
110 |
+
|
111 |
+
self.scratch.output_conv1 = nn.Conv2d(head_features_1, head_features_1 // 2, kernel_size=3, stride=1, padding=1)
|
112 |
+
self.scratch.output_conv2 = nn.Sequential(
|
113 |
+
nn.Conv2d(head_features_1 // 2, head_features_2, kernel_size=3, stride=1, padding=1),
|
114 |
+
nn.ReLU(True),
|
115 |
+
nn.Conv2d(head_features_2, 1, kernel_size=1, stride=1, padding=0),
|
116 |
+
nn.ReLU(True),
|
117 |
+
nn.Identity(),
|
118 |
+
)
|
119 |
+
|
120 |
+
def forward(self, out_features, patch_h, patch_w):
|
121 |
+
out = []
|
122 |
+
for i, x in enumerate(out_features):
|
123 |
+
if self.use_clstoken:
|
124 |
+
x, cls_token = x[0], x[1]
|
125 |
+
readout = cls_token.unsqueeze(1).expand_as(x)
|
126 |
+
x = self.readout_projects[i](torch.cat((x, readout), -1))
|
127 |
+
else:
|
128 |
+
x = x[0]
|
129 |
+
|
130 |
+
x = x.permute(0, 2, 1).reshape((x.shape[0], x.shape[-1], patch_h, patch_w))
|
131 |
+
|
132 |
+
x = self.projects[i](x)
|
133 |
+
x = self.resize_layers[i](x)
|
134 |
+
|
135 |
+
out.append(x)
|
136 |
+
|
137 |
+
layer_1, layer_2, layer_3, layer_4 = out
|
138 |
+
|
139 |
+
layer_1_rn = self.scratch.layer1_rn(layer_1)
|
140 |
+
layer_2_rn = self.scratch.layer2_rn(layer_2)
|
141 |
+
layer_3_rn = self.scratch.layer3_rn(layer_3)
|
142 |
+
layer_4_rn = self.scratch.layer4_rn(layer_4)
|
143 |
+
|
144 |
+
path_4 = self.scratch.refinenet4(layer_4_rn, size=layer_3_rn.shape[2:])
|
145 |
+
path_3 = self.scratch.refinenet3(path_4, layer_3_rn, size=layer_2_rn.shape[2:])
|
146 |
+
path_2 = self.scratch.refinenet2(path_3, layer_2_rn, size=layer_1_rn.shape[2:])
|
147 |
+
path_1 = self.scratch.refinenet1(path_2, layer_1_rn)
|
148 |
+
|
149 |
+
out = self.scratch.output_conv1(path_1)
|
150 |
+
out = F.interpolate(out, (int(patch_h * 14), int(patch_w * 14)), mode="bilinear", align_corners=True)
|
151 |
+
out = self.scratch.output_conv2(out)
|
152 |
+
|
153 |
+
return out
|
154 |
+
|
155 |
+
class DepthAnythingV2(nn.Module):
|
156 |
+
def __init__(
|
157 |
+
self,
|
158 |
+
encoder='vitl',
|
159 |
+
features=256,
|
160 |
+
out_channels=[256, 512, 1024, 1024],
|
161 |
+
use_bn=False,
|
162 |
+
use_clstoken=False
|
163 |
+
):
|
164 |
+
super(DepthAnythingV2, self).__init__()
|
165 |
+
|
166 |
+
self.intermediate_layer_idx = {
|
167 |
+
'vits': [2, 5, 8, 11],
|
168 |
+
'vitb': [2, 5, 8, 11],
|
169 |
+
'vitl': [4, 11, 17, 23],
|
170 |
+
'vitg': [9, 19, 29, 39]
|
171 |
+
}
|
172 |
+
|
173 |
+
self.encoder = encoder
|
174 |
+
self.pretrained = DINOv2(model_name=encoder)
|
175 |
+
|
176 |
+
self.depth_head = DPTHead(self.pretrained.embed_dim, features, use_bn, out_channels=out_channels, use_clstoken=use_clstoken)
|
177 |
+
|
178 |
+
@torch.no_grad()
|
179 |
+
def forward(self, image, input_size=518, device='cuda:0'):
|
180 |
+
x, (h, w) = self.image2tensor(image, input_size, device)
|
181 |
+
|
182 |
+
patch_h, patch_w = x.shape[-2] // 14, x.shape[-1] // 14
|
183 |
+
features = self.pretrained.get_intermediate_layers(x, self.intermediate_layer_idx[self.encoder], return_class_token=True)
|
184 |
+
|
185 |
+
depth = self.depth_head(features, patch_h, patch_w)
|
186 |
+
depth = F.relu(depth).squeeze(1)
|
187 |
+
depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True).squeeze()
|
188 |
+
return depth
|
189 |
+
|
190 |
+
@torch.no_grad()
|
191 |
+
def infer_image(self, raw_image, input_size=518):
|
192 |
+
image, (h, w) = self.image2tensor(raw_image, input_size)
|
193 |
+
|
194 |
+
depth = self.forward(image)
|
195 |
+
|
196 |
+
depth = F.interpolate(depth[:, None], (h, w), mode="bilinear", align_corners=True)[0, 0]
|
197 |
+
|
198 |
+
return depth
|
199 |
+
|
200 |
+
def image2tensor(self, raw_image, input_size=518, device='cuda'):
|
201 |
+
transform = Compose([
|
202 |
+
Resize(
|
203 |
+
width=input_size,
|
204 |
+
height=input_size,
|
205 |
+
resize_target=False,
|
206 |
+
keep_aspect_ratio=True,
|
207 |
+
ensure_multiple_of=14,
|
208 |
+
resize_method='lower_bound',
|
209 |
+
image_interpolation_method=cv2.INTER_CUBIC,
|
210 |
+
),
|
211 |
+
NormalizeImage(mean=[0.485, 0.456, 0.406], std=[0.229, 0.224, 0.225]),
|
212 |
+
PrepareForNet(),
|
213 |
+
])
|
214 |
+
# raw_image (bs, 3, h, w)
|
215 |
+
h, w = raw_image.shape[-2:]
|
216 |
+
raw_image = np.moveaxis(raw_image, 1, -1)
|
217 |
+
images = []
|
218 |
+
for i, single_image in enumerate(raw_image):
|
219 |
+
image = cv2.cvtColor(single_image, cv2.COLOR_BGR2RGB) / 255.0
|
220 |
+
image = transform({'image': image})['image']
|
221 |
+
images.append(torch.from_numpy(image))
|
222 |
+
images = torch.stack(images, dim=0)
|
223 |
+
images = images.to(device)
|
224 |
+
return images, (h, w)
|
depth_anything_v2/moments_dataset.py
ADDED
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Adobe. All rights reserved.
|
2 |
+
|
3 |
+
#%%
|
4 |
+
import glob
|
5 |
+
import torch
|
6 |
+
import torchvision
|
7 |
+
import matplotlib.pyplot as plt
|
8 |
+
from torch.utils.data import Dataset
|
9 |
+
import numpy as np
|
10 |
+
|
11 |
+
|
12 |
+
# %%
|
13 |
+
class MomentsDataset(Dataset):
|
14 |
+
def __init__(self, videos_folder, num_frames, samples_per_video, frame_size=512) -> None:
|
15 |
+
super().__init__()
|
16 |
+
|
17 |
+
self.videos_paths = glob.glob(f'{videos_folder}/*mp4')
|
18 |
+
self.resize = torchvision.transforms.Resize(size=frame_size)
|
19 |
+
self.center_crop = torchvision.transforms.CenterCrop(size=frame_size)
|
20 |
+
self.num_samples_per_video = samples_per_video
|
21 |
+
self.num_frames = num_frames
|
22 |
+
|
23 |
+
def __len__(self):
|
24 |
+
return len(self.videos_paths) * self.num_samples_per_video
|
25 |
+
|
26 |
+
def __getitem__(self, idx):
|
27 |
+
video_idx = idx // self.num_samples_per_video
|
28 |
+
video_path = self.videos_paths[video_idx]
|
29 |
+
|
30 |
+
try:
|
31 |
+
start_idx = np.random.randint(0, 20)
|
32 |
+
|
33 |
+
unsampled_video_frames, audio_frames, info = torchvision.io.read_video(video_path,output_format="TCHW")
|
34 |
+
sampled_indices = torch.tensor(np.linspace(start_idx, len(unsampled_video_frames)-1, self.num_frames).astype(int))
|
35 |
+
sampled_frames = unsampled_video_frames[sampled_indices]
|
36 |
+
processed_frames = []
|
37 |
+
|
38 |
+
for frame in sampled_frames:
|
39 |
+
resized_cropped_frame = self.center_crop(self.resize(frame))
|
40 |
+
processed_frames.append(resized_cropped_frame)
|
41 |
+
frames = torch.stack(processed_frames, dim=0)
|
42 |
+
frames = frames.float() / 255.0
|
43 |
+
except Exception as e:
|
44 |
+
print('oops', e)
|
45 |
+
rand_idx = np.random.randint(0, len(self))
|
46 |
+
return self.__getitem__(rand_idx)
|
47 |
+
|
48 |
+
out_dict = {'frames': frames,
|
49 |
+
'caption': 'none',
|
50 |
+
'keywords': 'none'}
|
51 |
+
|
52 |
+
return out_dict
|
53 |
+
|
54 |
+
|
depth_anything_v2/moments_processing.py
ADDED
@@ -0,0 +1,345 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# Copyright 2024 Adobe. All rights reserved.
|
2 |
+
|
3 |
+
#%%
|
4 |
+
import numpy as np
|
5 |
+
import torchvision
|
6 |
+
import cv2
|
7 |
+
import tqdm
|
8 |
+
import torchvision.transforms.functional as F
|
9 |
+
from PIL import Image
|
10 |
+
from torchvision.utils import save_image
|
11 |
+
import time
|
12 |
+
import os
|
13 |
+
import pathlib
|
14 |
+
from torch.utils.data import DataLoader
|
15 |
+
# %matplotlib inline
|
16 |
+
from kornia.filters.median import MedianBlur
|
17 |
+
|
18 |
+
median_filter = MedianBlur(kernel_size=(15,15))
|
19 |
+
from moments_dataset import MomentsDataset
|
20 |
+
|
21 |
+
try:
|
22 |
+
from processing_utils import aggregate_frames
|
23 |
+
import processing_utils
|
24 |
+
except Exception as e:
|
25 |
+
print(e)
|
26 |
+
print('process failed')
|
27 |
+
exit()
|
28 |
+
|
29 |
+
import torch
|
30 |
+
|
31 |
+
|
32 |
+
# %%
|
33 |
+
|
34 |
+
def load_image(img_path, resize_size=None,crop_size=None):
|
35 |
+
|
36 |
+
img1_pil = Image.open(img_path)
|
37 |
+
img1_frames = torchvision.transforms.functional.pil_to_tensor(img1_pil)
|
38 |
+
|
39 |
+
if resize_size:
|
40 |
+
img1_frames = torchvision.transforms.functional.resize(img1_frames, resize_size)
|
41 |
+
|
42 |
+
if crop_size:
|
43 |
+
img1_frames = torchvision.transforms.functional.center_crop(img1_frames, crop_size)
|
44 |
+
|
45 |
+
img1_batch = torch.unsqueeze(img1_frames, dim=0)
|
46 |
+
|
47 |
+
return img1_batch
|
48 |
+
|
49 |
+
def get_grid(size):
|
50 |
+
y = np.repeat(np.arange(size)[None, ...], size)
|
51 |
+
y = y.reshape(size, size)
|
52 |
+
x = y.transpose()
|
53 |
+
out = np.stack([y,x], -1)
|
54 |
+
return out
|
55 |
+
|
56 |
+
def collage_from_frames(frames_t):
|
57 |
+
# decide forward or backward
|
58 |
+
if np.random.randint(0, 2) == 0:
|
59 |
+
# flip
|
60 |
+
frames_t = frames_t.flip(0)
|
61 |
+
|
62 |
+
# decide how deep you would go
|
63 |
+
tgt_idx_guess = np.random.randint(1, min(len(frames_t), 20))
|
64 |
+
tgt_idx = 1
|
65 |
+
pairwise_flows = []
|
66 |
+
flow = None
|
67 |
+
init_time = time.time()
|
68 |
+
unsmoothed_agg = None
|
69 |
+
for cur_idx in range(1, tgt_idx_guess+1):
|
70 |
+
# cur_idx = i+1
|
71 |
+
cur_flow, pairwise_flows = aggregate_frames(frames_t[:cur_idx+1] , pairwise_flows, unsmoothed_agg) # passing pairwise flows for efficiency
|
72 |
+
unsmoothed_agg = cur_flow.clone()
|
73 |
+
agg_cur_flow = median_filter(cur_flow)
|
74 |
+
|
75 |
+
flow_norm = torch.norm(agg_cur_flow.squeeze(), dim=0).flatten()
|
76 |
+
# flow_10 = np.percentile(flow_norm.cpu().numpy(), 10)
|
77 |
+
flow_90 = np.percentile(flow_norm.cpu().numpy(), 90)
|
78 |
+
|
79 |
+
# flow_10 = np.percentile(flow_norm.cpu().numpy(), 10)
|
80 |
+
flow_90 = np.percentile(flow_norm.cpu().numpy(), 90)
|
81 |
+
flow_95 = np.percentile(flow_norm.cpu().numpy(), 95)
|
82 |
+
|
83 |
+
if cur_idx == 5: # if still small flow then drop
|
84 |
+
if flow_95 < 20.0:
|
85 |
+
# no motion in the frame. skip
|
86 |
+
print('flow is tiny :(')
|
87 |
+
return None
|
88 |
+
|
89 |
+
if cur_idx == tgt_idx_guess-1: # if still small flow then drop
|
90 |
+
if flow_95 < 50.0:
|
91 |
+
# no motion in the frame. skip
|
92 |
+
print('flow is tiny :(')
|
93 |
+
return None
|
94 |
+
|
95 |
+
if flow is None: # means first iter
|
96 |
+
if flow_90 < 1.0:
|
97 |
+
# no motion in the frame. skip
|
98 |
+
return None
|
99 |
+
flow = agg_cur_flow
|
100 |
+
|
101 |
+
if flow_90 <= 300: # maybe should increase this part
|
102 |
+
# update idx
|
103 |
+
tgt_idx = cur_idx
|
104 |
+
flow = agg_cur_flow
|
105 |
+
else:
|
106 |
+
break
|
107 |
+
final_time = time.time()
|
108 |
+
print('time guessing idx', final_time - init_time)
|
109 |
+
|
110 |
+
_, flow_warping_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=None, alpha_mask=None)
|
111 |
+
flow_warping_mask = flow_warping_mask.squeeze().numpy() > 0.5
|
112 |
+
|
113 |
+
if np.mean(flow_warping_mask) < 0.6:
|
114 |
+
return
|
115 |
+
|
116 |
+
|
117 |
+
src_array = frames_t[0].moveaxis(0, -1).cpu().numpy() * 1.0
|
118 |
+
init_time = time.time()
|
119 |
+
depth = get_depth_from_array(frames_t[0])
|
120 |
+
finish_time = time.time()
|
121 |
+
print('time getting depth', finish_time - init_time)
|
122 |
+
# flow, pairwise_flows = aggregate_frames(frames_t)
|
123 |
+
# agg_flow = median_filter(flow)
|
124 |
+
|
125 |
+
src_array_uint = src_array * 255.0
|
126 |
+
src_array_uint = src_array_uint.astype(np.uint8)
|
127 |
+
segments = processing_utils.mask_generator.generate(src_array_uint)
|
128 |
+
|
129 |
+
size = src_array.shape[1]
|
130 |
+
grid_np = get_grid(size).astype(np.float16) / size # 512 x 512 x 2get
|
131 |
+
grid_t = torch.tensor(grid_np).moveaxis(-1, 0) # 512 x 512 x 2
|
132 |
+
|
133 |
+
|
134 |
+
collage, canvas_alpha, lost_alpha = collage_warp(src_array, flow.squeeze(), depth, segments, grid_array=grid_np)
|
135 |
+
lost_alpha_t = torch.tensor(lost_alpha).squeeze().unsqueeze(0)
|
136 |
+
warping_alpha = (lost_alpha_t < 0.5).float()
|
137 |
+
|
138 |
+
rgb_grid_splatted, actual_warped_mask = processing_utils.forward_warp(frames_t[0], frames_t[tgt_idx], flow, grid=grid_t, alpha_mask=warping_alpha)
|
139 |
+
|
140 |
+
|
141 |
+
# basic blending now
|
142 |
+
# print('rgb grid splatted', rgb_grid_splatted.shape)
|
143 |
+
warped_src = (rgb_grid_splatted * actual_warped_mask).moveaxis(0, -1).cpu().numpy()
|
144 |
+
canvas_alpha_mask = canvas_alpha == 0.0
|
145 |
+
collage_mask = canvas_alpha.squeeze() + actual_warped_mask.squeeze().cpu().numpy()
|
146 |
+
collage_mask = collage_mask > 0.5
|
147 |
+
|
148 |
+
composite_grid = warped_src * canvas_alpha_mask + collage
|
149 |
+
rgb_grid_splatted_np = rgb_grid_splatted.moveaxis(0, -1).cpu().numpy()
|
150 |
+
|
151 |
+
return frames_t[0], frames_t[tgt_idx], rgb_grid_splatted_np, composite_grid, flow_warping_mask, collage_mask
|
152 |
+
|
153 |
+
def collage_warp(rgb_array, flow, depth, segments, grid_array):
|
154 |
+
avg_depths = []
|
155 |
+
avg_flows = []
|
156 |
+
|
157 |
+
# src_array = src_array.moveaxis(-1, 0).cpu().numpy() #np.array(Image.open(src_path).convert('RGB')) / 255.0
|
158 |
+
src_array = np.concatenate([rgb_array, grid_array], axis=-1)
|
159 |
+
canvas = np.zeros_like(src_array)
|
160 |
+
canvas_alpha = np.zeros_like(canvas[...,-1:]).astype(float)
|
161 |
+
lost_regions = np.zeros_like(canvas[...,-1:]).astype(float)
|
162 |
+
z_buffer = np.ones_like(depth)[..., None] * -1.0
|
163 |
+
unsqueezed_depth = depth[..., None]
|
164 |
+
|
165 |
+
affine_transforms = []
|
166 |
+
|
167 |
+
filtered_segments = []
|
168 |
+
for segment in segments:
|
169 |
+
if segment['area'] > 300:
|
170 |
+
filtered_segments.append(segment)
|
171 |
+
|
172 |
+
for segment in filtered_segments:
|
173 |
+
seg_mask = segment['segmentation']
|
174 |
+
avg_flow = torch.mean(flow[:, seg_mask],dim=1)
|
175 |
+
avg_flows.append(avg_flow)
|
176 |
+
# median depth (conversion from disparity)
|
177 |
+
avg_depth = torch.median(1.0 / (depth[seg_mask] + 1e-6))
|
178 |
+
avg_depths.append(avg_depth)
|
179 |
+
|
180 |
+
all_y, all_x = np.nonzero(segment['segmentation'])
|
181 |
+
rand_indices = np.random.randint(0, len(all_y), size=50)
|
182 |
+
rand_x = [all_x[i] for i in rand_indices]
|
183 |
+
rand_y = [all_y[i] for i in rand_indices]
|
184 |
+
|
185 |
+
src_pairs = [(x, y) for x, y in zip(rand_x, rand_y)]
|
186 |
+
# tgt_pairs = [(x + w, y) for x, y in src_pairs]
|
187 |
+
tgt_pairs = []
|
188 |
+
# print('estimating affine') # TODO this can be faster
|
189 |
+
for i in range(len(src_pairs)):
|
190 |
+
x, y = src_pairs[i]
|
191 |
+
dx, dy = flow[:, y, x]
|
192 |
+
tgt_pairs.append((x+dx, y+dy))
|
193 |
+
|
194 |
+
# affine_trans, inliers = cv2.estimateAffine2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32))
|
195 |
+
affine_trans, inliers = cv2.estimateAffinePartial2D(np.array(src_pairs).astype(np.float32), np.array(tgt_pairs).astype(np.float32))
|
196 |
+
# print('num inliers', np.sum(inliers))
|
197 |
+
# # print('num inliers', np.sum(inliers))
|
198 |
+
affine_transforms.append(affine_trans)
|
199 |
+
|
200 |
+
depth_sorted_indices = np.arange(len(avg_depths))
|
201 |
+
depth_sorted_indices = sorted(depth_sorted_indices, key=lambda x: avg_depths[x])
|
202 |
+
# sorted_masks = []
|
203 |
+
# print('warping stuff')
|
204 |
+
for idx in depth_sorted_indices:
|
205 |
+
# sorted_masks.append(mask[idx])
|
206 |
+
alpha_mask = filtered_segments[idx]['segmentation'][..., None] * (lost_regions < 0.5).astype(float)
|
207 |
+
src_rgba = np.concatenate([src_array, alpha_mask, unsqueezed_depth], axis=-1)
|
208 |
+
warp_dst = cv2.warpAffine(src_rgba, affine_transforms[idx], (src_array.shape[1], src_array.shape[0]))
|
209 |
+
warped_mask = warp_dst[..., -2:-1] # this is warped alpha
|
210 |
+
warped_depth = warp_dst[..., -1:]
|
211 |
+
warped_rgb = warp_dst[...,:-2]
|
212 |
+
|
213 |
+
good_z_region = warped_depth > z_buffer
|
214 |
+
|
215 |
+
warped_mask = np.logical_and(warped_mask > 0.5, good_z_region).astype(float)
|
216 |
+
|
217 |
+
kernel = np.ones((3,3), float)
|
218 |
+
# print('og masked shape', warped_mask.shape)
|
219 |
+
# warped_mask = cv2.erode(warped_mask,(5,5))[..., None]
|
220 |
+
# print('eroded masked shape', warped_mask.shape)
|
221 |
+
canvas_alpha += cv2.erode(warped_mask,kernel)[..., None]
|
222 |
+
|
223 |
+
lost_regions += alpha_mask
|
224 |
+
canvas = canvas * (1.0 - warped_mask) + warped_mask * warped_rgb # TODO check if need to dialate here
|
225 |
+
z_buffer = z_buffer * (1.0 - warped_mask) + warped_mask * warped_depth # TODO check if need to dialate here # print('max lost region', np.max(lost_regions))
|
226 |
+
return canvas, canvas_alpha, lost_regions
|
227 |
+
|
228 |
+
def get_depth_from_array(img_t):
|
229 |
+
img_arr = img_t.moveaxis(0, -1).cpu().numpy() * 1.0
|
230 |
+
# print(img_arr.shape)
|
231 |
+
img_arr *= 255.0
|
232 |
+
img_arr = img_arr.astype(np.uint8)
|
233 |
+
input_batch = processing_utils.depth_transform(img_arr).cuda()
|
234 |
+
|
235 |
+
with torch.no_grad():
|
236 |
+
prediction = processing_utils.midas(input_batch)
|
237 |
+
|
238 |
+
prediction = torch.nn.functional.interpolate(
|
239 |
+
prediction.unsqueeze(1),
|
240 |
+
size=img_arr.shape[:2],
|
241 |
+
mode="bicubic",
|
242 |
+
align_corners=False,
|
243 |
+
).squeeze()
|
244 |
+
|
245 |
+
output = prediction.cpu()
|
246 |
+
return output
|
247 |
+
|
248 |
+
|
249 |
+
# %%
|
250 |
+
|
251 |
+
def main():
|
252 |
+
print('starting main')
|
253 |
+
video_folder = './example_videos'
|
254 |
+
save_dir = pathlib.Path('./processed_data')
|
255 |
+
process_video_folder(video_folder, save_dir)
|
256 |
+
|
257 |
+
def process_video_folder(video_folder, save_dir):
|
258 |
+
all_counter = 0
|
259 |
+
success_counter = 0
|
260 |
+
|
261 |
+
# save_folder = pathlib.Path('/dev/shm/processed')
|
262 |
+
# save_dir = save_folder / foldername #pathlib.Path('/sensei-fs/users/halzayer/collage2photo/testing_partitioning_dilate_extreme')
|
263 |
+
os.makedirs(save_dir, exist_ok=True)
|
264 |
+
|
265 |
+
dataset = MomentsDataset(videos_folder=video_folder, num_frames=20, samples_per_video=5)
|
266 |
+
batch_size = 4
|
267 |
+
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
|
268 |
+
|
269 |
+
with torch.no_grad():
|
270 |
+
for i, batch in tqdm.tqdm(enumerate(dataloader), total=len(dataset)//batch_size):
|
271 |
+
frames_to_visualize = batch["frames"]
|
272 |
+
bs = frames_to_visualize.shape[0]
|
273 |
+
|
274 |
+
for j in range(bs):
|
275 |
+
frames = frames_to_visualize[j]
|
276 |
+
caption = batch["caption"][j]
|
277 |
+
|
278 |
+
collage_init_time = time.time()
|
279 |
+
out = collage_from_frames(frames)
|
280 |
+
collage_finish_time = time.time()
|
281 |
+
print('collage processing time', collage_finish_time - collage_init_time)
|
282 |
+
all_counter += 1
|
283 |
+
if out is not None:
|
284 |
+
src_image, tgt_image, splatted, collage, flow_mask, collage_mask = out
|
285 |
+
|
286 |
+
splatted_rgb = splatted[...,:3]
|
287 |
+
splatted_grid = splatted[...,3:].astype(np.float16)
|
288 |
+
|
289 |
+
collage_rgb = collage[...,:3]
|
290 |
+
collage_grid = collage[...,3:].astype(np.float16)
|
291 |
+
success_counter += 1
|
292 |
+
else:
|
293 |
+
continue
|
294 |
+
|
295 |
+
id_str = f'{success_counter:08d}'
|
296 |
+
|
297 |
+
src_path = str(save_dir / f'src_{id_str}.png')
|
298 |
+
tgt_path = str(save_dir / f'tgt_{id_str}.png')
|
299 |
+
flow_warped_path = str(save_dir / f'flow_warped_{id_str}.png')
|
300 |
+
composite_path = str(save_dir / f'composite_{id_str}.png')
|
301 |
+
flow_mask_path = str(save_dir / f'flow_mask_{id_str}.png')
|
302 |
+
composite_mask_path = str(save_dir / f'composite_mask_{id_str}.png')
|
303 |
+
|
304 |
+
flow_grid_path = str(save_dir / f'flow_warped_grid_{id_str}.npy')
|
305 |
+
composite_grid_path = str(save_dir / f'composite_grid_{id_str}.npy')
|
306 |
+
|
307 |
+
save_image(src_image, src_path)
|
308 |
+
save_image(tgt_image, tgt_path)
|
309 |
+
|
310 |
+
collage_pil = Image.fromarray((collage_rgb * 255).astype(np.uint8))
|
311 |
+
collage_pil.save(composite_path)
|
312 |
+
|
313 |
+
splatted_pil = Image.fromarray((splatted_rgb * 255).astype(np.uint8))
|
314 |
+
splatted_pil.save(flow_warped_path)
|
315 |
+
|
316 |
+
flow_mask_pil = Image.fromarray((flow_mask.astype(float) * 255).astype(np.uint8))
|
317 |
+
flow_mask_pil.save(flow_mask_path)
|
318 |
+
|
319 |
+
composite_mask_pil = Image.fromarray((collage_mask.astype(float) * 255).astype(np.uint8))
|
320 |
+
composite_mask_pil.save(composite_mask_path)
|
321 |
+
|
322 |
+
splatted_grid_t = torch.tensor(splatted_grid).moveaxis(-1, 0)
|
323 |
+
splatted_grid_resized = torchvision.transforms.functional.resize(splatted_grid_t, (64,64))
|
324 |
+
|
325 |
+
collage_grid_t = torch.tensor(collage_grid).moveaxis(-1, 0)
|
326 |
+
collage_grid_resized = torchvision.transforms.functional.resize(collage_grid_t, (64,64))
|
327 |
+
np.save(flow_grid_path, splatted_grid_resized.cpu().numpy())
|
328 |
+
np.save(composite_grid_path, collage_grid_resized.cpu().numpy())
|
329 |
+
|
330 |
+
del out
|
331 |
+
del splatted_grid
|
332 |
+
del collage_grid
|
333 |
+
del frames
|
334 |
+
|
335 |
+
del frames_to_visualize
|
336 |
+
|
337 |
+
#%%
|
338 |
+
|
339 |
+
if __name__ == '__main__':
|
340 |
+
try:
|
341 |
+
main()
|
342 |
+
except Exception as e:
|
343 |
+
print(e)
|
344 |
+
print('process failed')
|
345 |
+
|
depth_anything_v2/processing_utils.py
ADDED
@@ -0,0 +1,318 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
import torch
|
2 |
+
import cv2
|
3 |
+
import numpy as np
|
4 |
+
import sys
|
5 |
+
import torchvision
|
6 |
+
from PIL import Image
|
7 |
+
from torchvision.models.optical_flow import Raft_Large_Weights
|
8 |
+
from torchvision.models.optical_flow import raft_large
|
9 |
+
from segment_anything import sam_model_registry, SamAutomaticMaskGenerator, SamPredictor
|
10 |
+
import matplotlib.pyplot as plt
|
11 |
+
import torchvision.transforms.functional as F
|
12 |
+
sys.path.append('./softmax-splatting')
|
13 |
+
import softsplat
|
14 |
+
|
15 |
+
sam_checkpoint = "./sam_vit_h_4b8939.pth"
|
16 |
+
model_type = "vit_h"
|
17 |
+
|
18 |
+
device = "cuda"
|
19 |
+
|
20 |
+
sam = sam_model_registry[model_type](checkpoint=sam_checkpoint)
|
21 |
+
sam.to(device=device)
|
22 |
+
# mask_generator = SamAutomaticMaskGenerator(sam,
|
23 |
+
# crop_overlap_ratio=0.05,
|
24 |
+
# box_nms_thresh=0.2,
|
25 |
+
# points_per_side=32,
|
26 |
+
# pred_iou_thresh=0.86,
|
27 |
+
# stability_score_thresh=0.8,
|
28 |
+
|
29 |
+
# min_mask_region_area=100,)
|
30 |
+
# mask_generator = SamAutomaticMaskGenerator(sam)
|
31 |
+
mask_generator = SamAutomaticMaskGenerator(sam,
|
32 |
+
# box_nms_thresh=0.5,
|
33 |
+
# crop_overlap_ratio=0.75,
|
34 |
+
# min_mask_region_area=200,
|
35 |
+
)
|
36 |
+
|
37 |
+
def get_mask(img_path):
|
38 |
+
image = cv2.imread(img_path)
|
39 |
+
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB)
|
40 |
+
masks = mask_generator.generate(image)
|
41 |
+
return masks
|
42 |
+
|
43 |
+
def get_mask_from_array(arr):
|
44 |
+
return mask_generator.generate(arr)
|
45 |
+
|
46 |
+
# depth model
|
47 |
+
|
48 |
+
import cv2
|
49 |
+
import torch
|
50 |
+
import urllib.request
|
51 |
+
|
52 |
+
import matplotlib.pyplot as plt
|
53 |
+
|
54 |
+
# potentially downgrade this. just need rough depths. benchmark this
|
55 |
+
# model_type = "DPT_Large" # MiDaS v3 - Large (highest accuracy, slowest inference speed)
|
56 |
+
# #model_type = "DPT_Hybrid" # MiDaS v3 - Hybrid (medium accuracy, medium inference speed)
|
57 |
+
# #model_type = "MiDaS_small" # MiDaS v2.1 - Small (lowest accuracy, highest inference speed)
|
58 |
+
#
|
59 |
+
# # midas = torch.hub.load("intel-isl/MiDaS", model_type)
|
60 |
+
# midas = torch.hub.load("/sensei-fs/users/halzayer/collage2photo/model_cache/intel-isl_MiDaS_master", model_type, source='local')
|
61 |
+
#
|
62 |
+
# device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
63 |
+
# midas.to(device)
|
64 |
+
# midas.eval()
|
65 |
+
#
|
66 |
+
# # midas_transforms = torch.hub.load("intel-isl/MiDaS", "transforms")
|
67 |
+
# midas_transforms = torch.hub.load("/sensei-fs/users/halzayer/collage2photo/model_cache/intel-isl_MiDaS_master", "transforms", source='local')
|
68 |
+
#
|
69 |
+
# if model_type == "DPT_Large" or model_type == "DPT_Hybrid":
|
70 |
+
# depth_transform = midas_transforms.dpt_transform
|
71 |
+
# else:
|
72 |
+
# depth_transform = midas_transforms.small_transform
|
73 |
+
from dpt import DepthAnythingV2
|
74 |
+
|
75 |
+
model_configs = {
|
76 |
+
'vits': {'encoder': 'vits', 'features': 64, 'out_channels': [48, 96, 192, 384]},
|
77 |
+
'vitb': {'encoder': 'vitb', 'features': 128, 'out_channels': [96, 192, 384, 768]},
|
78 |
+
'vitl': {'encoder': 'vitl', 'features': 256, 'out_channels': [256, 512, 1024, 1024]},
|
79 |
+
'vitg': {'encoder': 'vitg', 'features': 384, 'out_channels': [1536, 1536, 1536, 1536]}
|
80 |
+
}
|
81 |
+
|
82 |
+
depth_anything = DepthAnythingV2(**model_configs['vitl'])
|
83 |
+
depth_anything.load_state_dict(torch.load(f'/home/aiops/wangzh/Depth-Anything-V2/checkpoints/depth_anything_v2_vitl.pth', map_location='cpu'))
|
84 |
+
depth_anything = depth_anything.to(device).eval()
|
85 |
+
|
86 |
+
# img_path = '/sensei-fs/users/halzayer/valid/JPEGImages/45597680/00005.jpg'
|
87 |
+
def get_depth(img_path):
|
88 |
+
img = cv2.imread(img_path)
|
89 |
+
|
90 |
+
with torch.no_grad():
|
91 |
+
depth = depth_anything.infer_image(img, 518)
|
92 |
+
|
93 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
94 |
+
depth = depth.cpu().numpy().astype(np.uint8)
|
95 |
+
|
96 |
+
prediction = torch.nn.functional.interpolate(
|
97 |
+
depth.unsqueeze(1),
|
98 |
+
size=img.shape[:2],
|
99 |
+
mode="bicubic",
|
100 |
+
align_corners=False,
|
101 |
+
).squeeze()
|
102 |
+
|
103 |
+
output = prediction.cpu()
|
104 |
+
return output
|
105 |
+
|
106 |
+
def get_depth_from_array(img):
|
107 |
+
input_batch = img.to(device)
|
108 |
+
|
109 |
+
with torch.no_grad():
|
110 |
+
depth = depth_anything.infer_image(input_batch, 518)
|
111 |
+
|
112 |
+
depth = (depth - depth.min()) / (depth.max() - depth.min()) * 255.0
|
113 |
+
depth = depth.cpu().numpy().astype(np.uint8)
|
114 |
+
|
115 |
+
prediction = torch.nn.functional.interpolate(
|
116 |
+
depth.unsqueeze(1),
|
117 |
+
size=img.shape[:2],
|
118 |
+
mode="bicubic",
|
119 |
+
align_corners=False,
|
120 |
+
).squeeze()
|
121 |
+
|
122 |
+
output = prediction.cpu()
|
123 |
+
return output
|
124 |
+
|
125 |
+
|
126 |
+
def load_image(img_path):
|
127 |
+
img1_names = [img_path]
|
128 |
+
|
129 |
+
img1_pil = [Image.open(fn) for fn in img1_names]
|
130 |
+
img1_frames = [torchvision.transforms.functional.pil_to_tensor(fn) for fn in img1_pil]
|
131 |
+
|
132 |
+
img1_batch = torch.stack(img1_frames)
|
133 |
+
|
134 |
+
return img1_batch
|
135 |
+
|
136 |
+
weights = Raft_Large_Weights.DEFAULT
|
137 |
+
transforms = weights.transforms()
|
138 |
+
|
139 |
+
device = "cuda" if torch.cuda.is_available() else "cpu"
|
140 |
+
|
141 |
+
model = raft_large(weights=Raft_Large_Weights.DEFAULT, progress=False).to(device)
|
142 |
+
model = model.eval()
|
143 |
+
|
144 |
+
print('created model')
|
145 |
+
|
146 |
+
def preprocess(img1_batch, img2_batch, size=[520,960], transform_batch=True):
|
147 |
+
img1_batch = F.resize(img1_batch, size=size, antialias=False)
|
148 |
+
img2_batch = F.resize(img2_batch, size=size, antialias=False)
|
149 |
+
if transform_batch:
|
150 |
+
return transforms(img1_batch, img2_batch)
|
151 |
+
else:
|
152 |
+
return img1_batch, img2_batch
|
153 |
+
|
154 |
+
def compute_flow(img_path_1, img_path_2):
|
155 |
+
img1_batch_og, img2_batch_og = load_image(img_path_1), load_image(img_path_2)
|
156 |
+
B, C, H, W = img1_batch_og.shape
|
157 |
+
|
158 |
+
img1_batch, img2_batch = preprocess(img1_batch_og, img2_batch_og, transform_batch=False)
|
159 |
+
img1_batch_t, img2_batch_t = transforms(img1_batch, img2_batch)
|
160 |
+
|
161 |
+
# If you can, run this example on a GPU, it will be a lot faster.
|
162 |
+
with torch.no_grad():
|
163 |
+
list_of_flows = model(img1_batch_t.to(device), img2_batch_t.to(device))
|
164 |
+
predicted_flows = list_of_flows[-1]
|
165 |
+
# flows.append(predicted_flows)
|
166 |
+
|
167 |
+
resized_flow = F.resize(predicted_flows, size=(H, W), antialias=False)
|
168 |
+
|
169 |
+
_, _, flow_H, flow_W = predicted_flows.shape
|
170 |
+
|
171 |
+
resized_flow[:,0] *= (W / flow_W)
|
172 |
+
resized_flow[:,1] *= (H / flow_H)
|
173 |
+
|
174 |
+
return resized_flow.detach().cpu().squeeze()
|
175 |
+
|
176 |
+
def compute_flow_from_tensors(img1_batch_og, img2_batch_og):
|
177 |
+
if len(img1_batch_og.shape) < 4:
|
178 |
+
img1_batch_og = img1_batch_og.unsqueeze(0)
|
179 |
+
if len(img2_batch_og.shape) < 4:
|
180 |
+
img2_batch_og = img2_batch_og.unsqueeze(0)
|
181 |
+
|
182 |
+
B, C, H, W = img1_batch_og.shape
|
183 |
+
img1_batch, img2_batch = preprocess(img1_batch_og, img2_batch_og, transform_batch=False)
|
184 |
+
img1_batch_t, img2_batch_t = transforms(img1_batch, img2_batch)
|
185 |
+
|
186 |
+
# If you can, run this example on a GPU, it will be a lot faster.
|
187 |
+
with torch.no_grad():
|
188 |
+
list_of_flows = model(img1_batch_t.to(device), img2_batch_t.to(device))
|
189 |
+
predicted_flows = list_of_flows[-1]
|
190 |
+
# flows.append(predicted_flows)
|
191 |
+
|
192 |
+
resized_flow = F.resize(predicted_flows, size=(H, W), antialias=False)
|
193 |
+
|
194 |
+
_, _, flow_H, flow_W = predicted_flows.shape
|
195 |
+
|
196 |
+
resized_flow[:,0] *= (W / flow_W)
|
197 |
+
resized_flow[:,1] *= (H / flow_H)
|
198 |
+
|
199 |
+
return resized_flow.detach().cpu().squeeze()
|
200 |
+
|
201 |
+
|
202 |
+
|
203 |
+
# import run
|
204 |
+
backwarp_tenGrid = {}
|
205 |
+
|
206 |
+
def backwarp(tenIn, tenFlow):
|
207 |
+
if str(tenFlow.shape) not in backwarp_tenGrid:
|
208 |
+
tenHor = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[3], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, 1, -1).repeat(1, 1, tenFlow.shape[2], 1)
|
209 |
+
tenVer = torch.linspace(start=-1.0, end=1.0, steps=tenFlow.shape[2], dtype=tenFlow.dtype, device=tenFlow.device).view(1, 1, -1, 1).repeat(1, 1, 1, tenFlow.shape[3])
|
210 |
+
|
211 |
+
backwarp_tenGrid[str(tenFlow.shape)] = torch.cat([tenHor, tenVer], 1).cuda()
|
212 |
+
# end
|
213 |
+
|
214 |
+
tenFlow = torch.cat([tenFlow[:, 0:1, :, :] / ((tenIn.shape[3] - 1.0) / 2.0), tenFlow[:, 1:2, :, :] / ((tenIn.shape[2] - 1.0) / 2.0)], 1)
|
215 |
+
|
216 |
+
return torch.nn.functional.grid_sample(input=tenIn, grid=(backwarp_tenGrid[str(tenFlow.shape)] + tenFlow).permute(0, 2, 3, 1), mode='bilinear', padding_mode='zeros', align_corners=True)
|
217 |
+
|
218 |
+
torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
|
219 |
+
|
220 |
+
##########################################################
|
221 |
+
def forward_splt(src, tgt, flow, partial=False):
|
222 |
+
tenTwo = tgt.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/one.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda()
|
223 |
+
tenOne = src.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/two.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda()
|
224 |
+
tenFlow = flow.unsqueeze(0).cuda() #torch.FloatTensor(numpy.ascontiguousarray(run.read_flo('./images/flow.flo').transpose(2, 0, 1)[None, :, :, :])).cuda()
|
225 |
+
|
226 |
+
if not partial:
|
227 |
+
tenMetric = torch.nn.functional.l1_loss(input=tenOne, target=backwarp(tenIn=tenTwo, tenFlow=tenFlow), reduction='none').mean([1], True)
|
228 |
+
else:
|
229 |
+
tenMetric = torch.nn.functional.l1_loss(input=tenOne[:,:3], target=backwarp(tenIn=tenTwo[:,:3], tenFlow=tenFlow[:,:3]), reduction='none').mean([1], True)
|
230 |
+
# for intTime, fltTime in enumerate(np.linspace(0.0, 1.0, 11).tolist()):
|
231 |
+
tenSoftmax = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow , tenMetric=(-20.0 * tenMetric).clip(-20.0, 20.0), strMode='soft') # -20.0 is a hyperparameter, called 'alpha' in the paper, that could be learned using a torch.Parameter
|
232 |
+
|
233 |
+
return tenSoftmax.cpu()
|
234 |
+
|
235 |
+
|
236 |
+
def aggregate_frames(frames, pairwise_flows=None, agg_flow=None):
|
237 |
+
if pairwise_flows is None:
|
238 |
+
# store pairwise flows
|
239 |
+
pairwise_flows = []
|
240 |
+
|
241 |
+
if agg_flow is None:
|
242 |
+
start_idx = 0
|
243 |
+
else:
|
244 |
+
start_idx = len(pairwise_flows)
|
245 |
+
|
246 |
+
og_image = frames[start_idx]
|
247 |
+
prev_frame = og_image
|
248 |
+
|
249 |
+
for i in range(start_idx, len(frames)-1):
|
250 |
+
tgt_frame = frames[i+1]
|
251 |
+
|
252 |
+
if i < len(pairwise_flows):
|
253 |
+
flow = pairwise_flows[i]
|
254 |
+
else:
|
255 |
+
flow = compute_flow_from_tensors(prev_frame, tgt_frame)
|
256 |
+
pairwise_flows.append(flow.clone())
|
257 |
+
|
258 |
+
_, H, W = flow.shape
|
259 |
+
B=1
|
260 |
+
|
261 |
+
xx = torch.arange(0, W).view(1,-1).repeat(H,1)
|
262 |
+
|
263 |
+
yy = torch.arange(0, H).view(-1,1).repeat(1,W)
|
264 |
+
|
265 |
+
xx = xx.view(1,1,H,W).repeat(B,1,1,1)
|
266 |
+
|
267 |
+
yy = yy.view(1,1,H,W).repeat(B,1,1,1)
|
268 |
+
|
269 |
+
grid = torch.cat((xx,yy),1).float()
|
270 |
+
|
271 |
+
flow = flow.unsqueeze(0)
|
272 |
+
if agg_flow is None:
|
273 |
+
agg_flow = torch.zeros_like(flow)
|
274 |
+
|
275 |
+
vgrid = grid + agg_flow
|
276 |
+
vgrid[:,0,:,:] = 2.0*vgrid[:,0,:,:].clone() / max(W-1,1) - 1
|
277 |
+
|
278 |
+
vgrid[:,1,:,:] = 2.0*vgrid[:,1,:,:].clone() / max(H-1,1) - 1
|
279 |
+
|
280 |
+
flow_out = torch.nn.functional.grid_sample(flow, vgrid.permute(0,2,3,1), 'nearest')
|
281 |
+
|
282 |
+
agg_flow += flow_out
|
283 |
+
|
284 |
+
|
285 |
+
# mask = forward_splt(torch.ones_like(og_image), torch.ones_like(og_image), agg_flow.squeeze()).squeeze()
|
286 |
+
# blur_t = torchvision.transforms.GaussianBlur(kernel_size=(25,25), sigma=5.0)
|
287 |
+
# warping_mask = (blur_t(mask)[0:1] > 0.8)
|
288 |
+
# masks.append(warping_mask)
|
289 |
+
prev_frame = tgt_frame
|
290 |
+
|
291 |
+
return agg_flow, pairwise_flows #og_splatted_img, agg_flow, actual_warped_mask
|
292 |
+
|
293 |
+
|
294 |
+
def forward_warp(src_frame, tgt_frame, flow, grid=None, alpha_mask=None):
|
295 |
+
if alpha_mask is None:
|
296 |
+
alpha_mask = torch.ones_like(src_frame[:1])
|
297 |
+
|
298 |
+
if grid is not None:
|
299 |
+
src_list = [src_frame, grid, alpha_mask]
|
300 |
+
tgt_list = [tgt_frame, grid, alpha_mask]
|
301 |
+
else:
|
302 |
+
src_list = [src_frame, alpha_mask]
|
303 |
+
tgt_list = [tgt_frame, alpha_mask]
|
304 |
+
|
305 |
+
og_image_padded = torch.concat(src_list, dim=0)
|
306 |
+
tgt_frame_padded = torch.concat(tgt_list, dim=0)
|
307 |
+
|
308 |
+
og_splatted_img = forward_splt(og_image_padded, tgt_frame_padded, flow.squeeze(), partial=True).squeeze()
|
309 |
+
# print('og splatted image shape')
|
310 |
+
# grid_transformed = og_splatted_img[3:-1]
|
311 |
+
# print('grid transformed shape', grid_transformed)
|
312 |
+
|
313 |
+
# grid *= grid_size
|
314 |
+
# grid_transformed *= grid_size
|
315 |
+
actual_warped_mask = og_splatted_img[-1:]
|
316 |
+
splatted_rgb_grid = og_splatted_img[:-1]
|
317 |
+
|
318 |
+
return splatted_rgb_grid, actual_warped_mask
|
depth_anything_v2/softmax-splatting/README.md
ADDED
@@ -0,0 +1,90 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
# softmax-splatting
|
2 |
+
This is a reference implementation of the softmax splatting operator, which has been proposed in Softmax Splatting for Video Frame Interpolation [1], using PyTorch. Softmax splatting is a well-motivated approach for differentiable forward warping. It uses a translational invariant importance metric to disambiguate cases where multiple source pixels map to the same target pixel. Should you be making use of our work, please cite our paper [1].
|
3 |
+
|
4 |
+
<a href="https://arxiv.org/abs/2003.05534" rel="Paper"><img src="http://content.sniklaus.com/softsplat/paper.jpg" alt="Paper" width="100%"></a>
|
5 |
+
|
6 |
+
For our previous work on SepConv, see: https://github.com/sniklaus/revisiting-sepconv
|
7 |
+
|
8 |
+
## setup
|
9 |
+
The softmax splatting is implemented in CUDA using CuPy, which is why CuPy is a required dependency. It can be installed using `pip install cupy` or alternatively using one of the provided [binary packages](https://docs.cupy.dev/en/stable/install.html#installing-cupy) as outlined in the CuPy repository.
|
10 |
+
|
11 |
+
If you plan to process videos, then please also make sure to have `pip install moviepy` installed.
|
12 |
+
|
13 |
+
## usage
|
14 |
+
To run it on your own pair of frames, use the following command.
|
15 |
+
|
16 |
+
```
|
17 |
+
python run.py --model lf --one ./images/one.png --two ./images/two.png --out ./out.png
|
18 |
+
```
|
19 |
+
|
20 |
+
To run in on a video, use the following command.
|
21 |
+
|
22 |
+
```
|
23 |
+
python run.py --model lf --video ./videos/car-turn.mp4 --out ./out.mp4
|
24 |
+
```
|
25 |
+
|
26 |
+
For a quick benchmark using examples from the Middlebury benchmark for optical flow, run `python benchmark_middlebury.py`. You can use it to easily verify that the provided implementation runs as expected.
|
27 |
+
|
28 |
+
## warping
|
29 |
+
We provide a small script to replicate the third figure of our paper [1]. You can simply run the following to obtain the comparison between summation splatting, average splatting, linear splatting, and softmax splatting.
|
30 |
+
|
31 |
+
The example script is using OpenCV to load and display images, as well as to read the provided optical flow file. An easy way to install OpenCV for Python is using the `pip install opencv-contrib-python` package.
|
32 |
+
|
33 |
+
```
|
34 |
+
import cv2
|
35 |
+
import numpy
|
36 |
+
import torch
|
37 |
+
|
38 |
+
import run
|
39 |
+
|
40 |
+
import softsplat # the custom softmax splatting layer
|
41 |
+
|
42 |
+
##########################################################
|
43 |
+
|
44 |
+
torch.set_grad_enabled(False) # make sure to not compute gradients for computational performance
|
45 |
+
|
46 |
+
torch.backends.cudnn.enabled = True # make sure to use cudnn for computational performance
|
47 |
+
|
48 |
+
##########################################################
|
49 |
+
|
50 |
+
tenOne = torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/one.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda()
|
51 |
+
tenTwo = torch.FloatTensor(numpy.ascontiguousarray(cv2.imread(filename='./images/two.png', flags=-1).transpose(2, 0, 1)[None, :, :, :].astype(numpy.float32) * (1.0 / 255.0))).cuda()
|
52 |
+
tenFlow = torch.FloatTensor(numpy.ascontiguousarray(run.read_flo('./images/flow.flo').transpose(2, 0, 1)[None, :, :, :])).cuda()
|
53 |
+
|
54 |
+
tenMetric = torch.nn.functional.l1_loss(input=tenOne, target=run.backwarp(tenIn=tenTwo, tenFlow=tenFlow), reduction='none').mean([1], True)
|
55 |
+
|
56 |
+
for intTime, fltTime in enumerate(numpy.linspace(0.0, 1.0, 11).tolist()):
|
57 |
+
tenSummation = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=None, strMode='sum')
|
58 |
+
tenAverage = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=None, strMode='avg')
|
59 |
+
tenLinear = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=(0.3 - tenMetric).clip(0.001, 1.0), strMode='linear') # finding a good linearly metric is difficult, and it is not invariant to translations
|
60 |
+
tenSoftmax = softsplat.softsplat(tenIn=tenOne, tenFlow=tenFlow * fltTime, tenMetric=(-20.0 * tenMetric).clip(-20.0, 20.0), strMode='soft') # -20.0 is a hyperparameter, called 'alpha' in the paper, that could be learned using a torch.Parameter
|
61 |
+
|
62 |
+
cv2.imshow(winname='summation', mat=tenSummation[0, :, :, :].cpu().numpy().transpose(1, 2, 0))
|
63 |
+
cv2.imshow(winname='average', mat=tenAverage[0, :, :, :].cpu().numpy().transpose(1, 2, 0))
|
64 |
+
cv2.imshow(winname='linear', mat=tenLinear[0, :, :, :].cpu().numpy().transpose(1, 2, 0))
|
65 |
+
cv2.imshow(winname='softmax', mat=tenSoftmax[0, :, :, :].cpu().numpy().transpose(1, 2, 0))
|
66 |
+
cv2.waitKey(delay=0)
|
67 |
+
# end
|
68 |
+
```
|
69 |
+
|
70 |
+
## xiph
|
71 |
+
In our paper, we propose to use 4K video clips from Xiph to evaluate video frame interpolation on high-resolution footage. Please see the supplementary `benchmark_xiph.py` on how to reproduce the shown metrics.
|
72 |
+
|
73 |
+
## video
|
74 |
+
<a href="http://content.sniklaus.com/softsplat/video.mp4" rel="Video"><img src="http://content.sniklaus.com/softsplat/video.jpg" alt="Video" width="100%"></a>
|
75 |
+
|
76 |
+
## license
|
77 |
+
The provided implementation is strictly for academic purposes only. Should you be interested in using our technology for any commercial use, please feel free to contact us.
|
78 |
+
|
79 |
+
## references
|
80 |
+
```
|
81 |
+
[1] @inproceedings{Niklaus_CVPR_2020,
|
82 |
+
author = {Simon Niklaus and Feng Liu},
|
83 |
+
title = {Softmax Splatting for Video Frame Interpolation},
|
84 |
+
booktitle = {IEEE Conference on Computer Vision and Pattern Recognition},
|
85 |
+
year = {2020}
|
86 |
+
}
|
87 |
+
```
|
88 |
+
|
89 |
+
## acknowledgment
|
90 |
+
The video above uses materials under a Creative Common license as detailed at the end.
|
depth_anything_v2/softmax-splatting/__pycache__/softsplat.cpython-310.pyc
ADDED
Binary file (18.2 kB). View file
|
|
depth_anything_v2/softmax-splatting/benchmark_middlebury.py
ADDED
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import glob
|
4 |
+
import numpy
|
5 |
+
import PIL
|
6 |
+
import PIL.Image
|
7 |
+
import skimage
|
8 |
+
import skimage.metrics
|
9 |
+
import torch
|
10 |
+
|
11 |
+
import run
|
12 |
+
|
13 |
+
##########################################################
|
14 |
+
|
15 |
+
run.args_strModel = 'l1'
|
16 |
+
|
17 |
+
##########################################################
|
18 |
+
|
19 |
+
if __name__ == '__main__':
|
20 |
+
fltPsnr = []
|
21 |
+
fltSsim = []
|
22 |
+
|
23 |
+
for strTruth in sorted(glob.glob('./middlebury/*/frame10i11.png')):
|
24 |
+
tenOne = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(strTruth.replace('frame10i11', 'frame10')))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
|
25 |
+
tenTwo = torch.FloatTensor(numpy.ascontiguousarray(numpy.array(PIL.Image.open(strTruth.replace('frame10i11', 'frame11')))[:, :, ::-1].transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
|
26 |
+
|
27 |
+
npyEstimate = (run.estimate(tenOne, tenTwo, [0.5])[0].clip(0.0, 1.0).numpy().transpose(1, 2, 0) * 255.0).round().astype(numpy.uint8)
|
28 |
+
|
29 |
+
fltPsnr.append(skimage.metrics.peak_signal_noise_ratio(image_true=numpy.array(PIL.Image.open(strTruth))[:, :, ::-1], image_test=npyEstimate, data_range=255))
|
30 |
+
fltSsim.append(skimage.metrics.structural_similarity(im1=numpy.array(PIL.Image.open(strTruth))[:, :, ::-1], im2=npyEstimate, data_range=255, channel_axis=2))
|
31 |
+
# end
|
32 |
+
|
33 |
+
print('computed average psnr', numpy.mean(fltPsnr))
|
34 |
+
print('computed average ssim', numpy.mean(fltSsim))
|
35 |
+
# end
|
depth_anything_v2/softmax-splatting/benchmark_xiph.py
ADDED
@@ -0,0 +1,91 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import cv2
|
4 |
+
import glob
|
5 |
+
import numpy
|
6 |
+
import os
|
7 |
+
import skimage
|
8 |
+
import skimage.metrics
|
9 |
+
import sys
|
10 |
+
import torch
|
11 |
+
|
12 |
+
import run
|
13 |
+
|
14 |
+
##########################################################
|
15 |
+
|
16 |
+
run.args_strModel = 'l1'
|
17 |
+
|
18 |
+
##########################################################
|
19 |
+
|
20 |
+
os.makedirs(name='./netflix', exist_ok=True)
|
21 |
+
|
22 |
+
if len(glob.glob('./netflix/BoxingPractice-*.png')) != 100:
|
23 |
+
os.system('ffmpeg -i https://media.xiph.org/video/derf/ElFuente/Netflix_BoxingPractice_4096x2160_60fps_10bit_420.y4m -pix_fmt rgb24 -vframes 100 ./netflix/BoxingPractice-%03d.png')
|
24 |
+
# end
|
25 |
+
|
26 |
+
if len(glob.glob('./netflix/Crosswalk-*.png')) != 100:
|
27 |
+
os.system('ffmpeg -i https://media.xiph.org/video/derf/ElFuente/Netflix_Crosswalk_4096x2160_60fps_10bit_420.y4m -pix_fmt rgb24 -vframes 100 ./netflix/Crosswalk-%03d.png')
|
28 |
+
# end
|
29 |
+
|
30 |
+
if len(glob.glob('./netflix/DrivingPOV-*.png')) != 100:
|
31 |
+
os.system('ffmpeg -i https://media.xiph.org/video/derf/Chimera/Netflix_DrivingPOV_4096x2160_60fps_10bit_420.y4m -pix_fmt rgb24 -vframes 100 ./netflix/DrivingPOV-%03d.png')
|
32 |
+
# end
|
33 |
+
|
34 |
+
if len(glob.glob('./netflix/FoodMarket-*.png')) != 100:
|
35 |
+
os.system('ffmpeg -i https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket_4096x2160_60fps_10bit_420.y4m -pix_fmt rgb24 -vframes 100 ./netflix/FoodMarket-%03d.png')
|
36 |
+
# end
|
37 |
+
|
38 |
+
if len(glob.glob('./netflix/FoodMarket2-*.png')) != 100:
|
39 |
+
os.system('ffmpeg -i https://media.xiph.org/video/derf/ElFuente/Netflix_FoodMarket2_4096x2160_60fps_10bit_420.y4m -pix_fmt rgb24 -vframes 100 ./netflix/FoodMarket2-%03d.png')
|
40 |
+
# end
|
41 |
+
|
42 |
+
if len(glob.glob('./netflix/RitualDance-*.png')) != 100:
|
43 |
+
os.system('ffmpeg -i https://media.xiph.org/video/derf/ElFuente/Netflix_RitualDance_4096x2160_60fps_10bit_420.y4m -pix_fmt rgb24 -vframes 100 ./netflix/RitualDance-%03d.png')
|
44 |
+
# end
|
45 |
+
|
46 |
+
if len(glob.glob('./netflix/SquareAndTimelapse-*.png')) != 100:
|
47 |
+
os.system('ffmpeg -i https://media.xiph.org/video/derf/ElFuente/Netflix_SquareAndTimelapse_4096x2160_60fps_10bit_420.y4m -pix_fmt rgb24 -vframes 100 ./netflix/SquareAndTimelapse-%03d.png')
|
48 |
+
# end
|
49 |
+
|
50 |
+
if len(glob.glob('./netflix/Tango-*.png')) != 100:
|
51 |
+
os.system('ffmpeg -i https://media.xiph.org/video/derf/ElFuente/Netflix_Tango_4096x2160_60fps_10bit_420.y4m -pix_fmt rgb24 -vframes 100 ./netflix/Tango-%03d.png')
|
52 |
+
# end
|
53 |
+
|
54 |
+
##########################################################
|
55 |
+
|
56 |
+
for strCategory in ['resized', 'cropped']:
|
57 |
+
fltPsnr = []
|
58 |
+
fltSsim = []
|
59 |
+
|
60 |
+
for strFile in ['BoxingPractice', 'Crosswalk', 'DrivingPOV', 'FoodMarket', 'FoodMarket2', 'RitualDance', 'SquareAndTimelapse', 'Tango']:
|
61 |
+
for intFrame in range(2, 99, 2):
|
62 |
+
npyOne = cv2.imread(filename='./netflix/' + strFile + '-' + str(intFrame - 1).zfill(3) + '.png', flags=-1)
|
63 |
+
npyTwo = cv2.imread(filename='./netflix/' + strFile + '-' + str(intFrame + 1).zfill(3) + '.png', flags=-1)
|
64 |
+
npyTruth = cv2.imread(filename='./netflix/' + strFile + '-' + str(intFrame).zfill(3) + '.png', flags=-1)
|
65 |
+
|
66 |
+
if strCategory == 'resized':
|
67 |
+
npyOne = cv2.resize(src=npyOne, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)
|
68 |
+
npyTwo = cv2.resize(src=npyTwo, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)
|
69 |
+
npyTruth = cv2.resize(src=npyTruth, dsize=(2048, 1080), fx=0.0, fy=0.0, interpolation=cv2.INTER_AREA)
|
70 |
+
|
71 |
+
elif strCategory == 'cropped':
|
72 |
+
npyOne = npyOne[540:-540, 1024:-1024, :]
|
73 |
+
npyTwo = npyTwo[540:-540, 1024:-1024, :]
|
74 |
+
npyTruth = npyTruth[540:-540, 1024:-1024, :]
|
75 |
+
|
76 |
+
# end
|
77 |
+
|
78 |
+
tenOne = torch.FloatTensor(numpy.ascontiguousarray(npyOne.transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
|
79 |
+
tenTwo = torch.FloatTensor(numpy.ascontiguousarray(npyTwo.transpose(2, 0, 1).astype(numpy.float32) * (1.0 / 255.0)))
|
80 |
+
|
81 |
+
npyEstimate = (run.estimate(tenOne, tenTwo, [0.5])[0].clip(0.0, 1.0).numpy().transpose(1, 2, 0) * 255.0).round().astype(numpy.uint8)
|
82 |
+
|
83 |
+
fltPsnr.append(skimage.metrics.peak_signal_noise_ratio(image_true=npyTruth, image_test=npyEstimate, data_range=255))
|
84 |
+
fltSsim.append(skimage.metrics.structural_similarity(im1=npyTruth, im2=npyEstimate, data_range=255, channel_axis=2))
|
85 |
+
# end
|
86 |
+
# end
|
87 |
+
|
88 |
+
print('category', strCategory)
|
89 |
+
print('computed average psnr', numpy.mean(fltPsnr))
|
90 |
+
print('computed average ssim', numpy.mean(fltSsim))
|
91 |
+
# end
|
depth_anything_v2/softmax-splatting/correlation/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
This is an adaptation of the <a href="https://github.com/lmb-freiburg/flownet2">FlowNet2 implementation</a> in order to compute cost volumes. Should you be making use of this work, please make sure to adhere to the <a href="https://github.com/lmb-freiburg/flownet2#license-and-citation">licensing terms</a> of the original authors. Should you be making use or modify this particular implementation, please acknowledge it appropriately.
|
depth_anything_v2/softmax-splatting/correlation/correlation.py
ADDED
@@ -0,0 +1,400 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
1 |
+
#!/usr/bin/env python
|
2 |
+
|
3 |
+
import cupy
|
4 |
+
import os
|
5 |
+
import re
|
6 |
+
import torch
|
7 |
+
|
8 |
+
kernel_Correlation_rearrange = '''
|
9 |
+
extern "C" __global__ void kernel_Correlation_rearrange(
|
10 |
+
const int n,
|
11 |
+
const float* input,
|
12 |
+
float* output
|
13 |
+
) {
|
14 |
+
int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x;
|
15 |
+
|
16 |
+
if (intIndex >= n) {
|
17 |
+
return;
|
18 |
+
}
|
19 |
+
|
20 |
+
int intSample = blockIdx.z;
|
21 |
+
int intChannel = blockIdx.y;
|
22 |
+
|
23 |
+
float fltValue = input[(((intSample * SIZE_1(input)) + intChannel) * SIZE_2(input) * SIZE_3(input)) + intIndex];
|
24 |
+
|
25 |
+
__syncthreads();
|
26 |
+
|
27 |
+
int intPaddedY = (intIndex / SIZE_3(input)) + 4;
|
28 |
+
int intPaddedX = (intIndex % SIZE_3(input)) + 4;
|
29 |
+
int intRearrange = ((SIZE_3(input) + 8) * intPaddedY) + intPaddedX;
|
30 |
+
|
31 |
+
output[(((intSample * SIZE_1(output) * SIZE_2(output)) + intRearrange) * SIZE_1(input)) + intChannel] = fltValue;
|
32 |
+
}
|
33 |
+
'''
|
34 |
+
|
35 |
+
kernel_Correlation_updateOutput = '''
|
36 |
+
extern "C" __global__ void kernel_Correlation_updateOutput(
|
37 |
+
const int n,
|
38 |
+
const float* rbot0,
|
39 |
+
const float* rbot1,
|
40 |
+
float* top
|
41 |
+
) {
|
42 |
+
extern __shared__ char patch_data_char[];
|
43 |
+
|
44 |
+
float *patch_data = (float *)patch_data_char;
|
45 |
+
|
46 |
+
// First (upper left) position of kernel upper-left corner in current center position of neighborhood in image 1
|
47 |
+
int x1 = blockIdx.x + 4;
|
48 |
+
int y1 = blockIdx.y + 4;
|
49 |
+
int item = blockIdx.z;
|
50 |
+
int ch_off = threadIdx.x;
|
51 |
+
|
52 |
+
// Load 3D patch into shared shared memory
|
53 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
54 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
55 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
56 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
57 |
+
int idx1 = ((item * SIZE_1(rbot0) + y1+j) * SIZE_2(rbot0) + x1+i) * SIZE_3(rbot0) + ch;
|
58 |
+
int idxPatchData = ji_off + ch;
|
59 |
+
patch_data[idxPatchData] = rbot0[idx1];
|
60 |
+
}
|
61 |
+
}
|
62 |
+
}
|
63 |
+
|
64 |
+
__syncthreads();
|
65 |
+
|
66 |
+
__shared__ float sum[32];
|
67 |
+
|
68 |
+
// Compute correlation
|
69 |
+
for (int top_channel = 0; top_channel < SIZE_1(top); top_channel++) {
|
70 |
+
sum[ch_off] = 0;
|
71 |
+
|
72 |
+
int s2o = top_channel % 9 - 4;
|
73 |
+
int s2p = top_channel / 9 - 4;
|
74 |
+
|
75 |
+
for (int j = 0; j < 1; j++) { // HEIGHT
|
76 |
+
for (int i = 0; i < 1; i++) { // WIDTH
|
77 |
+
int ji_off = (j + i) * SIZE_3(rbot0);
|
78 |
+
for (int ch = ch_off; ch < SIZE_3(rbot0); ch += 32) { // CHANNELS
|
79 |
+
int x2 = x1 + s2o;
|
80 |
+
int y2 = y1 + s2p;
|
81 |
+
|
82 |
+
int idxPatchData = ji_off + ch;
|
83 |
+
int idx2 = ((item * SIZE_1(rbot0) + y2+j) * SIZE_2(rbot0) + x2+i) * SIZE_3(rbot0) + ch;
|
84 |
+
|
85 |
+
sum[ch_off] += patch_data[idxPatchData] * rbot1[idx2];
|
86 |
+
}
|
87 |
+
}
|
88 |
+
}
|
89 |
+
|
90 |
+
__syncthreads();
|
91 |
+
|
92 |
+
if (ch_off == 0) {
|
93 |
+
float total_sum = 0;
|
94 |
+
for (int idx = 0; idx < 32; idx++) {
|
95 |
+
total_sum += sum[idx];
|
96 |
+
}
|
97 |
+
const int sumelems = SIZE_3(rbot0);
|
98 |
+
const int index = ((top_channel*SIZE_2(top) + blockIdx.y)*SIZE_3(top))+blockIdx.x;
|
99 |
+
top[index + item*SIZE_1(top)*SIZE_2(top)*SIZE_3(top)] = total_sum / (float)sumelems;
|
100 |
+
}
|
101 |
+
}
|
102 |
+
}
|
103 |
+
'''
|
104 |
+
|
105 |
+
kernel_Correlation_updateGradOne = '''
|
106 |
+
#define ROUND_OFF 50000
|
107 |
+
|
108 |
+
extern "C" __global__ void kernel_Correlation_updateGradOne(
|
109 |
+
const int n,
|
110 |
+
const int intSample,
|
111 |
+
const float* rbot0,
|
112 |
+
const float* rbot1,
|
113 |
+
const float* gradOutput,
|
114 |
+
float* gradOne,
|
115 |
+
float* gradTwo
|
116 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
117 |
+
int n = intIndex % SIZE_1(gradOne); // channels
|
118 |
+
int l = (intIndex / SIZE_1(gradOne)) % SIZE_3(gradOne) + 4; // w-pos
|
119 |
+
int m = (intIndex / SIZE_1(gradOne) / SIZE_3(gradOne)) % SIZE_2(gradOne) + 4; // h-pos
|
120 |
+
|
121 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
122 |
+
// We use a large offset, for the inner part not to become negative.
|
123 |
+
const int round_off = ROUND_OFF;
|
124 |
+
const int round_off_s1 = round_off;
|
125 |
+
|
126 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
127 |
+
int xmin = (l - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
|
128 |
+
int ymin = (m - 4 + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4)
|
129 |
+
|
130 |
+
// Same here:
|
131 |
+
int xmax = (l - 4 + round_off_s1) - round_off; // floor (l - 4)
|
132 |
+
int ymax = (m - 4 + round_off_s1) - round_off; // floor (m - 4)
|
133 |
+
|
134 |
+
float sum = 0;
|
135 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
136 |
+
xmin = max(0,xmin);
|
137 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
138 |
+
|
139 |
+
ymin = max(0,ymin);
|
140 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
141 |
+
|
142 |
+
for (int p = -4; p <= 4; p++) {
|
143 |
+
for (int o = -4; o <= 4; o++) {
|
144 |
+
// Get rbot1 data:
|
145 |
+
int s2o = o;
|
146 |
+
int s2p = p;
|
147 |
+
int idxbot1 = ((intSample * SIZE_1(rbot0) + (m+s2p)) * SIZE_2(rbot0) + (l+s2o)) * SIZE_3(rbot0) + n;
|
148 |
+
float bot1tmp = rbot1[idxbot1]; // rbot1[l+s2o,m+s2p,n]
|
149 |
+
|
150 |
+
// Index offset for gradOutput in following loops:
|
151 |
+
int op = (p+4) * 9 + (o+4); // index[o,p]
|
152 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
153 |
+
|
154 |
+
for (int y = ymin; y <= ymax; y++) {
|
155 |
+
for (int x = xmin; x <= xmax; x++) {
|
156 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
157 |
+
sum += gradOutput[idxgradOutput] * bot1tmp;
|
158 |
+
}
|
159 |
+
}
|
160 |
+
}
|
161 |
+
}
|
162 |
+
}
|
163 |
+
const int sumelems = SIZE_1(gradOne);
|
164 |
+
const int bot0index = ((n * SIZE_2(gradOne)) + (m-4)) * SIZE_3(gradOne) + (l-4);
|
165 |
+
gradOne[bot0index + intSample*SIZE_1(gradOne)*SIZE_2(gradOne)*SIZE_3(gradOne)] = sum / (float)sumelems;
|
166 |
+
} }
|
167 |
+
'''
|
168 |
+
|
169 |
+
kernel_Correlation_updateGradTwo = '''
|
170 |
+
#define ROUND_OFF 50000
|
171 |
+
|
172 |
+
extern "C" __global__ void kernel_Correlation_updateGradTwo(
|
173 |
+
const int n,
|
174 |
+
const int intSample,
|
175 |
+
const float* rbot0,
|
176 |
+
const float* rbot1,
|
177 |
+
const float* gradOutput,
|
178 |
+
float* gradOne,
|
179 |
+
float* gradTwo
|
180 |
+
) { for (int intIndex = (blockIdx.x * blockDim.x) + threadIdx.x; intIndex < n; intIndex += blockDim.x * gridDim.x) {
|
181 |
+
int n = intIndex % SIZE_1(gradTwo); // channels
|
182 |
+
int l = (intIndex / SIZE_1(gradTwo)) % SIZE_3(gradTwo) + 4; // w-pos
|
183 |
+
int m = (intIndex / SIZE_1(gradTwo) / SIZE_3(gradTwo)) % SIZE_2(gradTwo) + 4; // h-pos
|
184 |
+
|
185 |
+
// round_off is a trick to enable integer division with ceil, even for negative numbers
|
186 |
+
// We use a large offset, for the inner part not to become negative.
|
187 |
+
const int round_off = ROUND_OFF;
|
188 |
+
const int round_off_s1 = round_off;
|
189 |
+
|
190 |
+
float sum = 0;
|
191 |
+
for (int p = -4; p <= 4; p++) {
|
192 |
+
for (int o = -4; o <= 4; o++) {
|
193 |
+
int s2o = o;
|
194 |
+
int s2p = p;
|
195 |
+
|
196 |
+
//Get X,Y ranges and clamp
|
197 |
+
// We add round_off before_s1 the int division and subtract round_off after it, to ensure the formula matches ceil behavior:
|
198 |
+
int xmin = (l - 4 - s2o + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
|
199 |
+
int ymin = (m - 4 - s2p + round_off_s1 - 1) + 1 - round_off; // ceil (l - 4 - s2o)
|
200 |
+
|
201 |
+
// Same here:
|
202 |
+
int xmax = (l - 4 - s2o + round_off_s1) - round_off; // floor (l - 4 - s2o)
|
203 |
+
int ymax = (m - 4 - s2p + round_off_s1) - round_off; // floor (m - 4 - s2p)
|
204 |
+
|
205 |
+
if (xmax>=0 && ymax>=0 && (xmin<=SIZE_3(gradOutput)-1) && (ymin<=SIZE_2(gradOutput)-1)) {
|
206 |
+
xmin = max(0,xmin);
|
207 |
+
xmax = min(SIZE_3(gradOutput)-1,xmax);
|
208 |
+
|
209 |
+
ymin = max(0,ymin);
|
210 |
+
ymax = min(SIZE_2(gradOutput)-1,ymax);
|
211 |
+
|
212 |
+
// Get rbot0 data:
|
213 |
+
int idxbot0 = ((intSample * SIZE_1(rbot0) + (m-s2p)) * SIZE_2(rbot0) + (l-s2o)) * SIZE_3(rbot0) + n;
|
214 |
+
float bot0tmp = rbot0[idxbot0]; // rbot1[l+s2o,m+s2p,n]
|
215 |
+
|
216 |
+
// Index offset for gradOutput in following loops:
|
217 |
+
int op = (p+4) * 9 + (o+4); // index[o,p]
|
218 |
+
int idxopoffset = (intSample * SIZE_1(gradOutput) + op);
|
219 |
+
|
220 |
+
for (int y = ymin; y <= ymax; y++) {
|
221 |
+
for (int x = xmin; x <= xmax; x++) {
|
222 |
+
int idxgradOutput = (idxopoffset * SIZE_2(gradOutput) + y) * SIZE_3(gradOutput) + x; // gradOutput[x,y,o,p]
|
223 |
+
sum += gradOutput[idxgradOutput] * bot0tmp;
|
224 |
+
}
|
225 |
+
}
|
226 |
+
}
|
227 |
+
}
|
228 |
+
}
|
229 |
+
const int sumelems = SIZE_1(gradTwo);
|
230 |
+
const int bot1index = ((n * SIZE_2(gradTwo)) + (m-4)) * SIZE_3(gradTwo) + (l-4);
|
231 |
+
gradTwo[bot1index + intSample*SIZE_1(gradTwo)*SIZE_2(gradTwo)*SIZE_3(gradTwo)] = sum / (float)sumelems;
|
232 |
+
} }
|
233 |
+
'''
|
234 |
+
|
235 |
+
def cupy_kernel(strFunction, objVariables):
|
236 |
+
strKernel = globals()[strFunction]
|
237 |
+
|
238 |
+
while True:
|
239 |
+
objMatch = re.search('(SIZE_)([0-4])(\()([^\)]*)(\))', strKernel)
|
240 |
+
|
241 |
+
if objMatch is None:
|
242 |
+
break
|
243 |
+
# end
|
244 |
+
|
245 |
+
intArg = int(objMatch.group(2))
|
246 |
+
|
247 |
+
strTensor = objMatch.group(4)
|
248 |
+
intSizes = objVariables[strTensor].size()
|
249 |
+
|
250 |
+
strKernel = strKernel.replace(objMatch.group(), str(intSizes[intArg] if torch.is_tensor(intSizes[intArg]) == False else intSizes[intArg].item()))
|
251 |
+
|
252 |
+
while True:
|
253 |
+
objMatch = re.search('(VALUE_)([0-4])(\()([^\)]+)(\))', strKernel)
|
254 |
+
|
255 |
+
if objMatch is None:
|
256 |
+
break
|
257 |
+
# end
|
258 |
+
|
259 |
+
intArgs = int(objMatch.group(2))
|
260 |
+
strArgs = objMatch.group(4).split(',')
|
261 |
+
|
262 |
+
strTensor = strArgs[0]
|
263 |
+
intStrides = objVariables[strTensor].stride()
|
264 |
+
strIndex = [ '((' + strArgs[intArg + 1].replace('{', '(').replace('}', ')').strip() + ')*' + str(intStrides[intArg] if torch.is_tensor(intStrides[intArg]) == False else intStrides[intArg].item()) + ')' for intArg in range(intArgs) ]
|
265 |
+
|
266 |
+
strKernel = strKernel.replace(objMatch.group(0), strTensor + '[' + str.join('+', strIndex) + ']')
|
267 |
+
# end
|
268 |
+
|
269 |
+
return strKernel
|
270 |
+
# end
|
271 |
+
|
272 |
+
@cupy.memoize(for_each_device=True)
|
273 |
+
def cupy_launch(strFunction, strKernel):
|
274 |
+
if 'CUDA_HOME' not in os.environ:
|
275 |
+
os.environ['CUDA_HOME'] = cupy.cuda.get_cuda_path()
|
276 |
+
# end
|
277 |
+
|
278 |
+
return cupy.RawKernel(strKernel, strFunction, tuple(['-I ' + os.environ['CUDA_HOME'], '-I ' + os.environ['CUDA_HOME'] + '/include']))
|
279 |
+
# end
|
280 |
+
|
281 |
+
class _FunctionCorrelation(torch.autograd.Function):
|
282 |
+
@staticmethod
|
283 |
+
def forward(self, one, two):
|
284 |
+
rbot0 = one.new_zeros([ one.shape[0], one.shape[2] + 8, one.shape[3] + 8, one.shape[1] ])
|
285 |
+
rbot1 = one.new_zeros([ one.shape[0], one.shape[2] + 8, one.shape[3] + 8, one.shape[1] ])
|
286 |
+
|
287 |
+
one = one.contiguous(); assert(one.is_cuda == True)
|
288 |
+
two = two.contiguous(); assert(two.is_cuda == True)
|
289 |
+
|
290 |
+
output = one.new_zeros([ one.shape[0], 81, one.shape[2], one.shape[3] ])
|
291 |
+
|
292 |
+
if one.is_cuda == True:
|
293 |
+
n = one.shape[2] * one.shape[3]
|
294 |
+
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
295 |
+
'input': one,
|
296 |
+
'output': rbot0
|
297 |
+
}))(
|
298 |
+
grid=tuple([ int((n + 16 - 1) / 16), one.shape[1], one.shape[0] ]),
|
299 |
+
block=tuple([ 16, 1, 1 ]),
|
300 |
+
args=[ cupy.int32(n), one.data_ptr(), rbot0.data_ptr() ]
|
301 |
+
)
|
302 |
+
|
303 |
+
n = two.shape[2] * two.shape[3]
|
304 |
+
cupy_launch('kernel_Correlation_rearrange', cupy_kernel('kernel_Correlation_rearrange', {
|
305 |
+
'input': two,
|
306 |
+
'output': rbot1
|
307 |
+
}))(
|
308 |
+
grid=tuple([ int((n + 16 - 1) / 16), two.shape[1], two.shape[0] ]),
|
309 |
+
block=tuple([ 16, 1, 1 ]),
|
310 |
+
args=[ cupy.int32(n), two.data_ptr(), rbot1.data_ptr() ]
|
311 |
+
)
|
312 |
+
|
313 |
+
n = output.shape[1] * output.shape[2] * output.shape[3]
|
314 |
+
cupy_launch('kernel_Correlation_updateOutput', cupy_kernel('kernel_Correlation_updateOutput', {
|
315 |
+
'rbot0': rbot0,
|
316 |
+
'rbot1': rbot1,
|
317 |
+
'top': output
|
318 |
+
}))(
|
319 |
+
grid=tuple([ output.shape[3], output.shape[2], output.shape[0] ]),
|
320 |
+
block=tuple([ 32, 1, 1 ]),
|
321 |
+
shared_mem=one.shape[1] * 4,
|
322 |
+
args=[ cupy.int32(n), rbot0.data_ptr(), rbot1.data_ptr(), output.data_ptr() ]
|
323 |
+
)
|
324 |
+
|
325 |
+
elif one.is_cuda == False:
|
326 |
+
raise NotImplementedError()
|
327 |
+
|
328 |
+
# end
|
329 |
+
|
330 |
+
self.save_for_backward(one, two, rbot0, rbot1)
|
331 |
+
|
332 |
+
return output
|
333 |
+
# end
|
334 |
+
|
335 |
+
@staticmethod
|
336 |
+
def backward(self, gradOutput):
|
337 |
+
one, two, rbot0, rbot1 = self.saved_tensors
|
338 |
+
|
339 |
+
gradOutput = gradOutput.contiguous(); assert(gradOutput.is_cuda == True)
|
340 |
+
|
341 |
+
gradOne = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[0] == True else None
|
342 |
+
gradTwo = one.new_zeros([ one.shape[0], one.shape[1], one.shape[2], one.shape[3] ]) if self.needs_input_grad[1] == True else None
|
343 |
+
|
344 |
+
if one.is_cuda == True:
|
345 |
+
if gradOne is not None:
|
346 |
+
for intSample in range(one.shape[0]):
|
347 |
+
n = one.shape[1] * one.shape[2] * one.shape[3]
|
348 |
+
cupy_launch('kernel_Correlation_updateGradOne', cupy_kernel('kernel_Correlation_updateGradOne', {
|
349 |
+
'rbot0': rbot0,
|
350 |
+
'rbot1': rbot1,
|
351 |
+
'gradOutput': gradOutput,
|
352 |
+
'gradOne': gradOne,
|
353 |
+
'gradTwo': None
|
354 |
+
}))(
|
355 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
356 |
+
block=tuple([ 512, 1, 1 ]),
|
357 |
+
args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), gradOne.data_ptr(), None ]
|
358 |
+
)
|
359 |
+
# end
|
360 |
+
# end
|
361 |
+
|
362 |
+
if gradTwo is not None:
|
363 |
+
for intSample in range(one.shape[0]):
|
364 |
+
n = one.shape[1] * one.shape[2] * one.shape[3]
|
365 |
+
cupy_launch('kernel_Correlation_updateGradTwo', cupy_kernel('kernel_Correlation_updateGradTwo', {
|
366 |
+
'rbot0': rbot0,
|
367 |
+
'rbot1': rbot1,
|
368 |
+
'gradOutput': gradOutput,
|
369 |
+
'gradOne': None,
|
370 |
+
'gradTwo': gradTwo
|
371 |
+
}))(
|
372 |
+
grid=tuple([ int((n + 512 - 1) / 512), 1, 1 ]),
|
373 |
+
block=tuple([ 512, 1, 1 ]),
|
374 |
+
args=[ cupy.int32(n), intSample, rbot0.data_ptr(), rbot1.data_ptr(), gradOutput.data_ptr(), None, gradTwo.data_ptr() ]
|
375 |
+
)
|
376 |
+
# end
|
377 |
+
# end
|
378 |
+
|
379 |
+
elif one.is_cuda == False:
|
380 |
+
raise NotImplementedError()
|
381 |
+
|
382 |
+
# end
|
383 |
+
|
384 |
+
return gradOne, gradTwo
|
385 |
+
# end
|
386 |
+
# end
|
387 |
+
|
388 |
+
def FunctionCorrelation(tenOne, tenTwo):
|
389 |
+
return _FunctionCorrelation.apply(tenOne, tenTwo)
|
390 |
+
# end
|
391 |
+
|
392 |
+
class ModuleCorrelation(torch.nn.Module):
|
393 |
+
def __init__(self):
|
394 |
+
super().__init__()
|
395 |
+
# end
|
396 |
+
|
397 |
+
def forward(self, tenOne, tenTwo):
|
398 |
+
return _FunctionCorrelation.apply(tenOne, tenTwo)
|
399 |
+
# end
|
400 |
+
# end
|
depth_anything_v2/softmax-splatting/images/README.md
ADDED
@@ -0,0 +1 @@
|
|
|
|
|
1 |
+
The used example originates from the DAVIS challenge: https://davischallenge.org/
|
depth_anything_v2/softmax-splatting/images/flow.flo
ADDED
Binary file (819 kB). View file
|
|
depth_anything_v2/softmax-splatting/images/one.png
ADDED
depth_anything_v2/softmax-splatting/images/two.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/Beanbags/frame10.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/Beanbags/frame10i11.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/Beanbags/frame11.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/Dimetrodon/frame10.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/Dimetrodon/frame10i11.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/Dimetrodon/frame11.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/DogDance/frame10.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/DogDance/frame10i11.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/DogDance/frame11.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/Grove2/frame10.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/Grove2/frame10i11.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/Grove2/frame11.png
ADDED
depth_anything_v2/softmax-splatting/middlebury/Grove3/frame10.png
ADDED