File size: 5,002 Bytes
9223079
e15a186
 
 
 
9223079
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
import kornia
from kornia.feature.laf import (
    laf_from_center_scale_ori,
    extract_patches_from_pyramid,
)
import numpy as np
import torch
import pycolmap

from ..utils.base_model import BaseModel


EPS = 1e-6


def sift_to_rootsift(x):
    x = x / (np.linalg.norm(x, ord=1, axis=-1, keepdims=True) + EPS)
    x = np.sqrt(x.clip(min=EPS))
    x = x / (np.linalg.norm(x, axis=-1, keepdims=True) + EPS)
    return x


class DoG(BaseModel):
    default_conf = {
        "options": {
            "first_octave": 0,
            "peak_threshold": 0.01,
        },
        "descriptor": "rootsift",
        "max_keypoints": -1,
        "patch_size": 32,
        "mr_size": 12,
    }
    required_inputs = ["image"]
    detection_noise = 1.0
    max_batch_size = 1024

    def _init(self, conf):
        if conf["descriptor"] == "sosnet":
            self.describe = kornia.feature.SOSNet(pretrained=True)
        elif conf["descriptor"] == "hardnet":
            self.describe = kornia.feature.HardNet(pretrained=True)
        elif conf["descriptor"] not in ["sift", "rootsift"]:
            raise ValueError(f'Unknown descriptor: {conf["descriptor"]}')

        self.sift = None  # lazily instantiated on the first image
        self.device = torch.device("cpu")

    def to(self, *args, **kwargs):
        device = kwargs.get("device")
        if device is None:
            match = [a for a in args if isinstance(a, (torch.device, str))]
            if len(match) > 0:
                device = match[0]
        if device is not None:
            self.device = torch.device(device)
        return super().to(*args, **kwargs)

    def _forward(self, data):
        image = data["image"]
        image_np = image.cpu().numpy()[0, 0]
        assert image.shape[1] == 1
        assert image_np.min() >= -EPS and image_np.max() <= 1 + EPS

        if self.sift is None:
            use_gpu = pycolmap.has_cuda and self.device.type == "cuda"
            options = {**self.conf["options"]}
            if self.conf["descriptor"] == "rootsift":
                options["normalization"] = pycolmap.Normalization.L1_ROOT
            else:
                options["normalization"] = pycolmap.Normalization.L2
            self.sift = pycolmap.Sift(
                options=pycolmap.SiftExtractionOptions(options),
                device=getattr(pycolmap.Device, "cuda" if use_gpu else "cpu"),
            )

        keypoints, scores, descriptors = self.sift.extract(image_np)
        scales = keypoints[:, 2]
        oris = np.rad2deg(keypoints[:, 3])

        if self.conf["descriptor"] in ["sift", "rootsift"]:
            # We still renormalize because COLMAP does not normalize well,
            # maybe due to numerical errors
            if self.conf["descriptor"] == "rootsift":
                descriptors = sift_to_rootsift(descriptors)
            descriptors = torch.from_numpy(descriptors)
        elif self.conf["descriptor"] in ("sosnet", "hardnet"):
            center = keypoints[:, :2] + 0.5
            laf_scale = scales * self.conf["mr_size"] / 2
            laf_ori = -oris
            lafs = laf_from_center_scale_ori(
                torch.from_numpy(center)[None],
                torch.from_numpy(laf_scale)[None, :, None, None],
                torch.from_numpy(laf_ori)[None, :, None],
            ).to(image.device)
            patches = extract_patches_from_pyramid(
                image, lafs, PS=self.conf["patch_size"]
            )[0]
            descriptors = patches.new_zeros((len(patches), 128))
            if len(patches) > 0:
                for start_idx in range(0, len(patches), self.max_batch_size):
                    end_idx = min(len(patches), start_idx + self.max_batch_size)
                    descriptors[start_idx:end_idx] = self.describe(
                        patches[start_idx:end_idx]
                    )
        else:
            raise ValueError(f'Unknown descriptor: {self.conf["descriptor"]}')

        keypoints = torch.from_numpy(keypoints[:, :2])  # keep only x, y
        scales = torch.from_numpy(scales)
        oris = torch.from_numpy(oris)
        scores = torch.from_numpy(scores)
        if self.conf["max_keypoints"] != -1:
            # TODO: check that the scores from PyCOLMAP are 100% correct,
            # follow https://github.com/mihaidusmanu/pycolmap/issues/8
            max_number = (
                scores.shape[0]
                if scores.shape[0] < self.conf["max_keypoints"]
                else self.conf["max_keypoints"]
            )
            values, indices = torch.topk(scores, max_number)
            keypoints = keypoints[indices]
            scales = scales[indices]
            oris = oris[indices]
            scores = scores[indices]
            descriptors = descriptors[indices]

        return {
            "keypoints": keypoints[None],
            "scales": scales[None],
            "oris": oris[None],
            "scores": scores[None],
            "descriptors": descriptors.T[None],
        }