File size: 8,060 Bytes
92f0e98
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
from typing import Callable, Iterable, Tuple
import torch
import numpy as np
import PIL.Image
import cv2
import wandb

from tqdm import tqdm
from pytorch_grad_cam import GradCAM

from utils.val_loop_hook import ValidationLoopHook

def _get_grad_cam_target(model):
    """
    Determines the appropriate GradCAM target.
    """

    # very naive check
    if hasattr(model, "features"):
        return getattr(model, "features")

    pooling = [torch.nn.AdaptiveAvgPool1d, torch.nn.AvgPool1d, torch.nn.MaxPool1d, torch.nn.AdaptiveMaxPool1d,
               torch.nn.AdaptiveAvgPool2d, torch.nn.AvgPool2d, torch.nn.MaxPool2d, torch.nn.AdaptiveMaxPool2d,
               torch.nn.AdaptiveAvgPool3d, torch.nn.AvgPool3d, torch.nn.MaxPool3d, torch.nn.AdaptiveMaxPool3d]
    convolutions = [torch.nn.Conv1d, torch.nn.Conv2d, torch.nn.Conv3d]

    # reverse search starting from the final module
    inverted_modules = list(model.modules())[::-1]
    for i, module in enumerate(inverted_modules):
        if any([isinstance(module, po) for po in pooling]):
            # if a pooling layer was hit, pick the module directly before it
            return inverted_modules[i+1]
        elif any([isinstance(module, co) for co in convolutions]):
            # if a convolution was hit (but no pooling layer), pick that one instead
            return module
        elif isinstance(module, torch.nn.Sequential):
            # if a sequential module is hit, explore it
            for child in list(module.children())[::-1]:
                sequential_result = _get_grad_cam_target(child)
                if sequential_result is not None:
                    return sequential_result

def _show_cam_on_image(img: np.ndarray, mask: np.ndarray, use_rgb: bool = False, colormap: int = cv2.COLORMAP_JET) -> np.ndarray:
        """ This function overlays the cam mask on the image as an heatmap.
        By default the heatmap is in BGR format.
        :param img: The base image in RGB or BGR format.
        :param mask: The cam mask.
        :param use_rgb: Whether to use an RGB or BGR heatmap, this should be set to True if 'img' is in RGB format.
        :param colormap: The OpenCV colormap to be used.
        :returns: The default image with the cam overlay.
        """
        heatmap = cv2.applyColorMap(np.uint8(255 * mask), colormap)
        if use_rgb:
            heatmap = cv2.cvtColor(heatmap, cv2.COLOR_BGR2RGB)
        heatmap = np.float32(heatmap) / 255

        normalize = lambda x: (x - np.min(x))/np.ptp(x)

        cam = 0.6 * heatmap + normalize(img)
        cam = cam / np.max(cam)
        return np.uint8(255 * cam)

