kaz-sony commited on
Commit
a0bafb1
1 Parent(s): 06bf94f

update examples to use cache

Browse files
Files changed (1) hide show
  1. app.py +199 -163
app.py CHANGED
@@ -20,10 +20,6 @@ from extern.ZoeDepth.zoedepth.utils.misc import colorize
20
 
21
  from gradio_model3dgscamera import Model3DGSCamera
22
 
23
- IMAGE_SIZE = 512
24
- NEAR, FAR = 0.01, 100
25
- FOVY = np.deg2rad(55)
26
-
27
  def download_models():
28
  models = [
29
  {
@@ -59,6 +55,57 @@ def download_models():
59
  token=model['token']
60
  )
61
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
62
  # Crop the image to the shorter side.
63
  def crop(img: Image) -> Image:
64
  W, H = img.size
@@ -68,68 +115,9 @@ def crop(img: Image) -> Image:
68
  else:
69
  left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
70
  top, bottom = 0, H
71
- return img.crop((left, top, right, bottom))
72
-
73
- def unproject(depth):
74
- fovy_deg = 55
75
- H, W = depth.shape[2:4]
76
-
77
- mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
78
-
79
- viewport_mtx = get_viewport_matrix(
80
- IMAGE_SIZE, IMAGE_SIZE,
81
- batch_size=1
82
- ).to(depth)
83
-
84
- # Projection matrix.
85
- fovy = torch.ones(1) * FOVY
86
- proj_mtx = get_projection_matrix(
87
- fovy=fovy,
88
- aspect_wh=1.,
89
- near=NEAR,
90
- far=FAR
91
- ).to(depth)
92
-
93
- view_mtx = camera_lookat(
94
- torch.tensor([[0., 0., 0.]]),
95
- torch.tensor([[0., 0., 1.]]),
96
- torch.tensor([[0., -1., 0.]])
97
- ).to(depth)
98
-
99
- scr_mtx = (viewport_mtx @ proj_mtx).to(depth)
100
-
101
- grid = torch.stack(torch.meshgrid(
102
- torch.arange(W), torch.arange(H), indexing='xy'), dim=-1
103
- ).to(depth)[None] # BHW2
104
-
105
- screen = F.pad(grid, (0, 1), 'constant', 0)
106
- screen = F.pad(screen, (0, 1), 'constant', 1)
107
- screen_flat = rearrange(screen, 'b h w c -> b (h w) c')
108
-
109
- eye = screen_flat @ torch.linalg.inv_ex(
110
- scr_mtx.float()
111
- )[0].mT.to(depth)
112
- eye = eye * rearrange(depth, 'b c h w -> b (h w) c')
113
- eye[..., 3] = 1
114
-
115
- points = eye @ torch.linalg.inv_ex(view_mtx.float())[0].mT.to(depth)
116
- points = points[0, :, :3]
117
-
118
- # Translate to the origin.
119
- points[..., 2] -= mean_depth
120
- camera_pos = (0, 0, -mean_depth)
121
- view_mtx = camera_lookat(
122
- torch.tensor([[0., 0., -mean_depth]]),
123
- torch.tensor([[0., 0., 0.]]),
124
- torch.tensor([[0., -1., 0.]])
125
- ).to(depth)
126
-
127
- return points, camera_pos, view_mtx, proj_mtx
128
-
129
- def calc_dist2(points: np.ndarray):
130
- dists, _ = KDTree(points).query(points, k=4)
131
- mean_dists = (dists[:, 1:] ** 2).mean(1)
132
- return mean_dists
133
 
134
  def save_as_splat(
135
  filepath: str,
@@ -171,6 +159,44 @@ def save_as_splat(
171
  with open(filepath, "wb") as f:
172
  f.write(buffer.getvalue())
173
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
174
  def view_from_rt(position, rotation):
175
  t = np.array(position)
176
  euler = np.array(rotation)
@@ -213,48 +239,87 @@ def view_from_rt(position, rotation):
213
  return B @ view_mtx
214
 
215
 
216
- # Setup.
217
- download_models()
218
-
219
- mde = torch.hub.load(
220
- './extern/ZoeDepth',
221
- 'ZoeD_N',
222
- source='local',
223
- pretrained=True,
224
- trust_repo=True
225
- )
226
-
227
- import spaces
228
-
229
- check_call([
230
- sys.executable, '-m', 'pip', 'install',
231
- 'extern/splatting-0.0.1-py3-none-any.whl'
232
- ])
233
-
234
- from genwarp import GenWarp
235
- from genwarp.ops import (
236
- camera_lookat, get_projection_matrix, get_viewport_matrix
237
- )
238
-
239
- # GenWarp
240
- genwarp_cfg = dict(
241
- pretrained_model_path='checkpoints',
242
- checkpoint_name='multi1',
243
- half_precision_weights=True
244
- )
245
- genwarp_nvs = GenWarp(cfg=genwarp_cfg, device='cpu')
246
-
247
-
248
  with tempfile.TemporaryDirectory() as tmpdir:
249
  with gr.Blocks(
250
  title='GenWarp Demo',
251
  css='img {display: inline;}'
252
  ) as demo:
253
  # Internal states.
254
- src_image = gr.State()
255
- src_depth = gr.State()
256
- proj_mtx = gr.State()
257
- src_view_mtx = gr.State()
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
258
 
259
  # Blocks.
260
  gr.Markdown(
@@ -270,23 +335,23 @@ with tempfile.TemporaryDirectory() as tmpdir:
270
  This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer to the [paper](https://arxiv.org/abs/2405.17251).
271
 
272
  ## How to Use
 
 
 
 
 
273
  1. Upload a reference image to "Reference Input"
274
- - You can also select a image from "Examples"
275
  2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer
276
  3. Hit "Generate a novel view" button and check the result
277
 
278
  ## Tips
 
279
  - Extremely large camera movement from the input view might cause low performance results due to the unexpected deviation from the training distribution, which is not the scope of this model. Instead, you can feed the generation result for the small camera movement repeatedly and progressively move towards a desired view.
280
  - 3D viewer might take some time to update especially when trying different images back to back. Wait until it fully updates to the new image.
281
 
282
  """
283
  )
284
  file = gr.File(label='Reference Input', file_types=['image'])
285
- examples = gr.Examples(
286
- examples=['./assets/pexels-heyho-5998120_19mm.jpg',
287
- './assets/pexels-itsterrymag-12639296_24mm.jpg'],
288
- inputs=file
289
- )
290
  with gr.Row():
291
  image_widget = gr.Image(
292
  label='Reference View', type='filepath',
@@ -312,68 +377,39 @@ with tempfile.TemporaryDirectory() as tmpdir:
312
  gen_widget = gr.Image(
313
  label='Generated View', type='pil', interactive=False
314
  )
315
-
316
- # Callbacks
317
- @spaces.GPU
318
- def cb_mde(image_file: str):
319
- image = to_tensor(crop(Image.open(
320
- image_file
321
- ).convert('RGB')).resize((IMAGE_SIZE, IMAGE_SIZE)))[None].cuda()
322
- depth = mde.cuda().infer(image)
323
- depth_image = to_pil_image(colorize(depth[0]))
324
- return to_pil_image(image[0]), depth_image, image.cpu().detach(), depth.cpu().detach()
325
-
326
- @spaces.GPU
327
- def cb_3d(image, depth, image_file):
328
- xyz, camera_pos, view_mtx, proj_mtx = unproject(depth.cuda())
329
- rgb = rearrange(image, 'b c h w -> b (h w) c')[0]
330
- splat_file = join(tmpdir, f'./{splitext(basename(image_file))[0]}.splat')
331
- save_as_splat(splat_file, xyz.cpu().detach().numpy(), rgb.cpu().detach().numpy())
332
- return (splat_file, camera_pos, None), view_mtx.cpu().detach(), proj_mtx.cpu().detach()
333
-
334
- @spaces.GPU
335
- def cb_generate(viewer, image, depth, src_view_mtx, proj_mtx):
336
- image = image.cuda()
337
- depth = depth.cuda()
338
- src_view_mtx = src_view_mtx.cuda()
339
- proj_mtx = proj_mtx.cuda()
340
- src_camera_pos = viewer[1]
341
- src_camera_rot = viewer[2]
342
- tar_view_mtx = view_from_rt(src_camera_pos, src_camera_rot)
343
- tar_view_mtx = torch.from_numpy(tar_view_mtx).to(image)
344
- rel_view_mtx = (
345
- tar_view_mtx @ torch.linalg.inv(src_view_mtx.to(image))
346
- ).to(image)
347
-
348
- # GenWarp.
349
- renders = genwarp_nvs.to('cuda')(
350
- src_image=image.half(),
351
- src_depth=depth.half(),
352
- rel_view_mtx=rel_view_mtx.half(),
353
- src_proj_mtx=proj_mtx.half(),
354
- tar_proj_mtx=proj_mtx.half()
355
- )
356
-
357
- warped = renders['warped']
358
- synthesized = renders['synthesized']
359
- warped_pil = to_pil_image(warped[0])
360
- synthesized_pil = to_pil_image(synthesized[0])
361
-
362
- return warped_pil, synthesized_pil
363
 
364
  # Events
365
- file.change(
366
  fn=cb_mde,
367
  inputs=file,
368
- outputs=[image_widget, depth_widget, src_image, src_depth]
369
- ).then(
370
  fn=cb_3d,
371
- inputs=[src_image, src_depth, image_widget],
372
- outputs=[viewer, src_view_mtx, proj_mtx])
 
373
  button.click(
374
  fn=cb_generate,
375
- inputs=[viewer, src_image, src_depth, src_view_mtx, proj_mtx],
376
- outputs=[warped_widget, gen_widget])
 
 
 
 
 
 
 
377
 
378
  if __name__ == '__main__':
379
  demo.launch()
 
20
 
21
  from gradio_model3dgscamera import Model3DGSCamera
22
 
 
 
 
 
23
  def download_models():
24
  models = [
25
  {
 
55
  token=model['token']
56
  )
57
 
58
+ # Setup.
59
+ download_models()
60
+
61
+ mde = torch.hub.load(
62
+ './extern/ZoeDepth',
63
+ 'ZoeD_N',
64
+ source='local',
65
+ pretrained=True,
66
+ trust_repo=True
67
+ )
68
+
69
+ import spaces
70
+
71
+ check_call([
72
+ sys.executable, '-m', 'pip', 'install',
73
+ 'extern/splatting-0.0.1-py3-none-any.whl'
74
+ ])
75
+
76
+ from genwarp import GenWarp
77
+ from genwarp.ops import (
78
+ camera_lookat, get_projection_matrix, get_viewport_matrix
79
+ )
80
+
81
+ # GenWarp
82
+ genwarp_cfg = dict(
83
+ pretrained_model_path='checkpoints',
84
+ checkpoint_name='multi1',
85
+ half_precision_weights=True
86
+ )
87
+ genwarp_nvs = GenWarp(cfg=genwarp_cfg, device='cpu')
88
+
89
+ # Fixed parameters.
90
+ IMAGE_SIZE = 512
91
+ NEAR, FAR = 0.01, 100
92
+ FOVY = np.deg2rad(55)
93
+ PROJ_MTX = get_projection_matrix(
94
+ fovy=torch.ones(1) * FOVY,
95
+ aspect_wh=1.,
96
+ near=NEAR,
97
+ far=FAR
98
+ )
99
+ VIEW_MTX = camera_lookat(
100
+ torch.tensor([[0., 0., 0.]]),
101
+ torch.tensor([[0., 0., 1.]]),
102
+ torch.tensor([[0., -1., 0.]])
103
+ )
104
+ VIEWPORT_MTX = get_viewport_matrix(
105
+ IMAGE_SIZE, IMAGE_SIZE,
106
+ batch_size=1
107
+ )
108
+
109
  # Crop the image to the shorter side.
110
  def crop(img: Image) -> Image:
111
  W, H = img.size
 
115
  else:
116
  left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
117
  top, bottom = 0, H
118
+ img = img.crop((left, top, right, bottom))
119
+ img = img.resize((IMAGE_SIZE, IMAGE_SIZE))
120
+ return img
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
121
 
122
  def save_as_splat(
123
  filepath: str,
 
159
  with open(filepath, "wb") as f:
160
  f.write(buffer.getvalue())
161
 
162
+ def calc_dist2(points: np.ndarray):
163
+ dists, _ = KDTree(points).query(points, k=4)
164
+ mean_dists = (dists[:, 1:] ** 2).mean(1)
165
+ return mean_dists
166
+
167
+ def unproject(depth):
168
+ H, W = depth.shape[2:4]
169
+ mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
170
+
171
+ # Matrices.
172
+ viewport_mtx = VIEWPORT_MTX.to(depth)
173
+ proj_mtx = PROJ_MTX.to(depth)
174
+ view_mtx = VIEW_MTX.to(depth)
175
+ scr_mtx = (viewport_mtx @ proj_mtx).to(depth)
176
+
177
+ grid = torch.stack(torch.meshgrid(
178
+ torch.arange(W), torch.arange(H), indexing='xy'), dim=-1
179
+ ).to(depth)[None] # BHW2
180
+
181
+ screen = F.pad(grid, (0, 1), 'constant', 0)
182
+ screen = F.pad(screen, (0, 1), 'constant', 1)
183
+ screen_flat = rearrange(screen, 'b h w c -> b (h w) c')
184
+
185
+ eye = screen_flat @ torch.linalg.inv_ex(
186
+ scr_mtx.float()
187
+ )[0].mT.to(depth)
188
+ eye = eye * rearrange(depth, 'b c h w -> b (h w) c')
189
+ eye[..., 3] = 1
190
+
191
+ points = eye @ torch.linalg.inv_ex(view_mtx.float())[0].mT.to(depth)
192
+ points = points[0, :, :3]
193
+
194
+ # Translate to the origin.
195
+ points[..., 2] -= mean_depth
196
+ camera_pos = (0, 0, -mean_depth)
197
+
198
+ return points, camera_pos
199
+
200
  def view_from_rt(position, rotation):
201
  t = np.array(position)
202
  euler = np.array(rotation)
 
239
  return B @ view_mtx
240
 
241
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
242
  with tempfile.TemporaryDirectory() as tmpdir:
243
  with gr.Blocks(
244
  title='GenWarp Demo',
245
  css='img {display: inline;}'
246
  ) as demo:
247
  # Internal states.
248
+ image = gr.State()
249
+ depth = gr.State()
250
+
251
+ # Callbacks
252
+ @spaces.GPU()
253
+ def cb_mde(image_file: str):
254
+ # Load an image.
255
+ image_pil = crop(Image.open(image_file).convert('RGB'))
256
+ image = to_tensor(image_pil)[None].detach()
257
+ # Get depth.
258
+ depth = mde.cuda().infer(image.cuda()).cpu().detach()
259
+ depth_pil = to_pil_image(colorize(depth[0]))
260
+ return image_pil, depth_pil, image, depth
261
+
262
+ @spaces.GPU()
263
+ def cb_3d(image_file, image, depth):
264
+ # Unproject.
265
+ xyz, camera_pos = unproject(depth.cuda())
266
+ xyz = xyz.cpu().detach().numpy()
267
+ # Save as a splat.
268
+ ## Output filename.
269
+ splat_file = join(
270
+ tmpdir, f'./{splitext(basename(image_file))[0]}.splat')
271
+ rgb = rearrange(image, 'b c h w -> b (h w) c')[0].numpy()
272
+ save_as_splat(splat_file, xyz, rgb)
273
+ return splat_file, camera_pos, (0, 0, 0)
274
+
275
+ @spaces.GPU()
276
+ def cb_generate(viewer, image, depth):
277
+ if depth is None:
278
+ gr.Error('Image and Depth are not set. Try again.')
279
+ return None, None
280
+
281
+ mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
282
+ src_view_mtx = camera_lookat(
283
+ torch.tensor([[0., 0., -mean_depth]]),
284
+ torch.tensor([[0., 0., 0.]]),
285
+ torch.tensor([[0., -1., 0.]])
286
+ ).to(depth)
287
+ tar_camera_pos, tar_camera_rot = viewer[1:3]
288
+ tar_view_mtx = torch.from_numpy(view_from_rt(
289
+ tar_camera_pos, tar_camera_rot
290
+ ))
291
+ rel_view_mtx = (
292
+ tar_view_mtx @ torch.linalg.inv(src_view_mtx.double())
293
+ ).half().cuda()
294
+ proj_mtx = PROJ_MTX.half().cuda()
295
+
296
+ # GenWarp.
297
+ renders = genwarp_nvs.to('cuda')(
298
+ src_image=image.half().cuda(),
299
+ src_depth=depth.half().cuda(),
300
+ rel_view_mtx=rel_view_mtx,
301
+ src_proj_mtx=proj_mtx,
302
+ tar_proj_mtx=proj_mtx
303
+ )
304
+ warped_pil = to_pil_image(renders['warped'].cpu()[0])
305
+ synthesized_pil = to_pil_image(renders['synthesized'].cpu()[0])
306
+
307
+ return warped_pil, synthesized_pil
308
+
309
+ def process_example(image_file):
310
+ gr.Error('')
311
+ image_pil, depth_pil, image, depth = cb_mde(image_file)
312
+ viewer = cb_3d(image_file, image, depth)
313
+ # Fixed angle for examples.
314
+ viewer = (viewer[0], (-2.020, -0.727, -5.236), (-0.132, 0.378, 0.0))
315
+ warped_pil, synthsized_pil = cb_generate(
316
+ viewer, image, depth
317
+ )
318
+ return (
319
+ image_pil, depth_pil, viewer,
320
+ warped_pil, synthsized_pil,
321
+ None, None # Clear internal states.
322
+ )
323
 
324
  # Blocks.
325
  gr.Markdown(
 
335
  This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer to the [paper](https://arxiv.org/abs/2405.17251).
336
 
337
  ## How to Use
338
+
339
+ ### Try examples
340
+ - Examples are in the bottom section of the page
341
+
342
+ ### Upload your own images
343
  1. Upload a reference image to "Reference Input"
 
344
  2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer
345
  3. Hit "Generate a novel view" button and check the result
346
 
347
  ## Tips
348
+ - This model is mainly trained for indoor/outdoor scenery. It might not work well for object-centric inputs. For details on training the model, please check our [paper](https://arxiv.org/abs/2405.17251).
349
  - Extremely large camera movement from the input view might cause low performance results due to the unexpected deviation from the training distribution, which is not the scope of this model. Instead, you can feed the generation result for the small camera movement repeatedly and progressively move towards a desired view.
350
  - 3D viewer might take some time to update especially when trying different images back to back. Wait until it fully updates to the new image.
351
 
352
  """
353
  )
354
  file = gr.File(label='Reference Input', file_types=['image'])
 
 
 
 
 
355
  with gr.Row():
356
  image_widget = gr.Image(
357
  label='Reference View', type='filepath',
 
377
  gen_widget = gr.Image(
378
  label='Generated View', type='pil', interactive=False
379
  )
380
+ examples = gr.Examples(
381
+ examples=[
382
+ './assets/pexels-heyho-5998120_19mm.jpg',
383
+ './assets/pexels-itsterrymag-12639296_24mm.jpg'
384
+ ],
385
+ fn=process_example,
386
+ inputs=file,
387
+ outputs=[image_widget, depth_widget, viewer,
388
+ warped_widget, gen_widget,
389
+ image, depth]
390
+ )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
391
 
392
  # Events
393
+ file.upload(
394
  fn=cb_mde,
395
  inputs=file,
396
+ outputs=[image_widget, depth_widget, image, depth]
397
+ ).success(
398
  fn=cb_3d,
399
+ inputs=[image_widget, image, depth],
400
+ outputs=viewer
401
+ )
402
  button.click(
403
  fn=cb_generate,
404
+ inputs=[viewer, image, depth],
405
+ outputs=[warped_widget, gen_widget]
406
+ )
407
+ # To re-calculate the uncached depth for examples in background.
408
+ examples.load_input_event.success(
409
+ fn=lambda x: cb_mde(x)[2:4],
410
+ inputs=file,
411
+ outputs=[image, depth]
412
+ )
413
 
414
  if __name__ == '__main__':
415
  demo.launch()