Spaces:
Sleeping
Sleeping
import pickle | |
import os | |
from sklearn.neighbors import NearestNeighbors | |
import numpy as np | |
import gradio as gr | |
from PIL import Image | |
data_root = 'https://ai-vision-public-datasets.s3.eu.cloud-object-storage.appdomain.cloud/DomainNet' | |
feat_dir = 'brad_feats' | |
domains = ['sketch', 'painting', 'clipart', 'real'] | |
shots = '-1' | |
num_nn = 20 | |
search_domain = 'all' | |
num_results_per_domain = 5 | |
src_data_dict = {} | |
if search_domain == 'all': | |
for d in domains: | |
with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp: | |
src_data = pickle.load(fp) | |
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1]) | |
src_data_dict[d] = (src_data,src_nn_fit) | |
else: | |
with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as fp: | |
src_data = pickle.load(fp) | |
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1]) | |
src_data_dict[search_domain] = (src_data,src_nn_fit) | |
dst_data_dict = {} | |
min_len = 1e10 | |
for d in domains: | |
with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp: | |
dst_data_dict[d] = pickle.load(fp) | |
min_len = min(min_len, len(dst_data_dict[d][0])) | |
def query(query_index, query_domain): | |
dst_data = dst_data_dict[query_domain] | |
dst_img_path = os.path.join(data_root, dst_data[0][query_index]) | |
img_paths = [dst_img_path] | |
q_cl = dst_img_path.split('/')[-2] | |
captions = [f'Query: {q_cl}'.title()] | |
for s_domain, s_data in src_data_dict.items(): | |
_, top_n_matches_ids = s_data[1].kneighbors(dst_data[1][query_index:query_index+1]) | |
top_n_labels = s_data[0][2][top_n_matches_ids][0] | |
src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in top_n_matches_ids[0]] | |
img_paths += src_img_pths | |
for p in src_img_pths: | |
src_cl = p.split('/')[-2] | |
src_file = p.split('/')[-1] | |
captions.append(src_cl.title()) | |
print(img_paths) | |
return tuple([p for p in img_paths])+ tuple(captions) | |
demo = gr.Blocks() | |
with demo: | |
gr.Markdown('# Unsupervised Domain Generalization by Learning a Bridge Across Domains') | |
gr.Markdown('This demo showcases the cross-domain retrieval capabilities of our self-supervised cross domain training as presented @CVPR 2022. For details please refer to [the paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Harary_Unsupervised_Domain_Generalization_by_Learning_a_Bridge_Across_Domains_CVPR_2022_paper.pdf)') | |
gr.Markdown('## Instructions:') | |
gr.Markdown('Select a query domain from the dropdown menu and the select any random image from the domain using the slider below. The retrieved results from each of the four domains, along with the class label will be presented.') | |
gr.Markdown('## Select Query Domain: ') | |
domain_drop = gr.Dropdown(domains) | |
# domain_select_button = gr.Button("Select Domain") | |
slider = gr.Slider(0, min_len) | |
# slider = gr.Slider(0, 10000) | |
image_button = gr.Button("Run") | |
with gr.Row(): | |
gr.Markdown('# Query Image: \t\t\t\t ') | |
gr.Markdown('\t') | |
gr.Markdown('\t') | |
gr.Markdown('\t') | |
with gr.Column(): | |
src_cap = gr.Label() | |
src_img = gr.Image() | |
out_images = [] | |
out_captions = [] | |
for d in domains: | |
gr.Markdown(f'# Retrieved Images from {d.title()} Domain:') | |
with gr.Row(): | |
for _ in range(num_results_per_domain): | |
with gr.Column(): | |
out_captions.append(gr.Label()) | |
out_images.append(gr.Image()) | |
image_button.click(query, inputs=[slider, domain_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions) | |
demo.launch(share=True) | |