|
import gradio as gr |
|
import timm |
|
import torch |
|
from PIL import Image |
|
import requests |
|
from io import BytesIO |
|
import numpy as np |
|
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad |
|
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget |
|
from pytorch_grad_cam.utils.image import show_cam_on_image |
|
from timm.data import create_transform |
|
from timm.data import infer_imagenet_subset, ImageNetInfo |
|
|
|
|
|
MODELS = timm.list_pretrained() |
|
|
|
|
|
CAM_METHODS = { |
|
"GradCAM": GradCAM, |
|
"HiResCAM": HiResCAM, |
|
"ScoreCAM": ScoreCAM, |
|
"GradCAM++": GradCAMPlusPlus, |
|
"AblationCAM": AblationCAM, |
|
"XGradCAM": XGradCAM, |
|
"EigenCAM": EigenCAM, |
|
"FullGrad": FullGrad |
|
} |
|
|
|
class CustomDatasetInfo: |
|
def __init__(self, label_names, label_descriptions=None): |
|
self.label_names = label_names |
|
self.label_descriptions = label_descriptions or label_names |
|
|
|
def index_to_description(self, index, detailed=False): |
|
if detailed and self.label_descriptions: |
|
return self.label_descriptions[index] |
|
return self.label_names[index] |
|
|
|
def load_model(model_name): |
|
model = timm.create_model(model_name, pretrained=True) |
|
model.eval() |
|
return model |
|
|
|
def process_image(image_path, model): |
|
if image_path.startswith('http'): |
|
response = requests.get(image_path) |
|
image = Image.open(BytesIO(response.content)) |
|
else: |
|
image = Image.open(image_path) |
|
|
|
config = model.pretrained_cfg |
|
transform = create_transform( |
|
input_size=config['input_size'], |
|
crop_pct=config['crop_pct'], |
|
mean=config['mean'], |
|
std=config['std'], |
|
interpolation=config['interpolation'], |
|
is_training=False |
|
) |
|
|
|
tensor = transform(image).unsqueeze(0) |
|
return tensor |
|
|
|
def get_cam_image(model, image, target_layer, cam_method, target_class): |
|
if target_class is not None and target_class != "highest scoring": |
|
target = ClassifierOutputTarget(target_class) |
|
else: |
|
target = None |
|
|
|
cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer]) |
|
grayscale_cam = cam(input_tensor=image, targets=[target] if target else None) |
|
|
|
config = model.pretrained_cfg |
|
mean = torch.tensor(config['mean']).view(3, 1, 1) |
|
std = torch.tensor(config['std']).view(3, 1, 1) |
|
rgb_img = (image.squeeze(0) * std + mean).permute(1, 2, 0).cpu().numpy() |
|
rgb_img = np.clip(rgb_img, 0, 1) |
|
|
|
cam_image = show_cam_on_image(rgb_img, grayscale_cam[0, :], use_rgb=True) |
|
return Image.fromarray(cam_image) |
|
|
|
def get_feature_info(model): |
|
if hasattr(model, 'feature_info'): |
|
return [f['module'] for f in model.feature_info] |
|
else: |
|
return [] |
|
|
|
def get_target_layer(model, target_layer_name): |
|
if target_layer_name is None: |
|
return None |
|
|
|
try: |
|
return model.get_submodule(target_layer_name) |
|
except AttributeError: |
|
print(f"WARNING: Layer '{target_layer_name}' not found in the model.") |
|
return None |
|
|
|
def get_class_names(model): |
|
dataset_info = None |
|
label_names = model.pretrained_cfg.get("label_names", None) |
|
label_descriptions = model.pretrained_cfg.get("label_descriptions", None) |
|
if label_names is None: |
|
imagenet_subset = infer_imagenet_subset(model) |
|
if imagenet_subset: |
|
dataset_info = ImageNetInfo(imagenet_subset) |
|
else: |
|
label_names = [f"LABEL_{i}" for i in range(model.num_classes)] |
|
if dataset_info is None: |
|
dataset_info = CustomDatasetInfo( |
|
label_names=label_names, |
|
label_descriptions=label_descriptions, |
|
) |
|
return dataset_info |
|
|
|
def explain_image(model_name, image_path, cam_method, feature_module, target_class): |
|
model = load_model(model_name) |
|
image = process_image(image_path, model) |
|
|
|
target_layer = get_target_layer(model, feature_module) |
|
|
|
if target_layer is None: |
|
feature_info = get_feature_info(model) |
|
if feature_info: |
|
target_layer = get_target_layer(model, feature_info[-1]) |
|
print(f"Using last feature module: {feature_info[-1]}") |
|
else: |
|
for name, module in reversed(list(model.named_modules())): |
|
if isinstance(module, torch.nn.Conv2d): |
|
target_layer = module |
|
print(f"Fallback: Using last convolutional layer: {name}") |
|
break |
|
|
|
if target_layer is None: |
|
raise ValueError("Could not find a suitable target layer.") |
|
|
|
target_class_index = None if target_class == "highest scoring" else int(target_class.split(':')[0]) |
|
cam_image = get_cam_image(model, image, target_layer, cam_method, target_class_index) |
|
|
|
with torch.no_grad(): |
|
out = model(image) |
|
probabilities = out.squeeze(0).softmax(dim=0) |
|
values, indices = torch.topk(probabilities, 5) |
|
dataset_info = get_class_names(model) |
|
labels = [ |
|
f"{i}: {dataset_info.index_to_description(i.item(), detailed=True)} ({v.item():.2%})" |
|
for i, v in zip(indices, values) |
|
] |
|
|
|
return cam_image, "\n".join(labels) |
|
|
|
def update_feature_modules(model_name): |
|
model = load_model(model_name) |
|
feature_modules = get_feature_info(model) |
|
return gr.Dropdown(choices=feature_modules, value=feature_modules[-1] if feature_modules else None) |
|
|
|
def update_class_dropdown(model_name): |
|
model = load_model(model_name) |
|
dataset_info = get_class_names(model) |
|
class_names = ["highest scoring"] + [f"{i}: {dataset_info.index_to_description(i, detailed=True)}" for i in range(model.num_classes)] |
|
return gr.Dropdown(choices=class_names, value="highest scoring") |
|
|
|
with gr.Blocks() as demo: |
|
gr.Markdown("# Explainable AI with timm models. NOTE: This is a WIP but some models are functioning.") |
|
gr.Markdown("Upload an image, select a model, CAM method, and optionally a specific feature module and target class to visualize the explanation.") |
|
|
|
with gr.Row(): |
|
with gr.Column(): |
|
model_dropdown = gr.Dropdown(choices=MODELS, label="Select Model") |
|
image_input = gr.Image(type="filepath", label="Upload Image") |
|
cam_method_dropdown = gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method") |
|
feature_module_dropdown = gr.Dropdown(label="Select Feature Module (optional)") |
|
class_dropdown = gr.Dropdown(label="Select Target Class (optional)") |
|
explain_button = gr.Button("Explain Image") |
|
|
|
with gr.Column(): |
|
output_image = gr.Image(type="pil", label="Explained Image") |
|
prediction_text = gr.Textbox(label="Top 5 Predictions") |
|
|
|
model_dropdown.change(fn=update_feature_modules, inputs=[model_dropdown], outputs=[feature_module_dropdown]) |
|
model_dropdown.change(fn=update_class_dropdown, inputs=[model_dropdown], outputs=[class_dropdown]) |
|
|
|
explain_button.click( |
|
fn=explain_image, |
|
inputs=[model_dropdown, image_input, cam_method_dropdown, feature_module_dropdown, class_dropdown], |
|
outputs=[output_image, prediction_text] |
|
) |
|
|
|
demo.launch() |