lukemelas commited on
Commit
e11cde7
1 Parent(s): cfe4337

Update app

Browse files
app.py CHANGED
@@ -1,7 +1,6 @@
1
- import os
2
  import os.path
3
  import sys
4
- from os.path import splitext
5
 
6
  import gradio as gr
7
  import matplotlib.pyplot as plt
@@ -19,6 +18,7 @@ from skimage.color import label2rgb
19
  from torch.utils.hooks import RemovableHandle
20
  from torchvision import transforms
21
  from torchvision.utils import make_grid
 
22
 
23
 
24
  def get_model(name: str):
@@ -67,6 +67,8 @@ def get_diagonal(W: scipy.sparse.csr_matrix, threshold: float = 1e-12):
67
  model_name = 'dino_vitb16' # TODOL Figure out how to make this user-editable
68
  K = 5
69
 
 
 
70
 
71
  # Load model
72
  model, val_transform, patch_size, num_heads = get_model(model_name)
@@ -122,7 +124,7 @@ def segment(inp: Image):
122
 
123
  # Remove hook from the model
124
  handle.remove()
125
-
126
  # Normalize features
127
  normalize = True
128
  if normalize:
@@ -160,27 +162,36 @@ def segment(inp: Image):
160
  eigenvectors[k] = 0 - eigenvectors[k]
161
 
162
  # Arrange eigenvectors into grid
163
- output_image_grid = []
164
- for i in range(1, K):
165
- eigenvector = eigenvectors[i].reshape(1, 1, H_pad, W_pad)
166
- eigenvector = F.interpolate(eigenvector, size=(H, W), mode='nearest') # slightly off, but for visualizations this is okay
167
- # plt.imsave('./tmp.png', eigenvector.squeeze().numpy()) # save to a temporary location
168
- # eigenvector = Image.open('./tmp.png').convert('RGB') # load back from our temporary location
169
- output_image_grid.append(eigenvector)
170
- img_tensor_grid = make_grid(output_image_grid, nrow=8, pad_value=1)
171
-
172
- # Postprocess for Gradio
173
- img_tensor_grid.numpy().squeeze()
174
-
175
- return img_tensor_grid
 
 
 
 
 
 
176
 
177
  # Placeholders
178
- input_placeholders = GradioInputImage(shape=(256, 256), source="upload", tool="editor", type="pil")
179
- output_placeholders = GradioOutputImage(type="numpy", label=f"Eigenvectors")
180
- # alternatively: [GradioOutputImage(type="numpy", label=f"Eigenvector {i}") for i in range(K)]
181
 
182
  # Metadata
183
- examples = [["images/img1.jpg"], ["images/img2.jpg"]]
 
 
 
184
  title = "Deep Spectral Segmentation"
185
  description = "Deep spectral segmentation..."
186
  thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
 
1
+ import io
2
  import os.path
3
  import sys
 
4
 
5
  import gradio as gr
6
  import matplotlib.pyplot as plt
 
18
  from torch.utils.hooks import RemovableHandle
19
  from torchvision import transforms
20
  from torchvision.utils import make_grid
21
+ from matplotlib.pyplot import get_cmap
22
 
23
 
24
  def get_model(name: str):
 
67
  model_name = 'dino_vitb16' # TODOL Figure out how to make this user-editable
68
  K = 5
69
 
70
+ # Fixed parameters
71
+ MAX_SIZE = 384
72
 
73
  # Load model
74
  model, val_transform, patch_size, num_heads = get_model(model_name)
 
124
 
125
  # Remove hook from the model
126
  handle.remove()
127
+
128
  # Normalize features
129
  normalize = True
130
  if normalize:
 
162
  eigenvectors[k] = 0 - eigenvectors[k]
163
 
164
  # Arrange eigenvectors into grid
165
+ cmap = get_cmap('viridis')
166
+ output_images = []
167
+ for i in range(1, K + 1):
168
+ eigenvector = eigenvectors[i].reshape(1, 1, H_patch, W_patch) # .reshape(1, 1, H_pad, W_pad)
169
+ eigenvector: torch.Tensor = F.interpolate(eigenvector, size=(H_pad, W_pad), mode='bilinear', align_corners=False) # slightly off, but for visualizations this is okay
170
+ buffer = io.BytesIO()
171
+ plt.imsave(buffer, eigenvector.squeeze().numpy(), format='png') # save to a temporary location
172
+ buffer.seek(0)
173
+ eigenvector_vis = Image.open(buffer).convert('RGB')
174
+ # eigenvector_vis = TF.to_tensor(eigenvector_vis).unsqueeze(0)
175
+ eigenvector_vis = np.array(eigenvector_vis)
176
+ output_images.append(eigenvector_vis)
177
+ # output_images = torch.cat(output_images, dim=0)
178
+ # output_images = make_grid(output_images, nrow=8, pad_value=1)
179
+
180
+ # # Postprocess for Gradio
181
+ # output_images = np.array(TF.to_pil_image(output_images))
182
+ print(f'{len(output_images)=}')
183
+ return output_images
184
 
185
  # Placeholders
186
+ input_placeholders = GradioInputImage(source="upload", tool="editor", type="pil")
187
+ # output_placeholders = GradioOutputImage(type="numpy", label=f"Eigenvectors")
188
+ output_placeholders = [GradioOutputImage(type="numpy", label=f"Eigenvector {i}") for i in range(K)]
189
 
190
  # Metadata
191
+ examples = [f"examples/{stem}.jpg" for stem in [
192
+ '2008_000099', '2008_000499', '2007_009446', '2007_001586', '2010_001256', '2008_000764', '2008_000705', # '2007_000039'
193
+ ]]
194
+
195
  title = "Deep Spectral Segmentation"
196
  description = "Deep spectral segmentation..."
197
  thumbnail = "https://raw.githubusercontent.com/gradio-app/hub-echonet/master/thumbnail.png"
examples/2007_000039.jpg ADDED
examples/2007_001586.jpg ADDED
examples/2007_009446.jpg ADDED
examples/2008_000099.jpg ADDED
examples/2008_000499.jpg ADDED
examples/2008_000705.jpg ADDED
examples/2008_000764.jpg ADDED
examples/2010_001256.jpg ADDED