|
|
|
import cv2 |
|
import torch |
|
|
|
import numpy as np |
|
from PIL import Image |
|
from typing import List, Callable, Optional |
|
from functools import partial |
|
|
|
from pytorch_grad_cam import GradCAM |
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
|
|
|
|
""" Model wrapper to return a tensor""" |
|
class HuggingfaceToTensorModelWrapper(torch.nn.Module): |
|
def __init__(self, model): |
|
super(HuggingfaceToTensorModelWrapper, self).__init__() |
|
self.model = model |
|
|
|
def forward(self, x): |
|
return self.model(x).logits |
|
|
|
|
|
class ClassActivationMap(object): |
|
def __init__(self, model, processor): |
|
self.model = HuggingfaceToTensorModelWrapper(model) |
|
target_layer = model.swinv2.layernorm |
|
self.target_layer = [target_layer] |
|
self.processor = processor |
|
|
|
def swinT_reshape_transform_huggingface(self, tensor, width, height): |
|
result = tensor.reshape(tensor.size(0), |
|
height, |
|
width, |
|
tensor.size(2)) |
|
result = result.transpose(2, 3).transpose(1, 2) |
|
return result |
|
|
|
def run_grad_cam_on_image(self, |
|
targets_for_gradcam: List[Callable], |
|
reshape_transform: Optional[Callable], |
|
input_tensor: torch.nn.Module, |
|
input_image: Image, |
|
method: Callable=GradCAM): |
|
with method(model=self.model, |
|
target_layers=self.target_layer, |
|
reshape_transform=reshape_transform) as cam: |
|
|
|
|
|
|
|
repeated_tensor = input_tensor[None, :].repeat(len(targets_for_gradcam), 1, 1, 1) |
|
|
|
|
|
batch_results = cam(input_tensor=repeated_tensor, |
|
targets=targets_for_gradcam) |
|
results = [] |
|
for grayscale_cam in batch_results: |
|
visualization = show_cam_on_image(np.float32(input_image) / 255, |
|
grayscale_cam, |
|
use_rgb=True) |
|
|
|
visualization = cv2.resize(visualization, |
|
(visualization.shape[1] // 1, visualization.shape[0] // 1)) |
|
results.append(visualization) |
|
return np.hstack(results) |
|
|
|
def get_cam(self, image, category_id): |
|
image = Image.fromarray(image).resize((self.processor.size['height'], self.processor.size['width'])) |
|
img_tensor = self.processor(images=image, return_tensors="pt")['pixel_values'].squeeze() |
|
targets_for_gradcam = [ClassifierOutputTarget(category_id)] |
|
reshape_transform = partial(self.swinT_reshape_transform_huggingface, |
|
width=img_tensor.shape[2] // 32, |
|
height=img_tensor.shape[1] // 32) |
|
cam = self.run_grad_cam_on_image(input_tensor=img_tensor, |
|
input_image=image, |
|
targets_for_gradcam=targets_for_gradcam, |
|
reshape_transform=reshape_transform) |
|
|
|
return cam |