Spaces:
Sleeping
Sleeping
add class selection
Browse files
app.py
CHANGED
@@ -14,11 +14,16 @@ num_nn = 20
|
|
14 |
search_domain = 'all'
|
15 |
num_results_per_domain = 5
|
16 |
src_data_dict = {}
|
|
|
17 |
if search_domain == 'all':
|
18 |
for d in domains:
|
19 |
with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
|
20 |
src_data = pickle.load(fp)
|
21 |
-
|
|
|
|
|
|
|
|
|
22 |
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
|
23 |
src_data_dict[d] = (src_data,src_nn_fit)
|
24 |
else:
|
@@ -32,12 +37,17 @@ dst_data_dict = {}
|
|
32 |
min_len = 1e10
|
33 |
for d in domains:
|
34 |
with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
|
35 |
-
|
36 |
-
|
|
|
|
|
|
|
|
|
|
|
37 |
|
38 |
-
def query(query_index, query_domain):
|
39 |
dst_data = dst_data_dict[query_domain]
|
40 |
-
dst_img_path = os.path.join(data_root, dst_data[
|
41 |
img_paths = [dst_img_path]
|
42 |
q_cl = dst_img_path.split('/')[-2]
|
43 |
captions = [f'Query: {q_cl}'.title()]
|
@@ -61,17 +71,21 @@ with demo:
|
|
61 |
gr.Markdown('## Instructions:')
|
62 |
gr.Markdown('Select a query domain from the dropdown menu and the select any random image from the domain using the slider below. The retrieved results from each of the four domains, along with the class label will be presented.')
|
63 |
gr.Markdown('## Select Query Domain: ')
|
64 |
-
|
|
|
|
|
65 |
# domain_select_button = gr.Button("Select Domain")
|
66 |
-
slider = gr.Slider(0, min_len)
|
67 |
# slider = gr.Slider(0, 10000)
|
68 |
image_button = gr.Button("Run")
|
69 |
-
|
70 |
with gr.Row():
|
71 |
-
gr.
|
72 |
-
|
73 |
-
|
74 |
-
|
|
|
|
|
|
|
75 |
with gr.Column():
|
76 |
src_cap = gr.Label()
|
77 |
src_img = gr.Image()
|
@@ -87,6 +101,6 @@ with demo:
|
|
87 |
out_captions.append(gr.Label())
|
88 |
out_images.append(gr.Image())
|
89 |
|
90 |
-
image_button.click(query, inputs=[slider, domain_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions)
|
91 |
|
92 |
demo.launch(share=True)
|
|
|
14 |
search_domain = 'all'
|
15 |
num_results_per_domain = 5
|
16 |
src_data_dict = {}
|
17 |
+
class_list = []
|
18 |
if search_domain == 'all':
|
19 |
for d in domains:
|
20 |
with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
|
21 |
src_data = pickle.load(fp)
|
22 |
+
if class_list == []:
|
23 |
+
for p in src_data[0]:
|
24 |
+
cl = p.split('/')[-2]
|
25 |
+
if cl not in class_list
|
26 |
+
class_list.append(cl)
|
27 |
src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
|
28 |
src_data_dict[d] = (src_data,src_nn_fit)
|
29 |
else:
|
|
|
37 |
min_len = 1e10
|
38 |
for d in domains:
|
39 |
with open(os.path.join(feat_dir, f'src_{d}_{shots}.pkl'), 'rb') as fp:
|
40 |
+
dest_data = pickle.load(fp)
|
41 |
+
dst_data_dict[d] = {cl: [] for cl in class_list}
|
42 |
+
for p in dst_data[0]:
|
43 |
+
cl = p.split('/')[-2]
|
44 |
+
dst_data_dict[d][cl].append(p)
|
45 |
+
for cl in class_list:
|
46 |
+
min_len = min(min_len, len(dst_data_dict[d][cl]))
|
47 |
|
48 |
+
def query(query_index, query_domain, cl):
|
49 |
dst_data = dst_data_dict[query_domain]
|
50 |
+
dst_img_path = os.path.join(data_root, dst_data[cl][query_index])
|
51 |
img_paths = [dst_img_path]
|
52 |
q_cl = dst_img_path.split('/')[-2]
|
53 |
captions = [f'Query: {q_cl}'.title()]
|
|
|
71 |
gr.Markdown('## Instructions:')
|
72 |
gr.Markdown('Select a query domain from the dropdown menu and the select any random image from the domain using the slider below. The retrieved results from each of the four domains, along with the class label will be presented.')
|
73 |
gr.Markdown('## Select Query Domain: ')
|
74 |
+
gr.Markdown('# Query Image: \t\t\t\t')
|
75 |
+
# domain_drop = gr.Dropdown(domains)
|
76 |
+
# cl_drop = gr.Dropdown(class_list)
|
77 |
# domain_select_button = gr.Button("Select Domain")
|
78 |
+
# slider = gr.Slider(0, min_len)
|
79 |
# slider = gr.Slider(0, 10000)
|
80 |
image_button = gr.Button("Run")
|
|
|
81 |
with gr.Row():
|
82 |
+
with gr.Column():
|
83 |
+
domain_drop = gr.Dropdown(domains)
|
84 |
+
cl_drop = gr.Dropdown(class_list)
|
85 |
+
slider = gr.Slider(0, min_len)
|
86 |
+
# gr.Markdown('\t')
|
87 |
+
# gr.Markdown('\t')
|
88 |
+
# gr.Markdown('\t')
|
89 |
with gr.Column():
|
90 |
src_cap = gr.Label()
|
91 |
src_img = gr.Image()
|
|
|
101 |
out_captions.append(gr.Label())
|
102 |
out_images.append(gr.Image())
|
103 |
|
104 |
+
image_button.click(query, inputs=[slider, domain_drop, cl_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions)
|
105 |
|
106 |
demo.launch(share=True)
|