|
import io |
|
import csv |
|
import sys |
|
import pickle |
|
from collections import Counter |
|
import numpy as np |
|
import gradio as gr |
|
import gdown |
|
import torchvision |
|
from torchvision.datasets import ImageFolder |
|
from PIL import Image |
|
|
|
from SimSearch import FaissCosineNeighbors, SearchableTrainingSet |
|
from ExtractEmbedding import QueryToEmbedding |
|
from CHMCorr import chm_classify_and_visualize |
|
from visualization import plot_from_reranker_corrmap |
|
|
|
csv.field_size_limit(sys.maxsize) |
|
|
|
concat = lambda x: np.concatenate(x, axis=0) |
|
|
|
|
|
gdown.cached_download( |
|
url="https://huggingface.co/datasets/XAI/CHM-Corr-Data/resolve/main/embeddings.pickle", |
|
path="./embeddings.pickle", |
|
quiet=False, |
|
md5="002b2a7f5c80d910b9cc740c2265f058", |
|
) |
|
|
|
|
|
|
|
|
|
|
|
gdown.download(id="1SDtq6ap7LPPpYfLbAxaMGGmj0EAV_m_e") |
|
|
|
|
|
gdown.cached_download( |
|
url="https://huggingface.co/datasets/XAI/CHM-Corr-Data/resolve/main/CUB_train.zip", |
|
path="./CUB_train.zip", |
|
quiet=False, |
|
md5="1bd99e73b2fea8e4c2ebcb0e7722f1b1", |
|
) |
|
|
|
|
|
torchvision.datasets.utils.extract_archive( |
|
from_path="CUB_train.zip", |
|
to_path="data/", |
|
remove_finished=False, |
|
) |
|
|
|
|
|
gdown.cached_download( |
|
url="https://drive.google.com/u/0/uc?id=1zsJRlAsoOn5F0GTCprSFYwDDfV85xDy6&export=download", |
|
path="pas_psi.pt", |
|
quiet=False, |
|
md5="6b7b4d7bad7f89600fac340d6aa7708b", |
|
) |
|
|
|
|
|
|
|
with open(f"./embeddings.pickle", "rb") as f: |
|
Xtrain = pickle.load(f) |
|
|
|
with open(f"./labels.pickle", "rb") as f: |
|
ytrain = pickle.load(f) |
|
|
|
searcher = SearchableTrainingSet(Xtrain, ytrain) |
|
searcher.build_index() |
|
|
|
|
|
training_folder = ImageFolder(root="./data/train/") |
|
id_to_bird_name = { |
|
x[1]: x[0].split("/")[-2].replace(".", " ") for x in training_folder.imgs |
|
} |
|
|
|
|
|
def search(query_image, searcher=searcher): |
|
query_embedding = QueryToEmbedding(query_image) |
|
scores, indices, labels = searcher.search(query_embedding, k=50) |
|
|
|
result_ctr = Counter(labels[0][:20]).most_common(5) |
|
|
|
top1_label = result_ctr[0][0] |
|
top_indices = [] |
|
|
|
for a, b in zip(labels[0][:20], indices[0][:20]): |
|
if a == top1_label: |
|
top_indices.append(b) |
|
|
|
gallery_images = [training_folder.imgs[int(X)][0] for X in top_indices[:5]] |
|
predicted_labels = {id_to_bird_name[X[0]]: X[1] / 20.0 for X in result_ctr} |
|
|
|
|
|
kNN_results = (top1_label, result_ctr[0][1], gallery_images) |
|
support_files = [training_folder.imgs[int(X)][0] for X in indices[0]] |
|
support_labels = [training_folder.imgs[int(X)][1] for X in indices[0]] |
|
|
|
support = [support_files, support_labels] |
|
|
|
chm_output = chm_classify_and_visualize( |
|
query_image, kNN_results, support, training_folder |
|
) |
|
|
|
fig, chm_output_label = plot_from_reranker_corrmap(chm_output) |
|
|
|
|
|
|
|
img_buf = io.BytesIO() |
|
fig.savefig(img_buf, format="jpg") |
|
image = Image.open(img_buf) |
|
width, height = image.size |
|
new_width = width |
|
new_height = height |
|
|
|
left = (width - new_width) / 2 |
|
top = (height - new_height) / 2 |
|
right = (width + new_width) / 2 |
|
bottom = (height + new_height) / 2 |
|
|
|
viz_image = image.crop((left + 310, top + 60, right - 248, bottom - 80)) |
|
|
|
chm_output_labels = Counter( |
|
[ |
|
x.split("/")[-2].replace(".", " ").replace("_", " ") |
|
for x in chm_output["chm-nearest-neighbors-all"][:20] |
|
] |
|
) |
|
|
|
return viz_image, {l: s / 20.0 for l, s in chm_output_labels.items()} |
|
|
|
|
|
blocks = gr.Blocks() |
|
|
|
tldr = """ |
|
We propose two architectures of interpretable image classifiers |
|
that first explain, and then predict by harnessing |
|
the visual correspondences between a query image and exemplars. |
|
Our models improve on several out-of-distribution (OOD) ImageNet |
|
datasets while achieving competitive performance on ImageNet |
|
than the black-box baselines (e.g. ImageNet-pretrained ResNet-50). |
|
On a large-scale human study (∼60 users per method per dataset) |
|
on ImageNet and CUB, our correspondence-based explanations led |
|
to human-alone image classification accuracy and human-AI team |
|
accuracy that are consistently better than that of kNN. |
|
We show that it is possible to achieve complementary human-AI |
|
team accuracy (i.e., that is higher than either AI-alone or |
|
human-alone), on ImageNet and CUB. |
|
|
|
<div align="center"> |
|
<a href="https://github.com/anguyen8/visual-correspondence-XAI">Github Page</a> |
|
</div> |
|
""" |
|
|
|
with blocks: |
|
gr.Markdown(""" # CHM-Corr DEMO""") |
|
gr.Markdown(f""" ## Description: \n {tldr}""") |
|
|
|
with gr.Row(): |
|
input_image = gr.Image(type="filepath") |
|
|
|
with gr.Column(): |
|
gr.Markdown(f"### Parameters:") |
|
gr.Markdown( |
|
"`N=50`\n `k=20` \nUsing `ImageNet Pretrained ResNet50` features" |
|
) |
|
|
|
run_btn = gr.Button("Classify") |
|
gr.Markdown(""" ### CHM-Corr Output Visualization """) |
|
viz_plot = gr.Image(type="pil", label="Visualization") |
|
with gr.Row(): |
|
with gr.Column(): |
|
gr.Markdown(""" ### CHM-Corr Prediction """) |
|
labels = gr.Label(label="Prediction") |
|
with gr.Column(): |
|
gr.Markdown(""" ### Examples """) |
|
examples = gr.Examples( |
|
examples=[ |
|
["./examples/bird.jpg"], |
|
["./examples/Red_Winged_Blackbird_0012_6015.jpg"], |
|
["./examples/Red_Winged_Blackbird_0025_5342.jpg"], |
|
["./examples/sample1.jpeg"], |
|
["./examples/sample2.jpeg"], |
|
["./examples/Yellow_Headed_Blackbird_0020_8549.jpg"], |
|
["./examples/Yellow_Headed_Blackbird_0026_8545.jpg"], |
|
], |
|
inputs=[input_image], |
|
outputs=[viz_plot, labels], |
|
fn=search, |
|
cache_examples=False, |
|
) |
|
run_btn.click( |
|
search, |
|
inputs=[input_image], |
|
outputs=[viz_plot, labels], |
|
) |
|
|
|
|
|
if __name__ == "__main__": |
|
blocks.launch() |
|
|