import os from typing import List, Tuple, Optional, Dict import argparse from PIL import Image import numpy as np import torch import torch.nn.functional from matplotlib import pyplot as plt from atoms_detection.dataset import CoordinatesDataset from atoms_detection.image_preprocessing import dl_prepro_image from atoms_detection.model import BasicCNN from utils.constants import ModelArgs, Split from utils.paths import ACTIVATIONS_VIS_PATH class ConvLayerVisualizer: CONV_0 = 'Conv0' CONV_3 = 'Conv3' CONV_6 = 'Conv6' def __init__(self, model_name: ModelArgs, ckpt_filename: str): self.model_name = model_name self.ckpt_filename = ckpt_filename self.device = self.get_torch_device() self.batch_size = 64 self.stride = 1 self.padding = 10 self.window_size = (21, 21) @staticmethod def get_torch_device(): use_cuda = torch.cuda.is_available() device = torch.device("cuda" if use_cuda else "cpu") return device def sliding_window(self, image: np.ndarray) -> Tuple[int, int, np.ndarray]: # slide a window across the image x_to_center = self.window_size[0] // 2 - 1 if self.window_size[0] % 2 == 0 else self.window_size[0] // 2 y_to_center = self.window_size[1] // 2 - 1 if self.window_size[1] % 2 == 0 else self.window_size[1] // 2 for y in range(0, image.shape[0] - self.window_size[1]+1, self.stride): for x in range(0, image.shape[1] - self.window_size[0]+1, self.stride): # yield the current window center_x = x + x_to_center center_y = y + y_to_center yield center_x, center_y, image[y:y + self.window_size[1], x:x + self.window_size[0]] def padding_image(self, img: np.ndarray) -> np.ndarray: image_padded = np.zeros((img.shape[0] + self.padding*2, img.shape[1] + self.padding*2)) image_padded[self.padding:-self.padding, self.padding:-self.padding] = img return image_padded def images_to_torch_input(self, image: np.ndarray) -> torch.Tensor: expanded_img = np.expand_dims(image, axis=(0, 1)) input_tensor = torch.from_numpy(expanded_img).float() input_tensor = input_tensor.to(self.device) return input_tensor def load_model(self) -> BasicCNN: checkpoint = torch.load(self.ckpt_filename, map_location=self.device) model = BasicCNN(num_classes=2).to(self.device) model.load_state_dict(checkpoint['state_dict']) model.eval() return model @staticmethod def center_to_slice(x_center: int, y_center: int, width: int, height: int) -> Tuple[slice, slice]: x_to_center = width // 2 - 1 if width % 2 == 0 else width // 2 y_to_center = height // 2 - 1 if height % 2 == 0 else height // 2 x = x_center - x_to_center y = y_center - y_to_center return slice(x, x + width), slice(y, y + height) def get_prediction_map(self, padded_image: np.ndarray) -> Dict[str, np.ndarray]: _shape = padded_image.shape convs_activations_dict = { self.CONV_0: (np.zeros(_shape), np.zeros(_shape)), self.CONV_3: (np.zeros(_shape), np.zeros(_shape)), self.CONV_6: (np.zeros(_shape), np.zeros(_shape)) } model = self.load_model() for x, y, image_crop in self.sliding_window(padded_image): torch_input = self.images_to_torch_input(image_crop) conv_outputs = self.get_conv_activations(torch_input, model) for conv_layer_key, activations_blob in conv_outputs.items(): activation_map = self.sum_channels(activations_blob) h, w = activation_map.shape x_slice, y_slice = self.center_to_slice(x, y, w, h) convs_activations_dict[conv_layer_key][0][y_slice, x_slice] += 1 convs_activations_dict[conv_layer_key][1][y_slice, x_slice] += activation_map activations_dict = {} for conv_layer_key, (counting_map, output_map) in convs_activations_dict.items(): zero_rows = np.sum(counting_map, axis=1) zero_cols = np.sum(counting_map, axis=0) output_map = np.delete(output_map, np.where(zero_rows == 0), axis=0) clean_output_map = np.delete(output_map, np.where(zero_cols == 0), axis=1) counting_map = np.delete(counting_map, np.where(zero_rows == 0), axis=0) clean_counting_map = np.delete(counting_map, np.where(zero_cols == 0), axis=1) activations_dict[conv_layer_key] = clean_output_map / clean_counting_map return activations_dict def get_conv_activations(self, input_image: torch.Tensor, model: BasicCNN) -> Dict[str, np.ndarray]: conv_activations = {} activations = input_image for i, layer in enumerate(model.features): activations = layer(activations) if i == 0: conv_activations[self.CONV_0] = activations.squeeze(0).detach().cpu().numpy() elif i == 3: conv_activations[self.CONV_3] = activations.squeeze(0).detach().cpu().numpy() elif i == 6: conv_activations[self.CONV_6] = activations.squeeze(0).detach().cpu().numpy() return conv_activations @staticmethod def sum_channels(activations: np.ndarray): aggregated_activations = np.sum(activations, axis=0) return aggregated_activations def image_to_pred_map(self, img: np.ndarray) -> Dict[str, np.ndarray]: preprocessed_img = dl_prepro_image(img) padded_image = self.padding_image(preprocessed_img) activations_dict = self.get_prediction_map(padded_image) return activations_dict def get_args(): parser = argparse.ArgumentParser() parser.add_argument( "architecture", type=ModelArgs, choices=ModelArgs, help="Architecture name" ) parser.add_argument( "ckpt_filename", type=str, help="Path to model checkpoint" ) parser.add_argument( "coords_csv", type=str, help="Coordinates CSV file to use as input" ) return parser.parse_args() if __name__ == "__main__": args = get_args() print(args) conv_visualizer = ConvLayerVisualizer( model_name=args.architecture, ckpt_filename=args.ckpt_filename ) coordinates_dataset = CoordinatesDataset(args.coords_csv) for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST): img = Image.open(image_path) np_img = np.array(img) activations_dict = conv_visualizer.image_to_pred_map(np_img) img_name = os.path.splitext(os.path.basename(image_path))[0] output_folder = os.path.join(ACTIVATIONS_VIS_PATH, f"{img_name}") if not os.path.exists(output_folder): os.makedirs(output_folder) for conv_layer_key, activation_map in activations_dict.items(): fig = plt.figure() plt.title(f"{conv_layer_key} -- {img_name}") plt.imshow(activation_map) output_path = os.path.join(output_folder, f"{conv_layer_key}_{img_name}.png") plt.savefig(output_path, bbox_inches='tight') plt.close(fig)