aarbelle commited on
Commit
d898082
1 Parent(s): 24cb51e

update app.py

Browse files
Files changed (1) hide show
  1. app.py +18 -25
app.py CHANGED
@@ -2,14 +2,15 @@ import pickle
2
  import os
3
  from sklearn.neighbors import NearestNeighbors
4
  import numpy as np
5
- num_nn = 20
6
  import gradio as gr
7
  from PIL import Image
8
 
9
- data_root = '/dccstor/elishc1/datasets/DomainNet'
10
- feat_dir = 'brad_feats'
11
  domains = ['real', 'painting', 'clipart', 'sketch']
12
  shots = '-1'
 
 
13
  search_domain = 'all'
14
  num_results_per_domain = 5
15
  src_data_dict = {}
@@ -18,16 +19,13 @@ if search_domain == 'all':
18
  with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
19
  src_data = pickle.load(fp)
20
 
21
- src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain,
22
- algorithm='auto', n_jobs=-1).fit(src_data[1])
23
  src_data_dict[d] = (src_data,src_nn_fit)
24
  else:
25
 
26
- with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as
27
- fp:
28
  src_data = pickle.load(fp)
29
- src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain,
30
- algorithm='auto', n_jobs=-1).fit(src_data[1])
31
  src_data_dict[search_domain] = (src_data,src_nn_fit)
32
 
33
  dst_data_dict = {}
@@ -41,27 +39,23 @@ def query(query_index, query_domain):
41
  img_paths = [dst_img_path]
42
  q_cl = dst_img_path.split('/')[-2]
43
  captions = [f'Query: {q_cl}']
44
- for s_domain, s_data in src_data_dict.items():
45
- _, top_n_matches_ids =
46
- s_data[1].kneighbors(dst_data[1][query_index:query_index+1])
47
  top_n_labels = s_data[0][2][top_n_matches_ids][0]
48
- src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in
49
- top_n_matches_ids[0]]
50
  img_paths += src_img_pths
51
 
52
  for p in src_img_pths:
53
  src_cl = p.split('/')[-2]
54
  src_file = p.split('/')[-1]
55
- captions.append(src_cl)
56
- return tuple([Image.open(p) for p in img_paths])+ tuple(captions)
57
- try:
58
- demo.close()
59
- except:
60
- pass
61
  demo = gr.Blocks()
62
  with demo:
63
  gr.Markdown('## Select Query Domain: ')
64
- domain_drop = gr.Dropdown(domains)
65
  # domain_select_button = gr.Button("Select Domain")
66
  slider = gr.Slider(0, 1000)
67
  image_button = gr.Button("Run")
@@ -70,7 +64,7 @@ with demo:
70
  src_cap = gr.Label()
71
  src_img = gr.Image()
72
 
73
-
74
  out_images = []
75
  out_captions = []
76
  for d in domains:
@@ -81,7 +75,6 @@ with demo:
81
  out_captions.append(gr.Label())
82
  out_images.append(gr.Image())
83
 
84
- image_button.click(query, inputs=[slider, domain_drop],
85
- outputs=[src_img]+out_images +[src_cap]+ out_captions)
86
- demo.launch(share=True)
87
 
 
 
2
  import os
3
  from sklearn.neighbors import NearestNeighbors
4
  import numpy as np
 
5
  import gradio as gr
6
  from PIL import Image
7
 
8
+ data_root = 'https://ai-vision-public-datasets.s3.eu.cloud-object-storage.appdomain.cloud/DomainNet'
9
+ feat_dir = 'brad_feats'
10
  domains = ['real', 'painting', 'clipart', 'sketch']
11
  shots = '-1'
12
+ num_nn = 20
13
+
14
  search_domain = 'all'
15
  num_results_per_domain = 5
16
  src_data_dict = {}
 
19
  with open(os.path.join(feat_dir, f'dst_{d}_{shots}.pkl'), 'rb') as fp:
20
  src_data = pickle.load(fp)
21
 
22
+ src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
 
23
  src_data_dict[d] = (src_data,src_nn_fit)
24
  else:
25
 
26
+ with open(os.path.join(feat_dir, f'dst_{search_domain}_{shots}.pkl'), 'rb') as fp:
 
27
  src_data = pickle.load(fp)
28
+ src_nn_fit = NearestNeighbors(n_neighbors=num_results_per_domain, algorithm='auto', n_jobs=-1).fit(src_data[1])
 
29
  src_data_dict[search_domain] = (src_data,src_nn_fit)
30
 
31
  dst_data_dict = {}
 
39
  img_paths = [dst_img_path]
40
  q_cl = dst_img_path.split('/')[-2]
41
  captions = [f'Query: {q_cl}']
42
+ for s_domain, s_data in src_data_dict.items():
43
+ _, top_n_matches_ids = s_data[1].kneighbors(dst_data[1][query_index:query_index+1])
 
44
  top_n_labels = s_data[0][2][top_n_matches_ids][0]
45
+ src_img_pths = [os.path.join(data_root, s_data[0][0][ix]) for ix in top_n_matches_ids[0]]
 
46
  img_paths += src_img_pths
47
 
48
  for p in src_img_pths:
49
  src_cl = p.split('/')[-2]
50
  src_file = p.split('/')[-1]
51
+ captions.append(src_cl)
52
+ print(img_paths)
53
+ return tuple([p for p in img_paths])+ tuple(captions)
54
+
 
 
55
  demo = gr.Blocks()
56
  with demo:
57
  gr.Markdown('## Select Query Domain: ')
58
+ domain_drop = gr.Dropdown(domains)
59
  # domain_select_button = gr.Button("Select Domain")
60
  slider = gr.Slider(0, 1000)
61
  image_button = gr.Button("Run")
 
64
  src_cap = gr.Label()
65
  src_img = gr.Image()
66
 
67
+
68
  out_images = []
69
  out_captions = []
70
  for d in domains:
 
75
  out_captions.append(gr.Label())
76
  out_images.append(gr.Image())
77
 
78
+ image_button.click(query, inputs=[slider, domain_drop], outputs=[src_img]+out_images +[src_cap]+ out_captions)
 
 
79
 
80
+ demo.launch(share=True)