File size: 4,912 Bytes
24cb51e
 
 
 
 
 
 
d898082
 
8309552
24cb51e
d898082
 
24cb51e
 
 
77733ea
24cb51e
 
 
 
77733ea
 
 
74cdc0c
77733ea
d898082
24cb51e
 
 
d898082
24cb51e
d898082
24cb51e
 
 
03be69d
24cb51e
 
77733ea
fb018ef
 
77733ea
fb018ef
 
 
77733ea
fb018ef
24cb51e
77733ea
24cb51e
fb018ef
 
24cb51e
 
5666b8d
d898082
 
24cb51e
d898082
24cb51e
 
 
 
 
5666b8d
fb018ef
d898082
 
24cb51e
 
80aa682
dc031a1
8e2a7a5
099eca2
eff3219
24cb51e
77733ea
 
 
24cb51e
77733ea
099eca2
a045da3
5666b8d
77733ea
fb018ef
 
8e2a7a5
6633a91
77733ea
 
 
5666b8d
 
 
79a12df
24cb51e
d898082
24cb51e
 
 
099eca2
24cb51e
 
 
 
 
 
77733ea
24cb51e
d898082
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
import pickle
import os
from sklearn.neighbors import NearestNeighbors
import numpy as np
import gradio as gr
from PIL import Image

data_root = 'https://ai-vision-public-datasets.s3.eu.cloud-object-storage.appdomain.cloud/DomainNet'
feat_dir = 'brad_feats'
domains = ['sketch', 'painting', 'clipart', 'real']
shots = '-1'
num_nn = 20

search_domain = 'all'
num_results_per_domain = 5
src_data_dict = {}
class_list = []
if search_domain == 'all':
    for d in domains:
        with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
            src_data = pickle.load(fp)
            if class_list == []:
                for p in src_data[0]:
                    cl = p.split('/')[-2]
                    if cl not in class_list:
                        class_list.append(cl)
            src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
            src_data_dict[d] = (src_data,src_nn_fit)
else:

    with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as fp:
        src_data = pickle.load(fp)
        src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
        src_data_dict[search_domain] = (src_data,src_nn_fit)

dst_data_dict = {}
min_len = 1e10
for d in domains:
    with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
        dest_data =  pickle.load(fp)
        dst_data_dict[d] = ({cl: ([],[]) for cl in class_list},dest_data[1])
        for c, p in enumerate(dest_data[0]):
            cl = p.split('/')[-2]
            dst_data_dict[d][0][cl][0].append(p)
            dst_data_dict[d][0][cl][1].append(c)

        for cl in class_list:
            min_len = min(min_len, len(dst_data_dict[d][0][cl]))

def query(query_index, query_domain, cl):
    dst_data = dst_data_dict[query_domain]
    dst_img_path = os.path.join(data_root, dst_data[0][cl][0][query_index])
    query_index =  dst_data[0][cl][1][query_index]
    img_paths = [dst_img_path]
    q_cl = dst_img_path.split('/')[-2]
    captions = [f'Query: {q_cl}'.title()]
    for s_domain, s_data in src_data_dict.items():
        _, top_n_matches_ids = s_data[1].kneighbors(dst_data[1][query_index:query_index+1])
        top_n_labels = s_data[0][2][top_n_matches_ids][0]
        src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in top_n_matches_ids[0]]
        img_paths += src_img_pths

        for p in src_img_pths:
            src_cl = p.split('/')[-2]
            src_file =  p.split('/')[-1]
            captions.append(src_cl.title())
#     print(img_paths)
    return tuple([p for p in img_paths])+ tuple(captions)

demo = gr.Blocks()
with demo:
    gr.Markdown('# Unsupervised Domain Generalization by Learning a Bridge Across Domains')
    gr.Markdown('This demo showcases the cross-domain retrieval capabilities of our self-supervised cross domain training as presented @CVPR 2022. For details please refer to [the paper](https://openaccess.thecvf.com/content/CVPR2022/papers/Harary_Unsupervised_Domain_Generalization_by_Learning_a_Bridge_Across_Domains_CVPR_2022_paper.pdf)')
    gr.Markdown('The model is trained in an unsupervised manner on all domains without class labels. The labels are displayed to indicate retrieval success/failure.')
    gr.Markdown('## Instructions:')
    gr.Markdown('Select a query domain and a class from the drop-down menus and select any random image index from the domain using the slider below, then press the "Run" button. The query image and the retrieved results from each of the four domains, along with the class label will be presented.')
    gr.Markdown('## Select Query Domain: ')
    gr.Markdown('# Query Image: \t\t\t\t')
    # domain_drop = gr.Dropdown(domains)
    # cl_drop = gr.Dropdown(class_list)
#     domain_select_button = gr.Button("Select Domain")
    # slider = gr.Slider(0, min_len)
    # slider = gr.Slider(0, 10000)

    with gr.Row():
        with gr.Column():
            domain_drop = gr.Dropdown(domains, label='Domain')
            cl_drop = gr.Dropdown(class_list,  label='Query Class')
            slider = gr.Slider(0, 100, label='Query image selector slider')

        # gr.Markdown('\t')
        # gr.Markdown('\t')
        # gr.Markdown('\t')
        with gr.Column():
            src_cap = gr.Label()
            src_img = gr.Image()
    image_button = gr.Button("Run")


    out_images = []
    out_captions = []
    for d in domains:
        gr.Markdown(f'# Retrieved Images from {d.title()} Domain:')
        with gr.Row():
            for _ in range(num_results_per_domain):
                with gr.Column():
                    out_captions.append(gr.Label())
                    out_images.append(gr.Image())

    image_button.click(query, inputs=[slider, domain_drop, cl_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions)

demo.launch(share=True)