import pytorch_lightning as pl import torch import torch.nn as nn import utils from torchvision.models import resnet50 import torch from monai.transforms import ( Compose, Resize, ResizeWithPadOrCrop, ) from pytorch_grad_cam import GradCAM import matplotlib.colors as mcolors import matplotlib.pyplot as plt import numpy as np from PIL import Image from io import BytesIO class ResNet(pl.LightningModule): def __init__(self): super().__init__() self.save_hyperparameters() backbone = resnet50() num_input_channel = 1 layer = backbone.conv1 new_layer = nn.Conv2d( in_channels=num_input_channel, out_channels=layer.out_channels, kernel_size=layer.kernel_size, stride=layer.stride, padding=layer.padding, bias=layer.bias, ) new_layer.weight = nn.Parameter(layer.weight.sum(dim=1, keepdim=True)) backbone.conv1 = new_layer backbone.fc = nn.Sequential( nn.Linear(2048, 1024), nn.ReLU(), nn.BatchNorm1d(1024), nn.Dropout(0), nn.Linear(1024, 2), ) self.model = backbone def forward(self, x): out = self.model(x) return out val_transforms_416x628 = Compose( [ utils.CustomCLAHE(), Resize(spatial_size=628, mode="bilinear", align_corners=True, size_mode="longest"), ResizeWithPadOrCrop(spatial_size=(416, 628)), ] ) checkpoint = torch.load("classification_model.ckpt", map_location=torch.device('cpu')) model = ResNet() model.load_state_dict(checkpoint["state_dict"]) model.eval() def load_and_classify_image(image_path, device): model = model.to(device) image = val_transforms_416x628(image_path) image = image.unsqueeze(0).to(device) with torch.no_grad(): prediction = model(image) prediction = torch.nn.functional.softmax(prediction, dim=1).squeeze(0) return prediction.to('cpu'), image.to('cpu') def make_GradCAM(image, device): model = model.to(device) image = image.to(device) model.eval() target_layers = [model.model.layer4[-1]] arr = image.numpy().squeeze() cam = GradCAM(model=model, target_layers=target_layers) targets = None grayscale_cam = cam( input_tensor=image, targets=targets, aug_smooth=False, eigen_smooth=True, ) grayscale_cam = grayscale_cam.to('cpu').squeeze() jet = plt.colormaps.get_cmap("inferno") newcolors = jet(np.linspace(0, 1, 256)) newcolors[0, :3] = 0 new_jet = mcolors.ListedColormap(newcolors) plt.figure(figsize=(10, 10)) plt.imshow(arr, cmap='gray') plt.imshow(grayscale_cam, cmap=new_jet, alpha=0.5) plt.axis('off') buffer2 = BytesIO() plt.savefig(buffer2, format='png', bbox_inches='tight', pad_inches=0) buffer2.seek(0) gradcam_image = np.array(Image.open(buffer2)).squeeze() return gradcam_image