kritsg's picture
modified desc
7570470
raw
history blame
3.89 kB
from cgitb import enable
from pyexpat import model
from statistics import mode
import numpy as np
import gradio as gr
import argparse
import os
from os.path import exists, dirname
import sys
import json
import flask
from PIL import Image
parent_dir = dirname(os.path.abspath(os.getcwd()))
sys.path.append(parent_dir)
from bayes.explanations import BayesLocalExplanations, explain_many
from bayes.data_routines import get_dataset_by_name
from bayes.models import *
from image_posterior import create_gif
def get_image_data(inp_image):
"""Gets the image data and model."""
image = get_dataset_by_name(inp_image, get_label=False)
# print("image returned\n", image)
model_and_data = process_imagenet_get_model(image)
# print("model returned\n", model_and_data)
return image, model_and_data
def segmentation_generation(input_image, c_width, n_top, n_gif_imgs):
print("Inputs Received:", input_image, c_width, n_top, n_gif_imgs)
image, model_and_data = get_image_data(input_image)
# Unpack datax
xtest = model_and_data["xtest"]
ytest = model_and_data["ytest"]
segs = model_and_data["xtest_segs"]
get_model = model_and_data["model"]
label = model_and_data["label"]
# if (image_name == 'imagenet_diego'):
# label = 156
# elif (image_name == 'imagenet_french_bulldog'):
# label = 245
# Unpack instance and segments
instance = xtest[0]
segments = segs[0]
# Get wrapped model
cur_model = get_model(instance, segments)
# Get background data
xtrain = get_xtrain(segments)
prediction = np.argmax(cur_model(xtrain[:1]), axis=1)
# if image_name in ["imagenet_diego", "imagenet_french_bulldog"]:
# assert prediction == label, f"Prediction is {prediction} not {label}"
# Compute explanation
exp_init = BayesLocalExplanations(training_data=xtrain,
data="image",
kernel="lime",
categorical_features=np.arange(xtrain.shape[1]),
verbose=True)
rout = exp_init.explain(classifier_f=cur_model,
data=np.ones_like(xtrain[0]),
label=int(prediction[0]),
cred_width=c_width,
focus_sample=False,
l2=False)
# Create the gif of the explanation
return create_gif(rout['blr'], input_image, segments, instance, prediction[0], n_gif_imgs, n_top)
if __name__ == "__main__":
inp = gr.inputs.Image(label="Input Image (Or select an example)", type="pil")
out = [gr.outputs.HTML(label="Output GIF"), gr.outputs.Textbox(label="Prediction")]
iface = gr.Interface(
segmentation_generation,
[
inp,
gr.inputs.Slider(minimum=0.01, maximum=0.8, step=0.01, default=0.01, label="cred_width", optional=False),
gr.inputs.Slider(minimum=1, maximum=10, step=1, default=5, label="n_top_segs", optional=False),
gr.inputs.Slider(minimum=10, maximum=100, step=1, default=30, label="n_gif_images", optional=False),
],
outputs=out,
examples=[["./data/diego.png", 0.01, 7, 50],
["./data/french_bulldog.jpg", 0.01, 5, 50],
["./data/pepper.jpeg", 0.01, 5, 50],
["./data/bird.jpg", 0.01, 5, 50],
["./data/hockey.jpg", 0.01, 5, 50]],
title="Reliable Post Hoc Explanations: Modeling Uncertainty in Explainability",
description = "Dylan Slack, Sophie Hilgard, Sameer Singh, and Hima Lakkaraju. NeurIPS 2021.",
article="Research paper and Github can be found [here](https://dylanslacks.website/reliable/index.html)"
)
iface.launch(enable_queue=True)