Spaces:
Sleeping
Sleeping
import os | |
from typing import List, Tuple, Optional, Dict | |
import argparse | |
from PIL import Image | |
import numpy as np | |
import torch | |
import torch.nn.functional | |
from matplotlib import pyplot as plt | |
from atoms_detection.dataset import CoordinatesDataset | |
from atoms_detection.image_preprocessing import dl_prepro_image | |
from atoms_detection.model import BasicCNN | |
from utils.constants import ModelArgs, Split | |
from utils.paths import ACTIVATIONS_VIS_PATH | |
class ConvLayerVisualizer: | |
CONV_0 = 'Conv0' | |
CONV_3 = 'Conv3' | |
CONV_6 = 'Conv6' | |
def __init__(self, model_name: ModelArgs, ckpt_filename: str): | |
self.model_name = model_name | |
self.ckpt_filename = ckpt_filename | |
self.device = self.get_torch_device() | |
self.batch_size = 64 | |
self.stride = 1 | |
self.padding = 10 | |
self.window_size = (21, 21) | |
def get_torch_device(): | |
use_cuda = torch.cuda.is_available() | |
device = torch.device("cuda" if use_cuda else "cpu") | |
return device | |
def sliding_window(self, image: np.ndarray) -> Tuple[int, int, np.ndarray]: | |
# slide a window across the image | |
x_to_center = self.window_size[0] // 2 - 1 if self.window_size[0] % 2 == 0 else self.window_size[0] // 2 | |
y_to_center = self.window_size[1] // 2 - 1 if self.window_size[1] % 2 == 0 else self.window_size[1] // 2 | |
for y in range(0, image.shape[0] - self.window_size[1]+1, self.stride): | |
for x in range(0, image.shape[1] - self.window_size[0]+1, self.stride): | |
# yield the current window | |
center_x = x + x_to_center | |
center_y = y + y_to_center | |
yield center_x, center_y, image[y:y + self.window_size[1], x:x + self.window_size[0]] | |
def padding_image(self, img: np.ndarray) -> np.ndarray: | |
image_padded = np.zeros((img.shape[0] + self.padding*2, img.shape[1] + self.padding*2)) | |
image_padded[self.padding:-self.padding, self.padding:-self.padding] = img | |
return image_padded | |
def images_to_torch_input(self, image: np.ndarray) -> torch.Tensor: | |
expanded_img = np.expand_dims(image, axis=(0, 1)) | |
input_tensor = torch.from_numpy(expanded_img).float() | |
input_tensor = input_tensor.to(self.device) | |
return input_tensor | |
def load_model(self) -> BasicCNN: | |
checkpoint = torch.load(self.ckpt_filename, map_location=self.device) | |
model = BasicCNN(num_classes=2).to(self.device) | |
model.load_state_dict(checkpoint['state_dict']) | |
model.eval() | |
return model | |
def center_to_slice(x_center: int, y_center: int, width: int, height: int) -> Tuple[slice, slice]: | |
x_to_center = width // 2 - 1 if width % 2 == 0 else width // 2 | |
y_to_center = height // 2 - 1 if height % 2 == 0 else height // 2 | |
x = x_center - x_to_center | |
y = y_center - y_to_center | |
return slice(x, x + width), slice(y, y + height) | |
def get_prediction_map(self, padded_image: np.ndarray) -> Dict[str, np.ndarray]: | |
_shape = padded_image.shape | |
convs_activations_dict = { | |
self.CONV_0: (np.zeros(_shape), np.zeros(_shape)), | |
self.CONV_3: (np.zeros(_shape), np.zeros(_shape)), | |
self.CONV_6: (np.zeros(_shape), np.zeros(_shape)) | |
} | |
model = self.load_model() | |
for x, y, image_crop in self.sliding_window(padded_image): | |
torch_input = self.images_to_torch_input(image_crop) | |
conv_outputs = self.get_conv_activations(torch_input, model) | |
for conv_layer_key, activations_blob in conv_outputs.items(): | |
activation_map = self.sum_channels(activations_blob) | |
h, w = activation_map.shape | |
x_slice, y_slice = self.center_to_slice(x, y, w, h) | |
convs_activations_dict[conv_layer_key][0][y_slice, x_slice] += 1 | |
convs_activations_dict[conv_layer_key][1][y_slice, x_slice] += activation_map | |
activations_dict = {} | |
for conv_layer_key, (counting_map, output_map) in convs_activations_dict.items(): | |
zero_rows = np.sum(counting_map, axis=1) | |
zero_cols = np.sum(counting_map, axis=0) | |
output_map = np.delete(output_map, np.where(zero_rows == 0), axis=0) | |
clean_output_map = np.delete(output_map, np.where(zero_cols == 0), axis=1) | |
counting_map = np.delete(counting_map, np.where(zero_rows == 0), axis=0) | |
clean_counting_map = np.delete(counting_map, np.where(zero_cols == 0), axis=1) | |
activations_dict[conv_layer_key] = clean_output_map / clean_counting_map | |
return activations_dict | |
def get_conv_activations(self, input_image: torch.Tensor, model: BasicCNN) -> Dict[str, np.ndarray]: | |
conv_activations = {} | |
activations = input_image | |
for i, layer in enumerate(model.features): | |
activations = layer(activations) | |
if i == 0: | |
conv_activations[self.CONV_0] = activations.squeeze(0).detach().cpu().numpy() | |
elif i == 3: | |
conv_activations[self.CONV_3] = activations.squeeze(0).detach().cpu().numpy() | |
elif i == 6: | |
conv_activations[self.CONV_6] = activations.squeeze(0).detach().cpu().numpy() | |
return conv_activations | |
def sum_channels(activations: np.ndarray): | |
aggregated_activations = np.sum(activations, axis=0) | |
return aggregated_activations | |
def image_to_pred_map(self, img: np.ndarray) -> Dict[str, np.ndarray]: | |
preprocessed_img = dl_prepro_image(img) | |
padded_image = self.padding_image(preprocessed_img) | |
activations_dict = self.get_prediction_map(padded_image) | |
return activations_dict | |
def get_args(): | |
parser = argparse.ArgumentParser() | |
parser.add_argument( | |
"architecture", | |
type=ModelArgs, | |
choices=ModelArgs, | |
help="Architecture name" | |
) | |
parser.add_argument( | |
"ckpt_filename", | |
type=str, | |
help="Path to model checkpoint" | |
) | |
parser.add_argument( | |
"coords_csv", | |
type=str, | |
help="Coordinates CSV file to use as input" | |
) | |
return parser.parse_args() | |
if __name__ == "__main__": | |
args = get_args() | |
print(args) | |
conv_visualizer = ConvLayerVisualizer( | |
model_name=args.architecture, | |
ckpt_filename=args.ckpt_filename | |
) | |
coordinates_dataset = CoordinatesDataset(args.coords_csv) | |
for image_path, coordinates_path in coordinates_dataset.iterate_data(Split.TEST): | |
img = Image.open(image_path) | |
np_img = np.array(img) | |
activations_dict = conv_visualizer.image_to_pred_map(np_img) | |
img_name = os.path.splitext(os.path.basename(image_path))[0] | |
output_folder = os.path.join(ACTIVATIONS_VIS_PATH, f"{img_name}") | |
if not os.path.exists(output_folder): | |
os.makedirs(output_folder) | |
for conv_layer_key, activation_map in activations_dict.items(): | |
fig = plt.figure() | |
plt.title(f"{conv_layer_key} -- {img_name}") | |
plt.imshow(activation_map) | |
output_path = os.path.join(output_folder, f"{conv_layer_key}_{img_name}.png") | |
plt.savefig(output_path, bbox_inches='tight') | |
plt.close(fig) | |