BrAD / app.py
aarbelle's picture
Fix title
80aa682
raw
history blame
3.55 kB
import pickle
import os
from sklearn.neighbors import NearestNeighbors
import numpy as np
import gradio as gr
from PIL import Image
data_root = 'https://ai-vision-public-datasets.s3.eu.cloud-object-storage.appdomain.cloud/DomainNet'
feat_dir = 'brad_feats'
domains = ['sketch', 'painting', 'clipart', 'real']
shots = '-1'
num_nn = 20
search_domain = 'all'
num_results_per_domain = 5
src_data_dict = {}
if search_domain == 'all':
for d in domains:
with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
src_data = pickle.load(fp)
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
src_data_dict[d] = (src_data,src_nn_fit)
else:
with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as fp:
src_data = pickle.load(fp)
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
src_data_dict[search_domain] = (src_data,src_nn_fit)
dst_data_dict = {}
min_len = 1e10
for d in domains:
with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
dst_data_dict[d] = pickle.load(fp)
min_len = min(min_len, len(dst_data_dict[d][0]))
def query(query_index, query_domain):
dst_data = dst_data_dict[query_domain]
dst_img_path = os.path.join(data_root, dst_data[0][query_index])
img_paths = [dst_img_path]
q_cl = dst_img_path.split('/')[-2]
captions = [f'Query: {q_cl}'.title()]
for s_domain, s_data in src_data_dict.items():
_, top_n_matches_ids = s_data[1].kneighbors(dst_data[1][query_index:query_index+1])
top_n_labels = s_data[0][2][top_n_matches_ids][0]
src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in top_n_matches_ids[0]]
img_paths += src_img_pths
for p in src_img_pths:
src_cl = p.split('/')[-2]
src_file = p.split('/')[-1]
captions.append(src_cl.title())
print(img_paths)
return tuple([p for p in img_paths])+ tuple(captions)
demo = gr.Blocks()
with demo:
gr.Markdown('# Unsupervised Domain Generalization by Learning a Bridge Across Domains')
gr.Markdown('This demo showcases the cross-domain retrieval capabilities of our self-supervised cross domain training as presented @CVPR 2022. For details please refer to [the paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Harary_Unsupervised_Domain_Generalization_by_Learning_a_Bridge_Across_Domains_CVPR_2022_paper.pdf)')
gr.Markdown('## Select Query Domain: ')
domain_drop = gr.Dropdown(domains)
# domain_select_button = gr.Button("Select Domain")
# slider = gr.Slider(0, min_len)
slider = gr.Slider(0, 10000)
image_button = gr.Button("Run")
with gr.Row():
gr.Markdown('# Query Image: \t\t\t\t ')
gr.Markdown('\t')
gr.Markdown('\t')
gr.Markdown('\t')
with gr.Column():
src_cap = gr.Label()
src_img = gr.Image()
out_images = []
out_captions = []
for d in domains:
gr.Markdown(f'# {d.title()} Domain Images')
with gr.Row():
for _ in range(num_results_per_domain):
with gr.Column():
out_captions.append(gr.Label())
out_images.append(gr.Image())
image_button.click(query, inputs=[slider, domain_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions)
demo.launch(share=True)