Spaces:
Sleeping
Sleeping
File size: 4,888 Bytes
24cb51e d898082 8309552 24cb51e d898082 24cb51e 77733ea 24cb51e 77733ea 74cdc0c 77733ea d898082 24cb51e d898082 24cb51e d898082 24cb51e 03be69d 24cb51e 77733ea fb018ef 77733ea fb018ef 77733ea fb018ef 24cb51e 77733ea 24cb51e fb018ef 24cb51e 5666b8d d898082 24cb51e d898082 24cb51e 5666b8d fb018ef d898082 24cb51e 80aa682 dc031a1 8e2a7a5 099eca2 a393337 24cb51e 77733ea 24cb51e 77733ea 099eca2 a045da3 5666b8d 77733ea fb018ef 8e2a7a5 6633a91 77733ea 5666b8d 79a12df 24cb51e d898082 24cb51e 099eca2 24cb51e 77733ea 24cb51e d898082 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 |
import pickle
import os
from sklearn.neighbors import NearestNeighbors
import numpy as np
import gradio as gr
from PIL import Image
data_root = 'https://ai-vision-public-datasets.s3.eu.cloud-object-storage.appdomain.cloud/DomainNet'
feat_dir = 'brad_feats'
domains = ['sketch', 'painting', 'clipart', 'real']
shots = '-1'
num_nn = 20
search_domain = 'all'
num_results_per_domain = 5
src_data_dict = {}
class_list = []
if search_domain == 'all':
for d in domains:
with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
src_data = pickle.load(fp)
if class_list == []:
for p in src_data[0]:
cl = p.split('/')[-2]
if cl not in class_list:
class_list.append(cl)
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
src_data_dict[d] = (src_data,src_nn_fit)
else:
with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as fp:
src_data = pickle.load(fp)
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
src_data_dict[search_domain] = (src_data,src_nn_fit)
dst_data_dict = {}
min_len = 1e10
for d in domains:
with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
dest_data = pickle.load(fp)
dst_data_dict[d] = ({cl: ([],[]) for cl in class_list},dest_data[1])
for c, p in enumerate(dest_data[0]):
cl = p.split('/')[-2]
dst_data_dict[d][0][cl][0].append(p)
dst_data_dict[d][0][cl][1].append(c)
for cl in class_list:
min_len = min(min_len, len(dst_data_dict[d][0][cl]))
def query(query_index, query_domain, cl):
dst_data = dst_data_dict[query_domain]
dst_img_path = os.path.join(data_root, dst_data[0][cl][0][query_index])
query_index = dst_data[0][cl][1][query_index]
img_paths = [dst_img_path]
q_cl = dst_img_path.split('/')[-2]
captions = [f'Query: {q_cl}'.title()]
for s_domain, s_data in src_data_dict.items():
_, top_n_matches_ids = s_data[1].kneighbors(dst_data[1][query_index:query_index+1])
top_n_labels = s_data[0][2][top_n_matches_ids][0]
src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in top_n_matches_ids[0]]
img_paths += src_img_pths
for p in src_img_pths:
src_cl = p.split('/')[-2]
src_file = p.split('/')[-1]
captions.append(src_cl.title())
# print(img_paths)
return tuple([p for p in img_paths])+ tuple(captions)
demo = gr.Blocks()
with demo:
gr.Markdown('# Unsupervised Domain Generalization by Learning a Bridge Across Domains')
gr.Markdown('This demo showcases the cross-domain retrieval capabilities of our self-supervised cross domain training as presented @CVPR 2022. For details please refer to [the paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Harary_Unsupervised_Domain_Generalization_by_Learning_a_Bridge_Across_Domains_CVPR_2022_paper.pdf)')
gr.Markdown('The model is trained in an unsupervised manner on all domains without class labels. The labels are displayed to indicate retrieval success/failure.')
gr.Markdown('## Instructions:')
gr.Markdown('Select a query domain and a class from the drop-down menus and the select any random image from the domain using the slider below and press the "Run" button. The retrieved results from each of the four domains, along with the class label will be presented.')
gr.Markdown('## Select Query Domain: ')
gr.Markdown('# Query Image: \t\t\t\t')
# domain_drop = gr.Dropdown(domains)
# cl_drop = gr.Dropdown(class_list)
# domain_select_button = gr.Button("Select Domain")
# slider = gr.Slider(0, min_len)
# slider = gr.Slider(0, 10000)
with gr.Row():
with gr.Column():
domain_drop = gr.Dropdown(domains, label='Domain')
cl_drop = gr.Dropdown(class_list, label='Query Class')
slider = gr.Slider(0, 100, label='Query image selector slider')
# gr.Markdown('\t')
# gr.Markdown('\t')
# gr.Markdown('\t')
with gr.Column():
src_cap = gr.Label()
src_img = gr.Image()
image_button = gr.Button("Run")
out_images = []
out_captions = []
for d in domains:
gr.Markdown(f'# Retrieved Images from {d.title()} Domain:')
with gr.Row():
for _ in range(num_results_per_domain):
with gr.Column():
out_captions.append(gr.Label())
out_images.append(gr.Image())
image_button.click(query, inputs=[slider, domain_drop, cl_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions)
demo.launch(share=True)
|