def _strip_image_from_grid_row(row, gap=5, bg=255):
    strip = torch.full(
            (row.shape[0] * (row.shape[3] + gap) - gap,
             row.shape[1] * (row.shape[3] + gap) - gap,
             row.shape[4]), bg, dtype=row.dtype)
    for i in range(0, row.shape[0] * row.shape[1]):
        strip[(i // row.shape[1]) * (row.shape[2] + gap) : ((i // row.shape[1])+1) * (row.shape[2] + gap) - gap,
              (i % row.shape[1]) * (row.shape[3] + gap) : ((i % row.shape[1])+1) * (row.shape[3] + gap) - gap,
              :] = row[i // row.shape[1]][i % row.shape[1]]
    return PIL.Image.fromarray(strip.numpy())


class GradCAMBuilder(ValidationLoopHook):
    def __init__(self, image_shape: Iterable[int], target_category: int = None, num_images: int = 5):
        self.image_shape = image_shape
        self.target_category = target_category
        self.num_images = num_images

        self.targets = torch.zeros(self.num_images)
        self.activations = torch.zeros(self.num_images)
        self.images = torch.zeros(torch.Size([self.num_images]) + torch.Size(self.image_shape))

    def process(self, batch, target_batch, logits_batch, prediction_batch):
        image_batch = batch["image"]

        with torch.no_grad():
            if self.target_category is not None:
                local_activations = logits_batch[:, self.target_category]
            else:
                local_activations = torch.amax(logits_batch, dim=-1)

        # filter samples where the prediction lines up with the target
        target_match = (prediction_batch == target_batch)

        # filter public dataset samples
        public = torch.tensor(["verse" in id for id in batch["verse_id"]]).type_as(target_match)

        mask = target_match & public

        if torch.max(mask) == False:
            # no samples match criteria in this batch, skip
            return

        # identify better activations and replace them accordingly
        local_top_idx = torch.argsort(local_activations, descending=True)
        # filter samples
        local_top_idx = local_top_idx[mask[local_top_idx]]
        current_idx = 0

        while current_idx < self.num_images and local_activations[local_top_idx[current_idx]] > torch.min(self.activations):
            # next item in local batch matches criteria and has a higher activation than one in the global batch, replace it
            idx_to_replace = torch.argsort(self.activations)[0]

            self.activations[idx_to_replace] = local_activations[local_top_idx[current_idx]]
            self.images[idx_to_replace] = image_batch[local_top_idx[current_idx]]
            self.targets[idx_to_replace] = target_batch[local_top_idx[current_idx]]

            current_idx += 1

    def trigger(self, module):
        model = module.backbone
        module.eval()

        # determine the Grad-CAM target module/layer
        grad_cam_target = _get_grad_cam_target(model)

        cam = GradCAM(model, [grad_cam_target], use_cuda=torch.cuda.is_available())

        # determine final order such that the highest activations are placed on top
        sorted_idx = torch.argsort(self.activations, descending=True)

        self.activations = self.activations[sorted_idx]
        self.images = self.images[sorted_idx]
        self.targets = self.targets[sorted_idx]

        # if a polyaxon experiment crashes here, remove the GradCAMBuilder instance from the
        # model.validation_hooks list
        grad_cams = cam(input_tensor=self.images, target_category=self.target_category)

        module.train()

        if len(self.images.shape) == 5:
            # 3D, visualize slices
            ld_res = grad_cams.shape[-1]
            img_res = self.images.shape[-1]
            img_slices = torch.linspace(int(img_res/ld_res/2), img_res-int(img_res/ld_res/2), ld_res, dtype=torch.long)
            
            # Show all images slices in a larger combined image
            grad_cams_image = _strip_image_from_grid_row(
                torch.stack([
                    torch.stack([
                        torch.tensor(
                            _show_cam_on_image((self.images[i, 0, ..., img_slices[s]]).unsqueeze(-1).repeat(1, 1, 3).numpy(), grad_cams[i, ..., s], use_rgb=True)
                        )
                    for s in range(grad_cams.shape[-1])])
                for i in range(self.num_images if self.num_images < grad_cams.shape[0] else grad_cams.shape[0])])
            )

        elif len(self.images.shape) == 4:
            # 2D
            grad_cams_image = _strip_image_from_grid_row(
                torch.stack([
                    torch.stack([
                        torch.tensor(
                            _show_cam_on_image((self.images[i, 0, ...]).unsqueeze(-1).repeat(1, 1, 3).numpy(), grad_cams[i, ...], use_rgb=True)
                        )
                    ])
                for i in range(self.num_images if self.num_images < grad_cams.shape[0] else grad_cams.shape[0])])
            )

        else:
            raise RuntimeError("Attempting to build Grad-CAMs for data that is neither 2D nor 3D")

        module.logger.experiment.log({
            "val/grad_cam": wandb.Image(grad_cams_image)
        })

    def reset(self):
        self.targets = torch.zeros(self.num_images)
        self.activations = torch.zeros(self.num_images)
        self.images = torch.zeros(torch.Size([self.num_images]) + torch.Size(self.image_shape))