Spaces:
No application file
No application file
import torch | |
import numpy as np | |
import os | |
from collections import OrderedDict, namedtuple | |
import sys | |
ROOT_DIR = os.path.abspath(os.path.join(os.path.dirname(__file__), "..")) | |
sys.path.insert(0, ROOT_DIR) | |
from sgmnet import matcher as SGM_Model | |
from superglue import matcher as SG_Model | |
from utils import evaluation_utils | |
class GNN_Matcher(object): | |
def __init__(self, config, model_name): | |
assert model_name == "SGM" or model_name == "SG" | |
config = namedtuple("config", config.keys())(*config.values()) | |
self.p_th = config.p_th | |
self.model = SGM_Model(config) if model_name == "SGM" else SG_Model(config) | |
self.model.cuda(), self.model.eval() | |
checkpoint = torch.load(os.path.join(config.model_dir, "model_best.pth")) | |
# for ddp model | |
if list(checkpoint["state_dict"].items())[0][0].split(".")[0] == "module": | |
new_stat_dict = OrderedDict() | |
for key, value in checkpoint["state_dict"].items(): | |
new_stat_dict[key[7:]] = value | |
checkpoint["state_dict"] = new_stat_dict | |
self.model.load_state_dict(checkpoint["state_dict"]) | |
def run(self, test_data): | |
norm_x1, norm_x2 = evaluation_utils.normalize_size( | |
test_data["x1"][:, :2], test_data["size1"] | |
), evaluation_utils.normalize_size(test_data["x2"][:, :2], test_data["size2"]) | |
x1, x2 = np.concatenate( | |
[norm_x1, test_data["x1"][:, 2, np.newaxis]], axis=-1 | |
), np.concatenate([norm_x2, test_data["x2"][:, 2, np.newaxis]], axis=-1) | |
feed_data = { | |
"x1": torch.from_numpy(x1[np.newaxis]).cuda().float(), | |
"x2": torch.from_numpy(x2[np.newaxis]).cuda().float(), | |
"desc1": torch.from_numpy(test_data["desc1"][np.newaxis]).cuda().float(), | |
"desc2": torch.from_numpy(test_data["desc2"][np.newaxis]).cuda().float(), | |
} | |
with torch.no_grad(): | |
res = self.model(feed_data, test_mode=True) | |
p = res["p"] | |
index1, index2 = self.match_p(p[0, :-1, :-1]) | |
corr1, corr2 = ( | |
test_data["x1"][:, :2][index1.cpu()], | |
test_data["x2"][:, :2][index2.cpu()], | |
) | |
if len(corr1.shape) == 1: | |
corr1, corr2 = corr1[np.newaxis], corr2[np.newaxis] | |
return corr1, corr2 | |
def match_p(self, p): # p N*M | |
score, index = torch.topk(p, k=1, dim=-1) | |
_, index2 = torch.topk(p, k=1, dim=-2) | |
mask_th, index, index2 = score[:, 0] > self.p_th, index[:, 0], index2.squeeze(0) | |
mask_mc = index2[index] == torch.arange(len(p)).cuda() | |
mask = mask_th & mask_mc | |
index1, index2 = torch.nonzero(mask).squeeze(1), index[mask] | |
return index1, index2 | |
class NN_Matcher(object): | |
def __init__(self, config): | |
config = namedtuple("config", config.keys())(*config.values()) | |
self.mutual_check = config.mutual_check | |
self.ratio_th = config.ratio_th | |
def run(self, test_data): | |
desc1, desc2, x1, x2 = ( | |
test_data["desc1"], | |
test_data["desc2"], | |
test_data["x1"], | |
test_data["x2"], | |
) | |
desc_mat = np.sqrt( | |
abs( | |
(desc1**2).sum(-1)[:, np.newaxis] | |
+ (desc2**2).sum(-1)[np.newaxis] | |
- 2 * desc1 @ desc2.T | |
) | |
) | |
nn_index = np.argpartition(desc_mat, kth=(1, 2), axis=-1) | |
dis_value12 = np.take_along_axis(desc_mat, nn_index, axis=-1) | |
ratio_score = dis_value12[:, 0] / dis_value12[:, 1] | |
nn_index1 = nn_index[:, 0] | |
nn_index2 = np.argmin(desc_mat, axis=0) | |
mask_ratio, mask_mutual = ( | |
ratio_score < self.ratio_th, | |
np.arange(len(x1)) == nn_index2[nn_index1], | |
) | |
corr1, corr2 = x1[:, :2], x2[:, :2][nn_index1] | |
if self.mutual_check: | |
mask = mask_ratio & mask_mutual | |
else: | |
mask = mask_ratio | |
corr1, corr2 = corr1[mask], corr2[mask] | |
return corr1, corr2 | |