Spaces:
Running
Running
Vincentqyw
commited on
Commit
•
2947428
1
Parent(s):
a9f1fc6
fix:roma
Browse files- common/utils.py +4 -4
- third_party/Roma/roma/models/encoders.py +33 -81
- third_party/Roma/roma/models/matcher.py +145 -267
common/utils.py
CHANGED
@@ -49,7 +49,7 @@ def gen_examples():
|
|
49 |
"topicfm",
|
50 |
"superpoint+superglue",
|
51 |
"disk+dualsoftmax",
|
52 |
-
"
|
53 |
]
|
54 |
|
55 |
def gen_images_pairs(path: str, count: int = 5):
|
@@ -452,12 +452,11 @@ ransac_zoo = {
|
|
452 |
|
453 |
# Matchers collections
|
454 |
matcher_zoo = {
|
455 |
-
"gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
|
456 |
-
"sold2": {"config": match_dense.confs["sold2"], "dense": True},
|
457 |
# 'dedode-sparse': {
|
458 |
# 'config': match_dense.confs['dedode_sparse'],
|
459 |
# 'dense': True # dense mode, we need 2 images
|
460 |
# },
|
|
|
461 |
"loftr": {"config": match_dense.confs["loftr"], "dense": True},
|
462 |
"topicfm": {"config": match_dense.confs["topicfm"], "dense": True},
|
463 |
"aspanformer": {"config": match_dense.confs["aspanformer"], "dense": True},
|
@@ -556,6 +555,7 @@ matcher_zoo = {
|
|
556 |
"config_feature": extract_features.confs["sift"],
|
557 |
"dense": False,
|
558 |
},
|
559 |
-
|
|
|
560 |
# "DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
|
561 |
}
|
|
|
49 |
"topicfm",
|
50 |
"superpoint+superglue",
|
51 |
"disk+dualsoftmax",
|
52 |
+
"roma",
|
53 |
]
|
54 |
|
55 |
def gen_images_pairs(path: str, count: int = 5):
|
|
|
452 |
|
453 |
# Matchers collections
|
454 |
matcher_zoo = {
|
|
|
|
|
455 |
# 'dedode-sparse': {
|
456 |
# 'config': match_dense.confs['dedode_sparse'],
|
457 |
# 'dense': True # dense mode, we need 2 images
|
458 |
# },
|
459 |
+
"roma": {"config": match_dense.confs["roma"], "dense": True},
|
460 |
"loftr": {"config": match_dense.confs["loftr"], "dense": True},
|
461 |
"topicfm": {"config": match_dense.confs["topicfm"], "dense": True},
|
462 |
"aspanformer": {"config": match_dense.confs["aspanformer"], "dense": True},
|
|
|
555 |
"config_feature": extract_features.confs["sift"],
|
556 |
"dense": False,
|
557 |
},
|
558 |
+
"gluestick": {"config": match_dense.confs["gluestick"], "dense": True},
|
559 |
+
"sold2": {"config": match_dense.confs["sold2"], "dense": True},
|
560 |
# "DKMv3": {"config": match_dense.confs["dkm"], "dense": True},
|
561 |
}
|
third_party/Roma/roma/models/encoders.py
CHANGED
@@ -6,59 +6,37 @@ import torch.nn.functional as F
|
|
6 |
import torchvision.models as tvm
|
7 |
import gc
|
8 |
|
9 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
10 |
-
|
11 |
|
12 |
class ResNet50(nn.Module):
|
13 |
-
def __init__(
|
14 |
-
self,
|
15 |
-
pretrained=False,
|
16 |
-
high_res=False,
|
17 |
-
weights=None,
|
18 |
-
dilation=None,
|
19 |
-
freeze_bn=True,
|
20 |
-
anti_aliased=False,
|
21 |
-
early_exit=False,
|
22 |
-
amp=False,
|
23 |
-
) -> None:
|
24 |
super().__init__()
|
25 |
if dilation is None:
|
26 |
-
dilation = [False,
|
27 |
if anti_aliased:
|
28 |
pass
|
29 |
else:
|
30 |
if weights is not None:
|
31 |
-
self.net = tvm.resnet50(
|
32 |
-
weights=weights, replace_stride_with_dilation=dilation
|
33 |
-
)
|
34 |
else:
|
35 |
-
self.net = tvm.resnet50(
|
36 |
-
|
37 |
-
)
|
38 |
-
|
39 |
self.high_res = high_res
|
40 |
self.freeze_bn = freeze_bn
|
41 |
self.early_exit = early_exit
|
42 |
self.amp = amp
|
43 |
-
if torch.cuda.
|
44 |
-
if torch.cuda.is_bf16_supported():
|
45 |
-
self.amp_dtype = torch.bfloat16
|
46 |
-
else:
|
47 |
-
self.amp_dtype = torch.float16
|
48 |
-
else:
|
49 |
-
self.amp_dtype = torch.float32
|
50 |
|
51 |
def forward(self, x, **kwargs):
|
52 |
-
with torch.autocast(
|
53 |
net = self.net
|
54 |
-
feats = {1:
|
55 |
x = net.conv1(x)
|
56 |
x = net.bn1(x)
|
57 |
x = net.relu(x)
|
58 |
-
feats[2] = x
|
59 |
x = net.maxpool(x)
|
60 |
x = net.layer1(x)
|
61 |
-
feats[4] = x
|
62 |
x = net.layer2(x)
|
63 |
feats[8] = x
|
64 |
if self.early_exit:
|
@@ -77,48 +55,35 @@ class ResNet50(nn.Module):
|
|
77 |
m.eval()
|
78 |
pass
|
79 |
|
80 |
-
|
81 |
class VGG19(nn.Module):
|
82 |
-
def __init__(self, pretrained=False, amp=False) -> None:
|
83 |
super().__init__()
|
84 |
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
85 |
self.amp = amp
|
86 |
-
if torch.cuda.
|
87 |
-
if torch.cuda.is_bf16_supported():
|
88 |
-
self.amp_dtype = torch.bfloat16
|
89 |
-
else:
|
90 |
-
self.amp_dtype = torch.float16
|
91 |
-
else:
|
92 |
-
self.amp_dtype = torch.float32
|
93 |
|
94 |
def forward(self, x, **kwargs):
|
95 |
-
with torch.autocast(
|
96 |
feats = {}
|
97 |
scale = 1
|
98 |
for layer in self.layers:
|
99 |
if isinstance(layer, nn.MaxPool2d):
|
100 |
feats[scale] = x
|
101 |
-
scale = scale
|
102 |
x = layer(x)
|
103 |
return feats
|
104 |
|
105 |
-
|
106 |
class CNNandDinov2(nn.Module):
|
107 |
-
def __init__(self, cnn_kwargs=None, amp=False, use_vgg=False, dinov2_weights=None):
|
108 |
super().__init__()
|
109 |
if dinov2_weights is None:
|
110 |
-
dinov2_weights = torch.hub.load_state_dict_from_url(
|
111 |
-
"https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth",
|
112 |
-
map_location="cpu",
|
113 |
-
)
|
114 |
from .transformer import vit_large
|
115 |
-
|
116 |
-
|
117 |
-
|
118 |
-
|
119 |
-
|
120 |
-
ffn_layer="mlp",
|
121 |
-
block_chunks=0,
|
122 |
)
|
123 |
|
124 |
dinov2_vitl14 = vit_large(**vit_kwargs).eval()
|
@@ -129,38 +94,25 @@ class CNNandDinov2(nn.Module):
|
|
129 |
else:
|
130 |
self.cnn = VGG19(**cnn_kwargs)
|
131 |
self.amp = amp
|
132 |
-
if torch.cuda.
|
133 |
-
if torch.cuda.is_bf16_supported():
|
134 |
-
self.amp_dtype = torch.bfloat16
|
135 |
-
else:
|
136 |
-
self.amp_dtype = torch.float16
|
137 |
-
else:
|
138 |
-
self.amp_dtype = torch.float32
|
139 |
if self.amp:
|
140 |
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
141 |
-
self.dinov2_vitl14 = [dinov2_vitl14]
|
142 |
-
|
|
|
143 |
def train(self, mode: bool = True):
|
144 |
return self.cnn.train(mode)
|
145 |
-
|
146 |
-
def forward(self, x, upsample=False):
|
147 |
-
B,
|
148 |
feature_pyramid = self.cnn(x)
|
149 |
-
|
150 |
if not upsample:
|
151 |
with torch.no_grad():
|
152 |
if self.dinov2_vitl14[0].device != x.device:
|
153 |
-
self.dinov2_vitl14[0] = (
|
154 |
-
|
155 |
-
|
156 |
-
dinov2_features_16 = self.dinov2_vitl14[0].forward_features(
|
157 |
-
x.to(self.amp_dtype)
|
158 |
-
)
|
159 |
-
features_16 = (
|
160 |
-
dinov2_features_16["x_norm_patchtokens"]
|
161 |
-
.permute(0, 2, 1)
|
162 |
-
.reshape(B, 1024, H // 14, W // 14)
|
163 |
-
)
|
164 |
del dinov2_features_16
|
165 |
feature_pyramid[16] = features_16
|
166 |
-
return feature_pyramid
|
|
|
6 |
import torchvision.models as tvm
|
7 |
import gc
|
8 |
|
|
|
|
|
9 |
|
10 |
class ResNet50(nn.Module):
|
11 |
+
def __init__(self, pretrained=False, high_res = False, weights = None, dilation = None, freeze_bn = True, anti_aliased = False, early_exit = False, amp = False) -> None:
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
12 |
super().__init__()
|
13 |
if dilation is None:
|
14 |
+
dilation = [False,False,False]
|
15 |
if anti_aliased:
|
16 |
pass
|
17 |
else:
|
18 |
if weights is not None:
|
19 |
+
self.net = tvm.resnet50(weights = weights,replace_stride_with_dilation=dilation)
|
|
|
|
|
20 |
else:
|
21 |
+
self.net = tvm.resnet50(pretrained=pretrained,replace_stride_with_dilation=dilation)
|
22 |
+
|
|
|
|
|
23 |
self.high_res = high_res
|
24 |
self.freeze_bn = freeze_bn
|
25 |
self.early_exit = early_exit
|
26 |
self.amp = amp
|
27 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
28 |
|
29 |
def forward(self, x, **kwargs):
|
30 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
31 |
net = self.net
|
32 |
+
feats = {1:x}
|
33 |
x = net.conv1(x)
|
34 |
x = net.bn1(x)
|
35 |
x = net.relu(x)
|
36 |
+
feats[2] = x
|
37 |
x = net.maxpool(x)
|
38 |
x = net.layer1(x)
|
39 |
+
feats[4] = x
|
40 |
x = net.layer2(x)
|
41 |
feats[8] = x
|
42 |
if self.early_exit:
|
|
|
55 |
m.eval()
|
56 |
pass
|
57 |
|
|
|
58 |
class VGG19(nn.Module):
|
59 |
+
def __init__(self, pretrained=False, amp = False) -> None:
|
60 |
super().__init__()
|
61 |
self.layers = nn.ModuleList(tvm.vgg19_bn(pretrained=pretrained).features[:40])
|
62 |
self.amp = amp
|
63 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
64 |
|
65 |
def forward(self, x, **kwargs):
|
66 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
67 |
feats = {}
|
68 |
scale = 1
|
69 |
for layer in self.layers:
|
70 |
if isinstance(layer, nn.MaxPool2d):
|
71 |
feats[scale] = x
|
72 |
+
scale = scale*2
|
73 |
x = layer(x)
|
74 |
return feats
|
75 |
|
|
|
76 |
class CNNandDinov2(nn.Module):
|
77 |
+
def __init__(self, cnn_kwargs = None, amp = False, use_vgg = False, dinov2_weights = None):
|
78 |
super().__init__()
|
79 |
if dinov2_weights is None:
|
80 |
+
dinov2_weights = torch.hub.load_state_dict_from_url("https://dl.fbaipublicfiles.com/dinov2/dinov2_vitl14/dinov2_vitl14_pretrain.pth", map_location="cpu")
|
|
|
|
|
|
|
81 |
from .transformer import vit_large
|
82 |
+
vit_kwargs = dict(img_size= 518,
|
83 |
+
patch_size= 14,
|
84 |
+
init_values = 1.0,
|
85 |
+
ffn_layer = "mlp",
|
86 |
+
block_chunks = 0,
|
|
|
|
|
87 |
)
|
88 |
|
89 |
dinov2_vitl14 = vit_large(**vit_kwargs).eval()
|
|
|
94 |
else:
|
95 |
self.cnn = VGG19(**cnn_kwargs)
|
96 |
self.amp = amp
|
97 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
|
|
|
|
|
|
|
|
|
|
|
|
98 |
if self.amp:
|
99 |
dinov2_vitl14 = dinov2_vitl14.to(self.amp_dtype)
|
100 |
+
self.dinov2_vitl14 = [dinov2_vitl14] # ugly hack to not show parameters to DDP
|
101 |
+
|
102 |
+
|
103 |
def train(self, mode: bool = True):
|
104 |
return self.cnn.train(mode)
|
105 |
+
|
106 |
+
def forward(self, x, upsample = False):
|
107 |
+
B,C,H,W = x.shape
|
108 |
feature_pyramid = self.cnn(x)
|
109 |
+
|
110 |
if not upsample:
|
111 |
with torch.no_grad():
|
112 |
if self.dinov2_vitl14[0].device != x.device:
|
113 |
+
self.dinov2_vitl14[0] = self.dinov2_vitl14[0].to(x.device).to(self.amp_dtype)
|
114 |
+
dinov2_features_16 = self.dinov2_vitl14[0].forward_features(x.to(self.amp_dtype))
|
115 |
+
features_16 = dinov2_features_16['x_norm_patchtokens'].permute(0,2,1).reshape(B,1024,H//14, W//14)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
116 |
del dinov2_features_16
|
117 |
feature_pyramid[16] = features_16
|
118 |
+
return feature_pyramid
|
third_party/Roma/roma/models/matcher.py
CHANGED
@@ -14,9 +14,6 @@ from roma.utils.local_correlation import local_correlation
|
|
14 |
from roma.utils.utils import cls_to_flow_refine
|
15 |
from roma.utils.kde import kde
|
16 |
|
17 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
18 |
-
|
19 |
-
|
20 |
class ConvRefiner(nn.Module):
|
21 |
def __init__(
|
22 |
self,
|
@@ -26,29 +23,25 @@ class ConvRefiner(nn.Module):
|
|
26 |
dw=False,
|
27 |
kernel_size=5,
|
28 |
hidden_blocks=3,
|
29 |
-
displacement_emb=None,
|
30 |
-
displacement_emb_dim=None,
|
31 |
-
local_corr_radius=None,
|
32 |
-
corr_in_other=None,
|
33 |
-
no_im_B_fm=False,
|
34 |
-
amp=False,
|
35 |
-
concat_logits=False,
|
36 |
-
use_bias_block_1=True,
|
37 |
-
use_cosine_corr=False,
|
38 |
-
disable_local_corr_grad=False,
|
39 |
-
is_classifier=False,
|
40 |
-
sample_mode="bilinear",
|
41 |
-
norm_type=nn.BatchNorm2d,
|
42 |
-
bn_momentum=0.1,
|
43 |
):
|
44 |
super().__init__()
|
45 |
self.bn_momentum = bn_momentum
|
46 |
self.block1 = self.create_block(
|
47 |
-
in_dim,
|
48 |
-
hidden_dim,
|
49 |
-
dw=dw,
|
50 |
-
kernel_size=kernel_size,
|
51 |
-
bias=use_bias_block_1,
|
52 |
)
|
53 |
self.hidden_blocks = nn.Sequential(
|
54 |
*[
|
@@ -66,7 +59,7 @@ class ConvRefiner(nn.Module):
|
|
66 |
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
|
67 |
if displacement_emb:
|
68 |
self.has_displacement_emb = True
|
69 |
-
self.disp_emb = nn.Conv2d(2,
|
70 |
else:
|
71 |
self.has_displacement_emb = False
|
72 |
self.local_corr_radius = local_corr_radius
|
@@ -78,22 +71,16 @@ class ConvRefiner(nn.Module):
|
|
78 |
self.disable_local_corr_grad = disable_local_corr_grad
|
79 |
self.is_classifier = is_classifier
|
80 |
self.sample_mode = sample_mode
|
81 |
-
if torch.cuda.
|
82 |
-
|
83 |
-
self.amp_dtype = torch.bfloat16
|
84 |
-
else:
|
85 |
-
self.amp_dtype = torch.float16
|
86 |
-
else:
|
87 |
-
self.amp_dtype = torch.float32
|
88 |
-
|
89 |
def create_block(
|
90 |
self,
|
91 |
in_dim,
|
92 |
out_dim,
|
93 |
dw=False,
|
94 |
kernel_size=5,
|
95 |
-
bias=True,
|
96 |
-
norm_type=nn.BatchNorm2d,
|
97 |
):
|
98 |
num_groups = 1 if not dw else in_dim
|
99 |
if dw:
|
@@ -109,56 +96,38 @@ class ConvRefiner(nn.Module):
|
|
109 |
groups=num_groups,
|
110 |
bias=bias,
|
111 |
)
|
112 |
-
norm = (
|
113 |
-
norm_type(out_dim, momentum=self.bn_momentum)
|
114 |
-
if norm_type is nn.BatchNorm2d
|
115 |
-
else norm_type(num_channels=out_dim)
|
116 |
-
)
|
117 |
relu = nn.ReLU(inplace=True)
|
118 |
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
|
119 |
return nn.Sequential(conv1, norm, relu, conv2)
|
120 |
-
|
121 |
-
def forward(self, x, y, flow, scale_factor=1, logits=None):
|
122 |
-
b,
|
123 |
-
with torch.autocast(
|
124 |
with torch.no_grad():
|
125 |
-
x_hat = F.grid_sample(
|
126 |
-
y,
|
127 |
-
flow.permute(0, 2, 3, 1),
|
128 |
-
align_corners=False,
|
129 |
-
mode=self.sample_mode,
|
130 |
-
)
|
131 |
if self.has_displacement_emb:
|
132 |
im_A_coords = torch.meshgrid(
|
133 |
-
|
134 |
-
|
135 |
-
|
136 |
-
|
137 |
)
|
138 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
139 |
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
140 |
-
in_displacement = flow
|
141 |
-
emb_in_displacement = self.disp_emb(
|
142 |
-
40 / 32 * scale_factor * in_displacement
|
143 |
-
)
|
144 |
if self.local_corr_radius:
|
145 |
if self.corr_in_other:
|
146 |
# Corr in other means take a kxk grid around the predicted coordinate in other image
|
147 |
-
local_corr = local_correlation(
|
148 |
-
|
149 |
-
y,
|
150 |
-
local_radius=self.local_corr_radius,
|
151 |
-
flow=flow,
|
152 |
-
sample_mode=self.sample_mode,
|
153 |
-
)
|
154 |
else:
|
155 |
-
raise NotImplementedError(
|
156 |
-
"Local corr in own frame should not be used."
|
157 |
-
)
|
158 |
if self.no_im_B_fm:
|
159 |
x_hat = torch.zeros_like(x)
|
160 |
d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
|
161 |
-
else:
|
162 |
d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
|
163 |
else:
|
164 |
if self.no_im_B_fm:
|
@@ -172,7 +141,6 @@ class ConvRefiner(nn.Module):
|
|
172 |
displacement, certainty = d[:, :-1], d[:, -1:]
|
173 |
return displacement, certainty
|
174 |
|
175 |
-
|
176 |
class CosKernel(nn.Module): # similar to softmax kernel
|
177 |
def __init__(self, T, learn_temperature=False):
|
178 |
super().__init__()
|
@@ -193,7 +161,6 @@ class CosKernel(nn.Module): # similar to softmax kernel
|
|
193 |
K = ((c - 1.0) / T).exp()
|
194 |
return K
|
195 |
|
196 |
-
|
197 |
class GP(nn.Module):
|
198 |
def __init__(
|
199 |
self,
|
@@ -207,7 +174,7 @@ class GP(nn.Module):
|
|
207 |
only_nearest_neighbour=False,
|
208 |
sigma_noise=0.1,
|
209 |
no_cov=False,
|
210 |
-
predict_features=False,
|
211 |
):
|
212 |
super().__init__()
|
213 |
self.K = kernel(T=T, learn_temperature=learn_temperature)
|
@@ -295,9 +262,7 @@ class GP(nn.Module):
|
|
295 |
mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
|
296 |
if not self.no_cov:
|
297 |
cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
|
298 |
-
cov_x = rearrange(
|
299 |
-
cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1
|
300 |
-
)
|
301 |
local_cov_x = self.get_local_cov(cov_x)
|
302 |
local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
|
303 |
gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
|
@@ -305,22 +270,11 @@ class GP(nn.Module):
|
|
305 |
gp_feats = mu_x
|
306 |
return gp_feats
|
307 |
|
308 |
-
|
309 |
class Decoder(nn.Module):
|
310 |
def __init__(
|
311 |
-
self,
|
312 |
-
|
313 |
-
|
314 |
-
proj,
|
315 |
-
conv_refiner,
|
316 |
-
detach=False,
|
317 |
-
scales="all",
|
318 |
-
pos_embeddings=None,
|
319 |
-
num_refinement_steps_per_scale=1,
|
320 |
-
warp_noise_std=0.0,
|
321 |
-
displacement_dropout_p=0.0,
|
322 |
-
gm_warp_dropout_p=0.0,
|
323 |
-
flow_upsample_mode="bilinear",
|
324 |
):
|
325 |
super().__init__()
|
326 |
self.embedding_decoder = embedding_decoder
|
@@ -342,14 +296,8 @@ class Decoder(nn.Module):
|
|
342 |
self.displacement_dropout_p = displacement_dropout_p
|
343 |
self.gm_warp_dropout_p = gm_warp_dropout_p
|
344 |
self.flow_upsample_mode = flow_upsample_mode
|
345 |
-
if torch.cuda.
|
346 |
-
|
347 |
-
self.amp_dtype = torch.bfloat16
|
348 |
-
else:
|
349 |
-
self.amp_dtype = torch.float16
|
350 |
-
else:
|
351 |
-
self.amp_dtype = torch.float32
|
352 |
-
|
353 |
def get_placeholder_flow(self, b, h, w, device):
|
354 |
coarse_coords = torch.meshgrid(
|
355 |
(
|
@@ -362,8 +310,8 @@ class Decoder(nn.Module):
|
|
362 |
].expand(b, h, w, 2)
|
363 |
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
364 |
return coarse_coords
|
365 |
-
|
366 |
-
def get_positional_embedding(self, b, h,
|
367 |
coarse_coords = torch.meshgrid(
|
368 |
(
|
369 |
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
@@ -378,29 +326,16 @@ class Decoder(nn.Module):
|
|
378 |
coarse_embedded_coords = self.pos_embedding(coarse_coords)
|
379 |
return coarse_embedded_coords
|
380 |
|
381 |
-
def forward(
|
382 |
-
self,
|
383 |
-
f1,
|
384 |
-
f2,
|
385 |
-
gt_warp=None,
|
386 |
-
gt_prob=None,
|
387 |
-
upsample=False,
|
388 |
-
flow=None,
|
389 |
-
certainty=None,
|
390 |
-
scale_factor=1,
|
391 |
-
):
|
392 |
coarse_scales = self.embedding_decoder.scales()
|
393 |
-
all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
|
394 |
sizes = {scale: f1[scale].shape[-2:] for scale in f1}
|
395 |
h, w = sizes[1]
|
396 |
b = f1[1].shape[0]
|
397 |
device = f1[1].device
|
398 |
coarsest_scale = int(all_scales[0])
|
399 |
old_stuff = torch.zeros(
|
400 |
-
b,
|
401 |
-
self.embedding_decoder.hidden_dim,
|
402 |
-
*sizes[coarsest_scale],
|
403 |
-
device=f1[coarsest_scale].device,
|
404 |
)
|
405 |
corresps = {}
|
406 |
if not upsample:
|
@@ -408,24 +343,24 @@ class Decoder(nn.Module):
|
|
408 |
certainty = 0.0
|
409 |
else:
|
410 |
flow = F.interpolate(
|
411 |
-
|
412 |
-
|
413 |
-
|
414 |
-
|
415 |
-
|
416 |
certainty = F.interpolate(
|
417 |
-
|
418 |
-
|
419 |
-
|
420 |
-
|
421 |
-
|
422 |
displacement = 0.0
|
423 |
for new_scale in all_scales:
|
424 |
ins = int(new_scale)
|
425 |
corresps[ins] = {}
|
426 |
f1_s, f2_s = f1[ins], f2[ins]
|
427 |
if new_scale in self.proj:
|
428 |
-
with torch.autocast(
|
429 |
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
430 |
|
431 |
if ins in coarse_scales:
|
@@ -436,59 +371,32 @@ class Decoder(nn.Module):
|
|
436 |
gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
|
437 |
gp_posterior, f1_s, old_stuff, new_scale
|
438 |
)
|
439 |
-
|
440 |
if self.embedding_decoder.is_classifier:
|
441 |
flow = cls_to_flow_refine(
|
442 |
gm_warp_or_cls,
|
443 |
-
).permute(0,
|
444 |
-
corresps[ins].update(
|
445 |
-
{
|
446 |
-
"gm_cls": gm_warp_or_cls,
|
447 |
-
"gm_certainty": certainty,
|
448 |
-
}
|
449 |
-
) if self.training else None
|
450 |
else:
|
451 |
-
corresps[ins].update(
|
452 |
-
{
|
453 |
-
"gm_flow": gm_warp_or_cls,
|
454 |
-
"gm_certainty": certainty,
|
455 |
-
}
|
456 |
-
) if self.training else None
|
457 |
flow = gm_warp_or_cls.detach()
|
458 |
-
|
459 |
if new_scale in self.conv_refiner:
|
460 |
-
corresps[ins].update(
|
461 |
-
{"flow_pre_delta": flow}
|
462 |
-
) if self.training else None
|
463 |
delta_flow, delta_certainty = self.conv_refiner[new_scale](
|
464 |
-
f1_s,
|
465 |
-
|
466 |
-
|
467 |
-
|
468 |
-
|
469 |
-
)
|
470 |
-
corresps[ins].update(
|
471 |
-
{
|
472 |
-
"delta_flow": delta_flow,
|
473 |
-
}
|
474 |
-
) if self.training else None
|
475 |
-
displacement = ins * torch.stack(
|
476 |
-
(
|
477 |
-
delta_flow[:, 0].float() / (self.refine_init * w),
|
478 |
-
delta_flow[:, 1].float() / (self.refine_init * h),
|
479 |
-
),
|
480 |
-
dim=1,
|
481 |
-
)
|
482 |
flow = flow + displacement
|
483 |
certainty = (
|
484 |
certainty + delta_certainty
|
485 |
) # predict both certainty and displacement
|
486 |
-
corresps[ins].update(
|
487 |
-
|
488 |
-
|
489 |
-
|
490 |
-
}
|
491 |
-
)
|
492 |
if new_scale != "1":
|
493 |
flow = F.interpolate(
|
494 |
flow,
|
@@ -503,7 +411,7 @@ class Decoder(nn.Module):
|
|
503 |
if self.detach:
|
504 |
flow = flow.detach()
|
505 |
certainty = certainty.detach()
|
506 |
-
#
|
507 |
return corresps
|
508 |
|
509 |
|
@@ -514,11 +422,11 @@ class RegressionMatcher(nn.Module):
|
|
514 |
decoder,
|
515 |
h=448,
|
516 |
w=448,
|
517 |
-
sample_mode="threshold",
|
518 |
-
upsample_preds=False,
|
519 |
-
symmetric=False,
|
520 |
-
name=None,
|
521 |
-
attenuate_cert=None,
|
522 |
):
|
523 |
super().__init__()
|
524 |
self.attenuate_cert = attenuate_cert
|
@@ -530,26 +438,24 @@ class RegressionMatcher(nn.Module):
|
|
530 |
self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
|
531 |
self.sample_mode = sample_mode
|
532 |
self.upsample_preds = upsample_preds
|
533 |
-
self.upsample_res = (14
|
534 |
self.symmetric = symmetric
|
535 |
self.sample_thresh = 0.05
|
536 |
-
|
537 |
def get_output_resolution(self):
|
538 |
if not self.upsample_preds:
|
539 |
return self.h_resized, self.w_resized
|
540 |
else:
|
541 |
return self.upsample_res
|
542 |
-
|
543 |
-
def extract_backbone_features(self, batch, batched=True, upsample=False):
|
544 |
x_q = batch["im_A"]
|
545 |
x_s = batch["im_B"]
|
546 |
if batched:
|
547 |
-
X = torch.cat((x_q, x_s), dim=0)
|
548 |
-
feature_pyramid = self.encoder(X, upsample=upsample)
|
549 |
else:
|
550 |
-
feature_pyramid = self.encoder(x_q, upsample=upsample), self.encoder(
|
551 |
-
x_s, upsample=upsample
|
552 |
-
)
|
553 |
return feature_pyramid
|
554 |
|
555 |
def sample(
|
@@ -567,28 +473,22 @@ class RegressionMatcher(nn.Module):
|
|
567 |
certainty.reshape(-1),
|
568 |
)
|
569 |
expansion_factor = 4 if "balanced" in self.sample_mode else 1
|
570 |
-
good_samples = torch.multinomial(
|
571 |
-
|
572 |
-
|
573 |
-
replacement=False,
|
574 |
-
)
|
575 |
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
|
576 |
if "balanced" not in self.sample_mode:
|
577 |
return good_matches, good_certainty
|
578 |
density = kde(good_matches, std=0.1)
|
579 |
-
p = 1 / (density
|
580 |
-
p[
|
581 |
-
|
582 |
-
|
583 |
-
|
584 |
-
p, num_samples=min(num, len(good_certainty)), replacement=False
|
585 |
-
)
|
586 |
return good_matches[balanced_samples], good_certainty[balanced_samples]
|
587 |
|
588 |
-
def forward(self, batch, batched=True, upsample=False, scale_factor=1):
|
589 |
-
feature_pyramid = self.extract_backbone_features(
|
590 |
-
batch, batched=batched, upsample=upsample
|
591 |
-
)
|
592 |
if batched:
|
593 |
f_q_pyramid = {
|
594 |
scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
|
@@ -598,42 +498,32 @@ class RegressionMatcher(nn.Module):
|
|
598 |
}
|
599 |
else:
|
600 |
f_q_pyramid, f_s_pyramid = feature_pyramid
|
601 |
-
corresps = self.decoder(
|
602 |
-
|
603 |
-
|
604 |
-
|
605 |
-
|
606 |
-
|
607 |
-
)
|
608 |
-
|
609 |
return corresps
|
610 |
|
611 |
-
def forward_symmetric(self, batch, batched=True, upsample=False, scale_factor=1):
|
612 |
-
feature_pyramid = self.extract_backbone_features(
|
613 |
-
batch, batched=batched, upsample=upsample
|
614 |
-
)
|
615 |
f_q_pyramid = feature_pyramid
|
616 |
f_s_pyramid = {
|
617 |
-
scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim=0)
|
618 |
for scale, f_scale in feature_pyramid.items()
|
619 |
}
|
620 |
-
corresps = self.decoder(
|
621 |
-
|
622 |
-
|
623 |
-
|
624 |
-
|
625 |
-
scale_factor=scale_factor,
|
626 |
-
)
|
627 |
return corresps
|
628 |
-
|
629 |
def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
|
630 |
-
kpts_A, kpts_B = matches[
|
631 |
-
kpts_A = torch.stack(
|
632 |
-
|
633 |
-
)
|
634 |
-
kpts_B = torch.stack(
|
635 |
-
(W_B / 2 * (kpts_B[..., 0] + 1), H_B / 2 * (kpts_B[..., 1] + 1)), axis=-1
|
636 |
-
)
|
637 |
return kpts_A, kpts_B
|
638 |
|
639 |
def match(
|
@@ -642,12 +532,11 @@ class RegressionMatcher(nn.Module):
|
|
642 |
im_B_path,
|
643 |
*args,
|
644 |
batched=False,
|
645 |
-
device=None,
|
646 |
):
|
647 |
if device is None:
|
648 |
-
device = torch.device(
|
649 |
from PIL import Image
|
650 |
-
|
651 |
if isinstance(im_A_path, (str, os.PathLike)):
|
652 |
im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
653 |
else:
|
@@ -663,9 +552,9 @@ class RegressionMatcher(nn.Module):
|
|
663 |
# Get images in good format
|
664 |
ws = self.w_resized
|
665 |
hs = self.h_resized
|
666 |
-
|
667 |
test_transform = get_tuple_transform_ops(
|
668 |
-
resize=(hs, ws), normalize=True, clahe=False
|
669 |
)
|
670 |
im_A, im_B = test_transform((im_A, im_B))
|
671 |
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
@@ -675,32 +564,25 @@ class RegressionMatcher(nn.Module):
|
|
675 |
assert w == w2 and h == h2, "For batched images we assume same size"
|
676 |
batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
|
677 |
if h != self.h_resized or self.w_resized != w:
|
678 |
-
warn(
|
679 |
-
"Model resolution and batch resolution differ, may produce unexpected results"
|
680 |
-
)
|
681 |
hs, ws = h, w
|
682 |
finest_scale = 1
|
683 |
# Run matcher
|
684 |
if symmetric:
|
685 |
-
corresps
|
686 |
else:
|
687 |
-
corresps = self.forward(batch, batched=True)
|
688 |
|
689 |
if self.upsample_preds:
|
690 |
hs, ws = self.upsample_res
|
691 |
-
|
692 |
if self.attenuate_cert:
|
693 |
low_res_certainty = F.interpolate(
|
694 |
-
|
695 |
-
size=(hs, ws),
|
696 |
-
align_corners=False,
|
697 |
-
mode="bilinear",
|
698 |
)
|
699 |
cert_clamp = 0
|
700 |
factor = 0.5
|
701 |
-
low_res_certainty = (
|
702 |
-
factor * low_res_certainty * (low_res_certainty < cert_clamp)
|
703 |
-
)
|
704 |
|
705 |
if self.upsample_preds:
|
706 |
finest_corresps = corresps[finest_scale]
|
@@ -711,38 +593,30 @@ class RegressionMatcher(nn.Module):
|
|
711 |
im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
712 |
im_A, im_B = test_transform((im_A, im_B))
|
713 |
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
714 |
-
scale_factor = math.sqrt(
|
715 |
-
self.upsample_res[0]
|
716 |
-
* self.upsample_res[1]
|
717 |
-
/ (self.w_resized * self.h_resized)
|
718 |
-
)
|
719 |
batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
|
720 |
if symmetric:
|
721 |
-
corresps = self.forward_symmetric(
|
722 |
-
batch, upsample=True, batched=True, scale_factor=scale_factor
|
723 |
-
)
|
724 |
else:
|
725 |
-
corresps = self.forward(
|
726 |
-
|
727 |
-
|
728 |
-
|
729 |
-
im_A_to_im_B = corresps[finest_scale]["flow"]
|
730 |
-
certainty = corresps[finest_scale]["certainty"] - (
|
731 |
-
low_res_certainty if self.attenuate_cert else 0
|
732 |
-
)
|
733 |
if finest_scale != 1:
|
734 |
im_A_to_im_B = F.interpolate(
|
735 |
-
|
736 |
)
|
737 |
certainty = F.interpolate(
|
738 |
-
|
|
|
|
|
|
|
739 |
)
|
740 |
-
im_A_to_im_B = im_A_to_im_B.permute(0, 2, 3, 1)
|
741 |
# Create im_A meshgrid
|
742 |
im_A_coords = torch.meshgrid(
|
743 |
(
|
744 |
-
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device=
|
745 |
-
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device=
|
746 |
)
|
747 |
)
|
748 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
@@ -751,21 +625,25 @@ class RegressionMatcher(nn.Module):
|
|
751 |
im_A_coords = im_A_coords.permute(0, 2, 3, 1)
|
752 |
if (im_A_to_im_B.abs() > 1).any() and True:
|
753 |
wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
|
754 |
-
certainty[wrong[:,
|
755 |
im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
|
756 |
if symmetric:
|
757 |
A_to_B, B_to_A = im_A_to_im_B.chunk(2)
|
758 |
q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
|
759 |
im_B_coords = im_A_coords
|
760 |
s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
|
761 |
-
warp = torch.cat((q_warp, s_warp),
|
762 |
certainty = torch.cat(certainty.chunk(2), dim=3)
|
763 |
else:
|
764 |
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|
765 |
if batched:
|
766 |
-
return (
|
|
|
|
|
|
|
767 |
else:
|
768 |
return (
|
769 |
warp[0],
|
770 |
certainty[0, 0],
|
771 |
)
|
|
|
|
14 |
from roma.utils.utils import cls_to_flow_refine
|
15 |
from roma.utils.kde import kde
|
16 |
|
|
|
|
|
|
|
17 |
class ConvRefiner(nn.Module):
|
18 |
def __init__(
|
19 |
self,
|
|
|
23 |
dw=False,
|
24 |
kernel_size=5,
|
25 |
hidden_blocks=3,
|
26 |
+
displacement_emb = None,
|
27 |
+
displacement_emb_dim = None,
|
28 |
+
local_corr_radius = None,
|
29 |
+
corr_in_other = None,
|
30 |
+
no_im_B_fm = False,
|
31 |
+
amp = False,
|
32 |
+
concat_logits = False,
|
33 |
+
use_bias_block_1 = True,
|
34 |
+
use_cosine_corr = False,
|
35 |
+
disable_local_corr_grad = False,
|
36 |
+
is_classifier = False,
|
37 |
+
sample_mode = "bilinear",
|
38 |
+
norm_type = nn.BatchNorm2d,
|
39 |
+
bn_momentum = 0.1,
|
40 |
):
|
41 |
super().__init__()
|
42 |
self.bn_momentum = bn_momentum
|
43 |
self.block1 = self.create_block(
|
44 |
+
in_dim, hidden_dim, dw=dw, kernel_size=kernel_size, bias = use_bias_block_1,
|
|
|
|
|
|
|
|
|
45 |
)
|
46 |
self.hidden_blocks = nn.Sequential(
|
47 |
*[
|
|
|
59 |
self.out_conv = nn.Conv2d(hidden_dim, out_dim, 1, 1, 0)
|
60 |
if displacement_emb:
|
61 |
self.has_displacement_emb = True
|
62 |
+
self.disp_emb = nn.Conv2d(2,displacement_emb_dim,1,1,0)
|
63 |
else:
|
64 |
self.has_displacement_emb = False
|
65 |
self.local_corr_radius = local_corr_radius
|
|
|
71 |
self.disable_local_corr_grad = disable_local_corr_grad
|
72 |
self.is_classifier = is_classifier
|
73 |
self.sample_mode = sample_mode
|
74 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
75 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
76 |
def create_block(
|
77 |
self,
|
78 |
in_dim,
|
79 |
out_dim,
|
80 |
dw=False,
|
81 |
kernel_size=5,
|
82 |
+
bias = True,
|
83 |
+
norm_type = nn.BatchNorm2d,
|
84 |
):
|
85 |
num_groups = 1 if not dw else in_dim
|
86 |
if dw:
|
|
|
96 |
groups=num_groups,
|
97 |
bias=bias,
|
98 |
)
|
99 |
+
norm = norm_type(out_dim, momentum = self.bn_momentum) if norm_type is nn.BatchNorm2d else norm_type(num_channels = out_dim)
|
|
|
|
|
|
|
|
|
100 |
relu = nn.ReLU(inplace=True)
|
101 |
conv2 = nn.Conv2d(out_dim, out_dim, 1, 1, 0)
|
102 |
return nn.Sequential(conv1, norm, relu, conv2)
|
103 |
+
|
104 |
+
def forward(self, x, y, flow, scale_factor = 1, logits = None):
|
105 |
+
b,c,hs,ws = x.shape
|
106 |
+
with torch.autocast("cuda", enabled=self.amp, dtype = self.amp_dtype):
|
107 |
with torch.no_grad():
|
108 |
+
x_hat = F.grid_sample(y, flow.permute(0, 2, 3, 1), align_corners=False, mode = self.sample_mode)
|
|
|
|
|
|
|
|
|
|
|
109 |
if self.has_displacement_emb:
|
110 |
im_A_coords = torch.meshgrid(
|
111 |
+
(
|
112 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
|
113 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
|
114 |
+
)
|
115 |
)
|
116 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
117 |
im_A_coords = im_A_coords[None].expand(b, 2, hs, ws)
|
118 |
+
in_displacement = flow-im_A_coords
|
119 |
+
emb_in_displacement = self.disp_emb(40/32 * scale_factor * in_displacement)
|
|
|
|
|
120 |
if self.local_corr_radius:
|
121 |
if self.corr_in_other:
|
122 |
# Corr in other means take a kxk grid around the predicted coordinate in other image
|
123 |
+
local_corr = local_correlation(x,y,local_radius=self.local_corr_radius,flow = flow,
|
124 |
+
sample_mode = self.sample_mode)
|
|
|
|
|
|
|
|
|
|
|
125 |
else:
|
126 |
+
raise NotImplementedError("Local corr in own frame should not be used.")
|
|
|
|
|
127 |
if self.no_im_B_fm:
|
128 |
x_hat = torch.zeros_like(x)
|
129 |
d = torch.cat((x, x_hat, emb_in_displacement, local_corr), dim=1)
|
130 |
+
else:
|
131 |
d = torch.cat((x, x_hat, emb_in_displacement), dim=1)
|
132 |
else:
|
133 |
if self.no_im_B_fm:
|
|
|
141 |
displacement, certainty = d[:, :-1], d[:, -1:]
|
142 |
return displacement, certainty
|
143 |
|
|
|
144 |
class CosKernel(nn.Module): # similar to softmax kernel
|
145 |
def __init__(self, T, learn_temperature=False):
|
146 |
super().__init__()
|
|
|
161 |
K = ((c - 1.0) / T).exp()
|
162 |
return K
|
163 |
|
|
|
164 |
class GP(nn.Module):
|
165 |
def __init__(
|
166 |
self,
|
|
|
174 |
only_nearest_neighbour=False,
|
175 |
sigma_noise=0.1,
|
176 |
no_cov=False,
|
177 |
+
predict_features = False,
|
178 |
):
|
179 |
super().__init__()
|
180 |
self.K = kernel(T=T, learn_temperature=learn_temperature)
|
|
|
262 |
mu_x = rearrange(mu_x, "b (h w) d -> b d h w", h=h1, w=w1)
|
263 |
if not self.no_cov:
|
264 |
cov_x = K_xx - K_xy.matmul(K_yy_inv.matmul(K_yx))
|
265 |
+
cov_x = rearrange(cov_x, "b (h w) (r c) -> b h w r c", h=h1, w=w1, r=h1, c=w1)
|
|
|
|
|
266 |
local_cov_x = self.get_local_cov(cov_x)
|
267 |
local_cov_x = rearrange(local_cov_x, "b h w K -> b K h w")
|
268 |
gp_feats = torch.cat((mu_x, local_cov_x), dim=1)
|
|
|
270 |
gp_feats = mu_x
|
271 |
return gp_feats
|
272 |
|
|
|
273 |
class Decoder(nn.Module):
|
274 |
def __init__(
|
275 |
+
self, embedding_decoder, gps, proj, conv_refiner, detach=False, scales="all", pos_embeddings = None,
|
276 |
+
num_refinement_steps_per_scale = 1, warp_noise_std = 0.0, displacement_dropout_p = 0.0, gm_warp_dropout_p = 0.0,
|
277 |
+
flow_upsample_mode = "bilinear"
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
278 |
):
|
279 |
super().__init__()
|
280 |
self.embedding_decoder = embedding_decoder
|
|
|
296 |
self.displacement_dropout_p = displacement_dropout_p
|
297 |
self.gm_warp_dropout_p = gm_warp_dropout_p
|
298 |
self.flow_upsample_mode = flow_upsample_mode
|
299 |
+
self.amp_dtype = torch.bfloat16 if torch.cuda.is_bf16_supported() else torch.float16
|
300 |
+
|
|
|
|
|
|
|
|
|
|
|
|
|
301 |
def get_placeholder_flow(self, b, h, w, device):
|
302 |
coarse_coords = torch.meshgrid(
|
303 |
(
|
|
|
310 |
].expand(b, h, w, 2)
|
311 |
coarse_coords = rearrange(coarse_coords, "b h w d -> b d h w")
|
312 |
return coarse_coords
|
313 |
+
|
314 |
+
def get_positional_embedding(self, b, h ,w, device):
|
315 |
coarse_coords = torch.meshgrid(
|
316 |
(
|
317 |
torch.linspace(-1 + 1 / h, 1 - 1 / h, h, device=device),
|
|
|
326 |
coarse_embedded_coords = self.pos_embedding(coarse_coords)
|
327 |
return coarse_embedded_coords
|
328 |
|
329 |
+
def forward(self, f1, f2, gt_warp = None, gt_prob = None, upsample = False, flow = None, certainty = None, scale_factor = 1):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
330 |
coarse_scales = self.embedding_decoder.scales()
|
331 |
+
all_scales = self.scales if not upsample else ["8", "4", "2", "1"]
|
332 |
sizes = {scale: f1[scale].shape[-2:] for scale in f1}
|
333 |
h, w = sizes[1]
|
334 |
b = f1[1].shape[0]
|
335 |
device = f1[1].device
|
336 |
coarsest_scale = int(all_scales[0])
|
337 |
old_stuff = torch.zeros(
|
338 |
+
b, self.embedding_decoder.hidden_dim, *sizes[coarsest_scale], device=f1[coarsest_scale].device
|
|
|
|
|
|
|
339 |
)
|
340 |
corresps = {}
|
341 |
if not upsample:
|
|
|
343 |
certainty = 0.0
|
344 |
else:
|
345 |
flow = F.interpolate(
|
346 |
+
flow,
|
347 |
+
size=sizes[coarsest_scale],
|
348 |
+
align_corners=False,
|
349 |
+
mode="bilinear",
|
350 |
+
)
|
351 |
certainty = F.interpolate(
|
352 |
+
certainty,
|
353 |
+
size=sizes[coarsest_scale],
|
354 |
+
align_corners=False,
|
355 |
+
mode="bilinear",
|
356 |
+
)
|
357 |
displacement = 0.0
|
358 |
for new_scale in all_scales:
|
359 |
ins = int(new_scale)
|
360 |
corresps[ins] = {}
|
361 |
f1_s, f2_s = f1[ins], f2[ins]
|
362 |
if new_scale in self.proj:
|
363 |
+
with torch.autocast("cuda", self.amp_dtype):
|
364 |
f1_s, f2_s = self.proj[new_scale](f1_s), self.proj[new_scale](f2_s)
|
365 |
|
366 |
if ins in coarse_scales:
|
|
|
371 |
gm_warp_or_cls, certainty, old_stuff = self.embedding_decoder(
|
372 |
gp_posterior, f1_s, old_stuff, new_scale
|
373 |
)
|
374 |
+
|
375 |
if self.embedding_decoder.is_classifier:
|
376 |
flow = cls_to_flow_refine(
|
377 |
gm_warp_or_cls,
|
378 |
+
).permute(0,3,1,2)
|
379 |
+
corresps[ins].update({"gm_cls": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
|
|
|
|
|
|
|
|
|
|
|
380 |
else:
|
381 |
+
corresps[ins].update({"gm_flow": gm_warp_or_cls,"gm_certainty": certainty,}) if self.training else None
|
|
|
|
|
|
|
|
|
|
|
382 |
flow = gm_warp_or_cls.detach()
|
383 |
+
|
384 |
if new_scale in self.conv_refiner:
|
385 |
+
corresps[ins].update({"flow_pre_delta": flow}) if self.training else None
|
|
|
|
|
386 |
delta_flow, delta_certainty = self.conv_refiner[new_scale](
|
387 |
+
f1_s, f2_s, flow, scale_factor = scale_factor, logits = certainty,
|
388 |
+
)
|
389 |
+
corresps[ins].update({"delta_flow": delta_flow,}) if self.training else None
|
390 |
+
displacement = ins*torch.stack((delta_flow[:, 0].float() / (self.refine_init * w),
|
391 |
+
delta_flow[:, 1].float() / (self.refine_init * h),),dim=1,)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
392 |
flow = flow + displacement
|
393 |
certainty = (
|
394 |
certainty + delta_certainty
|
395 |
) # predict both certainty and displacement
|
396 |
+
corresps[ins].update({
|
397 |
+
"certainty": certainty,
|
398 |
+
"flow": flow,
|
399 |
+
})
|
|
|
|
|
400 |
if new_scale != "1":
|
401 |
flow = F.interpolate(
|
402 |
flow,
|
|
|
411 |
if self.detach:
|
412 |
flow = flow.detach()
|
413 |
certainty = certainty.detach()
|
414 |
+
#torch.cuda.empty_cache()
|
415 |
return corresps
|
416 |
|
417 |
|
|
|
422 |
decoder,
|
423 |
h=448,
|
424 |
w=448,
|
425 |
+
sample_mode = "threshold",
|
426 |
+
upsample_preds = False,
|
427 |
+
symmetric = False,
|
428 |
+
name = None,
|
429 |
+
attenuate_cert = None,
|
430 |
):
|
431 |
super().__init__()
|
432 |
self.attenuate_cert = attenuate_cert
|
|
|
438 |
self.og_transforms = get_tuple_transform_ops(resize=None, normalize=True)
|
439 |
self.sample_mode = sample_mode
|
440 |
self.upsample_preds = upsample_preds
|
441 |
+
self.upsample_res = (14*16*6, 14*16*6)
|
442 |
self.symmetric = symmetric
|
443 |
self.sample_thresh = 0.05
|
444 |
+
|
445 |
def get_output_resolution(self):
|
446 |
if not self.upsample_preds:
|
447 |
return self.h_resized, self.w_resized
|
448 |
else:
|
449 |
return self.upsample_res
|
450 |
+
|
451 |
+
def extract_backbone_features(self, batch, batched = True, upsample = False):
|
452 |
x_q = batch["im_A"]
|
453 |
x_s = batch["im_B"]
|
454 |
if batched:
|
455 |
+
X = torch.cat((x_q, x_s), dim = 0)
|
456 |
+
feature_pyramid = self.encoder(X, upsample = upsample)
|
457 |
else:
|
458 |
+
feature_pyramid = self.encoder(x_q, upsample = upsample), self.encoder(x_s, upsample = upsample)
|
|
|
|
|
459 |
return feature_pyramid
|
460 |
|
461 |
def sample(
|
|
|
473 |
certainty.reshape(-1),
|
474 |
)
|
475 |
expansion_factor = 4 if "balanced" in self.sample_mode else 1
|
476 |
+
good_samples = torch.multinomial(certainty,
|
477 |
+
num_samples = min(expansion_factor*num, len(certainty)),
|
478 |
+
replacement=False)
|
|
|
|
|
479 |
good_matches, good_certainty = matches[good_samples], certainty[good_samples]
|
480 |
if "balanced" not in self.sample_mode:
|
481 |
return good_matches, good_certainty
|
482 |
density = kde(good_matches, std=0.1)
|
483 |
+
p = 1 / (density+1)
|
484 |
+
p[density < 10] = 1e-7 # Basically should have at least 10 perfect neighbours, or around 100 ok ones
|
485 |
+
balanced_samples = torch.multinomial(p,
|
486 |
+
num_samples = min(num,len(good_certainty)),
|
487 |
+
replacement=False)
|
|
|
|
|
488 |
return good_matches[balanced_samples], good_certainty[balanced_samples]
|
489 |
|
490 |
+
def forward(self, batch, batched = True, upsample = False, scale_factor = 1):
|
491 |
+
feature_pyramid = self.extract_backbone_features(batch, batched=batched, upsample = upsample)
|
|
|
|
|
492 |
if batched:
|
493 |
f_q_pyramid = {
|
494 |
scale: f_scale.chunk(2)[0] for scale, f_scale in feature_pyramid.items()
|
|
|
498 |
}
|
499 |
else:
|
500 |
f_q_pyramid, f_s_pyramid = feature_pyramid
|
501 |
+
corresps = self.decoder(f_q_pyramid,
|
502 |
+
f_s_pyramid,
|
503 |
+
upsample = upsample,
|
504 |
+
**(batch["corresps"] if "corresps" in batch else {}),
|
505 |
+
scale_factor=scale_factor)
|
506 |
+
|
|
|
|
|
507 |
return corresps
|
508 |
|
509 |
+
def forward_symmetric(self, batch, batched = True, upsample = False, scale_factor = 1):
|
510 |
+
feature_pyramid = self.extract_backbone_features(batch, batched = batched, upsample = upsample)
|
|
|
|
|
511 |
f_q_pyramid = feature_pyramid
|
512 |
f_s_pyramid = {
|
513 |
+
scale: torch.cat((f_scale.chunk(2)[1], f_scale.chunk(2)[0]), dim = 0)
|
514 |
for scale, f_scale in feature_pyramid.items()
|
515 |
}
|
516 |
+
corresps = self.decoder(f_q_pyramid,
|
517 |
+
f_s_pyramid,
|
518 |
+
upsample = upsample,
|
519 |
+
**(batch["corresps"] if "corresps" in batch else {}),
|
520 |
+
scale_factor=scale_factor)
|
|
|
|
|
521 |
return corresps
|
522 |
+
|
523 |
def to_pixel_coordinates(self, matches, H_A, W_A, H_B, W_B):
|
524 |
+
kpts_A, kpts_B = matches[...,:2], matches[...,2:]
|
525 |
+
kpts_A = torch.stack((W_A/2 * (kpts_A[...,0]+1), H_A/2 * (kpts_A[...,1]+1)),axis=-1)
|
526 |
+
kpts_B = torch.stack((W_B/2 * (kpts_B[...,0]+1), H_B/2 * (kpts_B[...,1]+1)),axis=-1)
|
|
|
|
|
|
|
|
|
527 |
return kpts_A, kpts_B
|
528 |
|
529 |
def match(
|
|
|
532 |
im_B_path,
|
533 |
*args,
|
534 |
batched=False,
|
535 |
+
device = None,
|
536 |
):
|
537 |
if device is None:
|
538 |
+
device = torch.device('cuda' if torch.cuda.is_available() else 'cpu')
|
539 |
from PIL import Image
|
|
|
540 |
if isinstance(im_A_path, (str, os.PathLike)):
|
541 |
im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
542 |
else:
|
|
|
552 |
# Get images in good format
|
553 |
ws = self.w_resized
|
554 |
hs = self.h_resized
|
555 |
+
|
556 |
test_transform = get_tuple_transform_ops(
|
557 |
+
resize=(hs, ws), normalize=True, clahe = False
|
558 |
)
|
559 |
im_A, im_B = test_transform((im_A, im_B))
|
560 |
batch = {"im_A": im_A[None].to(device), "im_B": im_B[None].to(device)}
|
|
|
564 |
assert w == w2 and h == h2, "For batched images we assume same size"
|
565 |
batch = {"im_A": im_A.to(device), "im_B": im_B.to(device)}
|
566 |
if h != self.h_resized or self.w_resized != w:
|
567 |
+
warn("Model resolution and batch resolution differ, may produce unexpected results")
|
|
|
|
|
568 |
hs, ws = h, w
|
569 |
finest_scale = 1
|
570 |
# Run matcher
|
571 |
if symmetric:
|
572 |
+
corresps = self.forward_symmetric(batch)
|
573 |
else:
|
574 |
+
corresps = self.forward(batch, batched = True)
|
575 |
|
576 |
if self.upsample_preds:
|
577 |
hs, ws = self.upsample_res
|
578 |
+
|
579 |
if self.attenuate_cert:
|
580 |
low_res_certainty = F.interpolate(
|
581 |
+
corresps[16]["certainty"], size=(hs, ws), align_corners=False, mode="bilinear"
|
|
|
|
|
|
|
582 |
)
|
583 |
cert_clamp = 0
|
584 |
factor = 0.5
|
585 |
+
low_res_certainty = factor*low_res_certainty*(low_res_certainty < cert_clamp)
|
|
|
|
|
586 |
|
587 |
if self.upsample_preds:
|
588 |
finest_corresps = corresps[finest_scale]
|
|
|
593 |
im_A, im_B = Image.open(im_A_path), Image.open(im_B_path)
|
594 |
im_A, im_B = test_transform((im_A, im_B))
|
595 |
im_A, im_B = im_A[None].to(device), im_B[None].to(device)
|
596 |
+
scale_factor = math.sqrt(self.upsample_res[0] * self.upsample_res[1] / (self.w_resized * self.h_resized))
|
|
|
|
|
|
|
|
|
597 |
batch = {"im_A": im_A, "im_B": im_B, "corresps": finest_corresps}
|
598 |
if symmetric:
|
599 |
+
corresps = self.forward_symmetric(batch, upsample = True, batched=True, scale_factor = scale_factor)
|
|
|
|
|
600 |
else:
|
601 |
+
corresps = self.forward(batch, batched = True, upsample=True, scale_factor = scale_factor)
|
602 |
+
|
603 |
+
im_A_to_im_B = corresps[finest_scale]["flow"]
|
604 |
+
certainty = corresps[finest_scale]["certainty"] - (low_res_certainty if self.attenuate_cert else 0)
|
|
|
|
|
|
|
|
|
605 |
if finest_scale != 1:
|
606 |
im_A_to_im_B = F.interpolate(
|
607 |
+
im_A_to_im_B, size=(hs, ws), align_corners=False, mode="bilinear"
|
608 |
)
|
609 |
certainty = F.interpolate(
|
610 |
+
certainty, size=(hs, ws), align_corners=False, mode="bilinear"
|
611 |
+
)
|
612 |
+
im_A_to_im_B = im_A_to_im_B.permute(
|
613 |
+
0, 2, 3, 1
|
614 |
)
|
|
|
615 |
# Create im_A meshgrid
|
616 |
im_A_coords = torch.meshgrid(
|
617 |
(
|
618 |
+
torch.linspace(-1 + 1 / hs, 1 - 1 / hs, hs, device="cuda"),
|
619 |
+
torch.linspace(-1 + 1 / ws, 1 - 1 / ws, ws, device="cuda"),
|
620 |
)
|
621 |
)
|
622 |
im_A_coords = torch.stack((im_A_coords[1], im_A_coords[0]))
|
|
|
625 |
im_A_coords = im_A_coords.permute(0, 2, 3, 1)
|
626 |
if (im_A_to_im_B.abs() > 1).any() and True:
|
627 |
wrong = (im_A_to_im_B.abs() > 1).sum(dim=-1) > 0
|
628 |
+
certainty[wrong[:,None]] = 0
|
629 |
im_A_to_im_B = torch.clamp(im_A_to_im_B, -1, 1)
|
630 |
if symmetric:
|
631 |
A_to_B, B_to_A = im_A_to_im_B.chunk(2)
|
632 |
q_warp = torch.cat((im_A_coords, A_to_B), dim=-1)
|
633 |
im_B_coords = im_A_coords
|
634 |
s_warp = torch.cat((B_to_A, im_B_coords), dim=-1)
|
635 |
+
warp = torch.cat((q_warp, s_warp),dim=2)
|
636 |
certainty = torch.cat(certainty.chunk(2), dim=3)
|
637 |
else:
|
638 |
warp = torch.cat((im_A_coords, im_A_to_im_B), dim=-1)
|
639 |
if batched:
|
640 |
+
return (
|
641 |
+
warp,
|
642 |
+
certainty[:, 0]
|
643 |
+
)
|
644 |
else:
|
645 |
return (
|
646 |
warp[0],
|
647 |
certainty[0, 0],
|
648 |
)
|
649 |
+
|