File size: 1,927 Bytes
a80d6bb
 
 
 
 
 
 
 
 
c74a070
 
a80d6bb
 
 
 
 
 
 
c74a070
 
 
 
a80d6bb
 
 
 
 
 
 
 
c74a070
 
 
a80d6bb
 
 
 
 
 
c74a070
a80d6bb
 
 
 
 
 
 
c74a070
 
 
 
a80d6bb
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
import torch
from torch import nn


class NN2(nn.Module):
    def __init__(self):
        super().__init__()

    def forward(self, data):
        desc1, desc2 = data["descriptors0"].cuda(), data["descriptors1"].cuda()
        kpts1, kpts2 = data["keypoints0"].cuda(), data["keypoints1"].cuda()

        # torch.cuda.synchronize()
        # t = time.time()

        if kpts1.shape[1] <= 1 or kpts2.shape[1] <= 1:  # no keypoints
            shape0, shape1 = kpts1.shape[:-1], kpts2.shape[:-1]
            return {
                "matches0": kpts1.new_full(shape0, -1, dtype=torch.int),
                "matches1": kpts2.new_full(shape1, -1, dtype=torch.int),
                "matching_scores0": kpts1.new_zeros(shape0),
                "matching_scores1": kpts2.new_zeros(shape1),
            }

        sim = torch.matmul(desc1.squeeze().T, desc2.squeeze())
        ids1 = torch.arange(0, sim.shape[0], device=desc1.device)
        nn12 = torch.argmax(sim, dim=1)

        nn21 = torch.argmax(sim, dim=0)
        mask = torch.eq(ids1, nn21[nn12])
        matches = torch.stack(
            [torch.masked_select(ids1, mask), torch.masked_select(nn12, mask)]
        )
        # matches = torch.stack([ids1, nn12])
        indices0 = torch.ones((1, desc1.shape[-1]), dtype=int) * -1
        mscores0 = torch.ones((1, desc1.shape[-1]), dtype=float) * -1

        # torch.cuda.synchronize()
        # print(time.time() - t)

        matches_0 = matches[0].cpu().int().numpy()
        matches_1 = matches[1].cpu().int()
        for i in range(matches.shape[-1]):
            indices0[0, matches_0[i]] = matches_1[i].int()
            mscores0[0, matches_0[i]] = sim[matches_0[i], matches_1[i]]

        return {
            "matches0": indices0,  # use -1 for invalid match
            "matches1": indices0,  # use -1 for invalid match
            "matching_scores0": mscores0,
            "matching_scores1": mscores0,
        }