Spaces:
Sleeping
Sleeping
File size: 3,553 Bytes
24cb51e d898082 8309552 24cb51e d898082 24cb51e d898082 24cb51e d898082 24cb51e d898082 24cb51e 03be69d 24cb51e dc031a1 24cb51e 5666b8d d898082 24cb51e d898082 24cb51e 5666b8d d898082 24cb51e 80aa682 dc031a1 24cb51e d898082 24cb51e fb1c84a 24cb51e 5666b8d 24cb51e d898082 24cb51e d898082 24cb51e d898082 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 |
import pickle
import os
from sklearn.neighbors import NearestNeighbors
import numpy as np
import gradio as gr
from PIL import Image
data_root = 'https://ai-vision-public-datasets.s3.eu.cloud-object-storage.appdomain.cloud/DomainNet'
feat_dir = 'brad_feats'
domains = ['sketch', 'painting', 'clipart', 'real']
shots = '-1'
num_nn = 20
search_domain = 'all'
num_results_per_domain = 5
src_data_dict = {}
if search_domain == 'all':
for d in domains:
with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
src_data = pickle.load(fp)
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
src_data_dict[d] = (src_data,src_nn_fit)
else:
with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as fp:
src_data = pickle.load(fp)
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
src_data_dict[search_domain] = (src_data,src_nn_fit)
dst_data_dict = {}
min_len = 1e10
for d in domains:
with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
dst_data_dict[d] = pickle.load(fp)
min_len = min(min_len, len(dst_data_dict[d][0]))
def query(query_index, query_domain):
dst_data = dst_data_dict[query_domain]
dst_img_path = os.path.join(data_root, dst_data[0][query_index])
img_paths = [dst_img_path]
q_cl = dst_img_path.split('/')[-2]
captions = [f'Query: {q_cl}'.title()]
for s_domain, s_data in src_data_dict.items():
_, top_n_matches_ids = s_data[1].kneighbors(dst_data[1][query_index:query_index+1])
top_n_labels = s_data[0][2][top_n_matches_ids][0]
src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in top_n_matches_ids[0]]
img_paths += src_img_pths
for p in src_img_pths:
src_cl = p.split('/')[-2]
src_file = p.split('/')[-1]
captions.append(src_cl.title())
print(img_paths)
return tuple([p for p in img_paths])+ tuple(captions)
demo = gr.Blocks()
with demo:
gr.Markdown('# Unsupervised Domain Generalization by Learning a Bridge Across Domains')
gr.Markdown('This demo showcases the cross-domain retrieval capabilities of our self-supervised cross domain training as presented @CVPR 2022. For details please refer to [the paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Harary_Unsupervised_Domain_Generalization_by_Learning_a_Bridge_Across_Domains_CVPR_2022_paper.pdf)')
gr.Markdown('## Select Query Domain: ')
domain_drop = gr.Dropdown(domains)
# domain_select_button = gr.Button("Select Domain")
# slider = gr.Slider(0, min_len)
slider = gr.Slider(0, 10000)
image_button = gr.Button("Run")
with gr.Row():
gr.Markdown('# Query Image: \t\t\t\t ')
gr.Markdown('\t')
gr.Markdown('\t')
gr.Markdown('\t')
with gr.Column():
src_cap = gr.Label()
src_img = gr.Image()
out_images = []
out_captions = []
for d in domains:
gr.Markdown(f'# {d.title()} Domain Images')
with gr.Row():
for _ in range(num_results_per_domain):
with gr.Column():
out_captions.append(gr.Label())
out_images.append(gr.Image())
image_button.click(query, inputs=[slider, domain_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions)
demo.launch(share=True)
|