import gradio as gr import timm import torch from PIL import Image import requests from io import BytesIO import numpy as np from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget from pytorch_grad_cam.utils.image import show_cam_on_image from timm.data import create_transform from timm.data import infer_imagenet_subset, ImageNetInfo # List of available timm models MODELS = timm.list_pretrained() # List of available GradCAM methods CAM_METHODS = { "GradCAM": GradCAM, "HiResCAM": HiResCAM, "ScoreCAM": ScoreCAM, "GradCAM++": GradCAMPlusPlus, "AblationCAM": AblationCAM, "XGradCAM": XGradCAM, "EigenCAM": EigenCAM, "FullGrad": FullGrad } class CustomDatasetInfo: def __init__(self, label_names, label_descriptions=None): self.label_names = label_names self.label_descriptions = label_descriptions or label_names def index_to_description(self, index, detailed=False): if detailed and self.label_descriptions: return self.label_descriptions[index] return self.label_names[index] def load_model(model_name): model = timm.create_model(model_name, pretrained=True) model.eval() return model def process_image(image_path, model): if image_path.startswith('http'): response = requests.get(image_path) image = Image.open(BytesIO(response.content)) else: image = Image.open(image_path) config = model.pretrained_cfg transform = create_transform( input_size=config['input_size'], crop_pct=config['crop_pct'], mean=config['mean'], std=config['std'], interpolation=config['interpolation'], is_training=False ) tensor = transform(image).unsqueeze(0) return tensor def get_cam_image(model, image, target_layer, cam_method, target_class): if target_class is not None and target_class != "highest scoring": target = ClassifierOutputTarget(target_class) else: target = None cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer]) grayscale_cam = cam(input_tensor=image, targets=[target] if target else None) config = model.pretrained_cfg mean = torch.tensor(config['mean']).view(3, 1, 1) std = torch.tensor(config['std']).view(3, 1, 1) rgb_img = (image.squeeze(0) * std + mean).permute(1, 2, 0).cpu().numpy() rgb_img = np.clip(rgb_img, 0, 1) cam_image = show_cam_on_image(rgb_img, grayscale_cam[0, :], use_rgb=True) return Image.fromarray(cam_image) def get_feature_info(model): if hasattr(model, 'feature_info'): return [f['module'] for f in model.feature_info] else: return [] def get_target_layer(model, target_layer_name): if target_layer_name is None: return None try: return model.get_submodule(target_layer_name) except AttributeError: print(f"WARNING: Layer '{target_layer_name}' not found in the model.") return None def get_class_names(model): dataset_info = None label_names = model.pretrained_cfg.get("label_names", None) label_descriptions = model.pretrained_cfg.get("label_descriptions", None) if label_names is None: imagenet_subset = infer_imagenet_subset(model) if imagenet_subset: dataset_info = ImageNetInfo(imagenet_subset) else: label_names = [f"LABEL_{i}" for i in range(model.num_classes)] if dataset_info is None: dataset_info = CustomDatasetInfo( label_names=label_names, label_descriptions=label_descriptions, ) return dataset_info def explain_image(model_name, image_path, cam_method, feature_module, target_class): model = load_model(model_name) image = process_image(image_path, model) target_layer = get_target_layer(model, feature_module) if target_layer is None: feature_info = get_feature_info(model) if feature_info: target_layer = get_target_layer(model, feature_info[-1]) print(f"Using last feature module: {feature_info[-1]}") else: for name, module in reversed(list(model.named_modules())): if isinstance(module, torch.nn.Conv2d): target_layer = module print(f"Fallback: Using last convolutional layer: {name}") break if target_layer is None: raise ValueError("Could not find a suitable target layer.") target_class_index = None if target_class == "highest scoring" else int(target_class.split(':')[0]) cam_image = get_cam_image(model, image, target_layer, cam_method, target_class_index) with torch.no_grad(): out = model(image) probabilities = out.squeeze(0).softmax(dim=0) values, indices = torch.topk(probabilities, 5) # Top 5 predictions dataset_info = get_class_names(model) labels = [ f"{i}: {dataset_info.index_to_description(i.item(), detailed=True)} ({v.item():.2%})" for i, v in zip(indices, values) ] return cam_image, "\n".join(labels) def update_feature_modules(model_name): model = load_model(model_name) feature_modules = get_feature_info(model) return gr.Dropdown(choices=feature_modules, value=feature_modules[-1] if feature_modules else None) def update_class_dropdown(model_name): model = load_model(model_name) dataset_info = get_class_names(model) class_names = ["highest scoring"] + [f"{i}: {dataset_info.index_to_description(i, detailed=True)}" for i in range(model.num_classes)] return gr.Dropdown(choices=class_names, value="highest scoring") with gr.Blocks() as demo: gr.Markdown("# Explainable AI with timm models. NOTE: This is a WIP but some models are functioning.") gr.Markdown("Upload an image, select a model, CAM method, and optionally a specific feature module and target class to visualize the explanation.") with gr.Row(): with gr.Column(): model_dropdown = gr.Dropdown(choices=MODELS, label="Select Model") image_input = gr.Image(type="filepath", label="Upload Image") cam_method_dropdown = gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method") feature_module_dropdown = gr.Dropdown(label="Select Feature Module (optional)") class_dropdown = gr.Dropdown(label="Select Target Class (optional)") explain_button = gr.Button("Explain Image") with gr.Column(): output_image = gr.Image(type="pil", label="Explained Image") prediction_text = gr.Textbox(label="Top 5 Predictions") model_dropdown.change(fn=update_feature_modules, inputs=[model_dropdown], outputs=[feature_module_dropdown]) model_dropdown.change(fn=update_class_dropdown, inputs=[model_dropdown], outputs=[class_dropdown]) explain_button.click( fn=explain_image, inputs=[model_dropdown, image_input, cam_method_dropdown, feature_module_dropdown, class_dropdown], outputs=[output_image, prediction_text] ) demo.launch()