timmCAM / app.py
rwightman's picture
rwightman HF staff
Update app.py
500c37b verified
import gradio as gr
import timm
import torch
from PIL import Image
import requests
from io import BytesIO
import numpy as np
from pytorch_grad_cam import GradCAM, HiResCAM, ScoreCAM, GradCAMPlusPlus, AblationCAM, XGradCAM, EigenCAM, FullGrad
from pytorch_grad_cam.utils.model_targets import ClassifierOutputTarget
from pytorch_grad_cam.utils.image import show_cam_on_image
from timm.data import create_transform
from timm.data import infer_imagenet_subset, ImageNetInfo
# List of available timm models
MODELS = timm.list_pretrained()
# List of available GradCAM methods
CAM_METHODS = {
"GradCAM": GradCAM,
"HiResCAM": HiResCAM,
"ScoreCAM": ScoreCAM,
"GradCAM++": GradCAMPlusPlus,
"AblationCAM": AblationCAM,
"XGradCAM": XGradCAM,
"EigenCAM": EigenCAM,
"FullGrad": FullGrad
}
class CustomDatasetInfo:
def __init__(self, label_names, label_descriptions=None):
self.label_names = label_names
self.label_descriptions = label_descriptions or label_names
def index_to_description(self, index, detailed=False):
if detailed and self.label_descriptions:
return self.label_descriptions[index]
return self.label_names[index]
def load_model(model_name):
model = timm.create_model(model_name, pretrained=True)
model.eval()
return model
def process_image(image_path, model):
if image_path.startswith('http'):
response = requests.get(image_path)
image = Image.open(BytesIO(response.content))
else:
image = Image.open(image_path)
config = model.pretrained_cfg
transform = create_transform(
input_size=config['input_size'],
crop_pct=config['crop_pct'],
mean=config['mean'],
std=config['std'],
interpolation=config['interpolation'],
is_training=False
)
tensor = transform(image).unsqueeze(0)
return tensor
def get_cam_image(model, image, target_layer, cam_method, target_class):
if target_class is not None and target_class != "highest scoring":
target = ClassifierOutputTarget(target_class)
else:
target = None
cam = CAM_METHODS[cam_method](model=model, target_layers=[target_layer])
grayscale_cam = cam(input_tensor=image, targets=[target] if target else None)
config = model.pretrained_cfg
mean = torch.tensor(config['mean']).view(3, 1, 1)
std = torch.tensor(config['std']).view(3, 1, 1)
rgb_img = (image.squeeze(0) * std + mean).permute(1, 2, 0).cpu().numpy()
rgb_img = np.clip(rgb_img, 0, 1)
cam_image = show_cam_on_image(rgb_img, grayscale_cam[0, :], use_rgb=True)
return Image.fromarray(cam_image)
def get_feature_info(model):
if hasattr(model, 'feature_info'):
return [f['module'] for f in model.feature_info]
else:
return []
def get_target_layer(model, target_layer_name):
if target_layer_name is None:
return None
try:
return model.get_submodule(target_layer_name)
except AttributeError:
print(f"WARNING: Layer '{target_layer_name}' not found in the model.")
return None
def get_class_names(model):
dataset_info = None
label_names = model.pretrained_cfg.get("label_names", None)
label_descriptions = model.pretrained_cfg.get("label_descriptions", None)
if label_names is None:
imagenet_subset = infer_imagenet_subset(model)
if imagenet_subset:
dataset_info = ImageNetInfo(imagenet_subset)
else:
label_names = [f"LABEL_{i}" for i in range(model.num_classes)]
if dataset_info is None:
dataset_info = CustomDatasetInfo(
label_names=label_names,
label_descriptions=label_descriptions,
)
return dataset_info
def explain_image(model_name, image_path, cam_method, feature_module, target_class):
model = load_model(model_name)
image = process_image(image_path, model)
target_layer = get_target_layer(model, feature_module)
if target_layer is None:
feature_info = get_feature_info(model)
if feature_info:
target_layer = get_target_layer(model, feature_info[-1])
print(f"Using last feature module: {feature_info[-1]}")
else:
for name, module in reversed(list(model.named_modules())):
if isinstance(module, torch.nn.Conv2d):
target_layer = module
print(f"Fallback: Using last convolutional layer: {name}")
break
if target_layer is None:
raise ValueError("Could not find a suitable target layer.")
target_class_index = None if target_class == "highest scoring" else int(target_class.split(':')[0])
cam_image = get_cam_image(model, image, target_layer, cam_method, target_class_index)
with torch.no_grad():
out = model(image)
probabilities = out.squeeze(0).softmax(dim=0)
values, indices = torch.topk(probabilities, 5) # Top 5 predictions
dataset_info = get_class_names(model)
labels = [
f"{i}: {dataset_info.index_to_description(i.item(), detailed=True)} ({v.item():.2%})"
for i, v in zip(indices, values)
]
return cam_image, "\n".join(labels)
def update_feature_modules(model_name):
model = load_model(model_name)
feature_modules = get_feature_info(model)
return gr.Dropdown(choices=feature_modules, value=feature_modules[-1] if feature_modules else None)
def update_class_dropdown(model_name):
model = load_model(model_name)
dataset_info = get_class_names(model)
class_names = ["highest scoring"] + [f"{i}: {dataset_info.index_to_description(i, detailed=True)}" for i in range(model.num_classes)]
return gr.Dropdown(choices=class_names, value="highest scoring")
with gr.Blocks() as demo:
gr.Markdown("# Explainable AI with timm models. NOTE: This is a WIP but some models are functioning.")
gr.Markdown("Upload an image, select a model, CAM method, and optionally a specific feature module and target class to visualize the explanation.")
with gr.Row():
with gr.Column():
model_dropdown = gr.Dropdown(choices=MODELS, label="Select Model")
image_input = gr.Image(type="filepath", label="Upload Image")
cam_method_dropdown = gr.Dropdown(choices=list(CAM_METHODS.keys()), label="Select CAM Method")
feature_module_dropdown = gr.Dropdown(label="Select Feature Module (optional)")
class_dropdown = gr.Dropdown(label="Select Target Class (optional)")
explain_button = gr.Button("Explain Image")
with gr.Column():
output_image = gr.Image(type="pil", label="Explained Image")
prediction_text = gr.Textbox(label="Top 5 Predictions")
model_dropdown.change(fn=update_feature_modules, inputs=[model_dropdown], outputs=[feature_module_dropdown])
model_dropdown.change(fn=update_class_dropdown, inputs=[model_dropdown], outputs=[class_dropdown])
explain_button.click(
fn=explain_image,
inputs=[model_dropdown, image_input, cam_method_dropdown, feature_module_dropdown, class_dropdown],
outputs=[output_image, prediction_text]
)
demo.launch()