update examples to use cache
Browse files
app.py
CHANGED
@@ -20,10 +20,6 @@ from extern.ZoeDepth.zoedepth.utils.misc import colorize
|
|
20 |
|
21 |
from gradio_model3dgscamera import Model3DGSCamera
|
22 |
|
23 |
-
IMAGE_SIZE = 512
|
24 |
-
NEAR, FAR = 0.01, 100
|
25 |
-
FOVY = np.deg2rad(55)
|
26 |
-
|
27 |
def download_models():
|
28 |
models = [
|
29 |
{
|
@@ -59,6 +55,57 @@ def download_models():
|
|
59 |
token=model['token']
|
60 |
)
|
61 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
62 |
# Crop the image to the shorter side.
|
63 |
def crop(img: Image) -> Image:
|
64 |
W, H = img.size
|
@@ -68,68 +115,9 @@ def crop(img: Image) -> Image:
|
|
68 |
else:
|
69 |
left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
|
70 |
top, bottom = 0, H
|
71 |
-
|
72 |
-
|
73 |
-
|
74 |
-
fovy_deg = 55
|
75 |
-
H, W = depth.shape[2:4]
|
76 |
-
|
77 |
-
mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
|
78 |
-
|
79 |
-
viewport_mtx = get_viewport_matrix(
|
80 |
-
IMAGE_SIZE, IMAGE_SIZE,
|
81 |
-
batch_size=1
|
82 |
-
).to(depth)
|
83 |
-
|
84 |
-
# Projection matrix.
|
85 |
-
fovy = torch.ones(1) * FOVY
|
86 |
-
proj_mtx = get_projection_matrix(
|
87 |
-
fovy=fovy,
|
88 |
-
aspect_wh=1.,
|
89 |
-
near=NEAR,
|
90 |
-
far=FAR
|
91 |
-
).to(depth)
|
92 |
-
|
93 |
-
view_mtx = camera_lookat(
|
94 |
-
torch.tensor([[0., 0., 0.]]),
|
95 |
-
torch.tensor([[0., 0., 1.]]),
|
96 |
-
torch.tensor([[0., -1., 0.]])
|
97 |
-
).to(depth)
|
98 |
-
|
99 |
-
scr_mtx = (viewport_mtx @ proj_mtx).to(depth)
|
100 |
-
|
101 |
-
grid = torch.stack(torch.meshgrid(
|
102 |
-
torch.arange(W), torch.arange(H), indexing='xy'), dim=-1
|
103 |
-
).to(depth)[None] # BHW2
|
104 |
-
|
105 |
-
screen = F.pad(grid, (0, 1), 'constant', 0)
|
106 |
-
screen = F.pad(screen, (0, 1), 'constant', 1)
|
107 |
-
screen_flat = rearrange(screen, 'b h w c -> b (h w) c')
|
108 |
-
|
109 |
-
eye = screen_flat @ torch.linalg.inv_ex(
|
110 |
-
scr_mtx.float()
|
111 |
-
)[0].mT.to(depth)
|
112 |
-
eye = eye * rearrange(depth, 'b c h w -> b (h w) c')
|
113 |
-
eye[..., 3] = 1
|
114 |
-
|
115 |
-
points = eye @ torch.linalg.inv_ex(view_mtx.float())[0].mT.to(depth)
|
116 |
-
points = points[0, :, :3]
|
117 |
-
|
118 |
-
# Translate to the origin.
|
119 |
-
points[..., 2] -= mean_depth
|
120 |
-
camera_pos = (0, 0, -mean_depth)
|
121 |
-
view_mtx = camera_lookat(
|
122 |
-
torch.tensor([[0., 0., -mean_depth]]),
|
123 |
-
torch.tensor([[0., 0., 0.]]),
|
124 |
-
torch.tensor([[0., -1., 0.]])
|
125 |
-
).to(depth)
|
126 |
-
|
127 |
-
return points, camera_pos, view_mtx, proj_mtx
|
128 |
-
|
129 |
-
def calc_dist2(points: np.ndarray):
|
130 |
-
dists, _ = KDTree(points).query(points, k=4)
|
131 |
-
mean_dists = (dists[:, 1:] ** 2).mean(1)
|
132 |
-
return mean_dists
|
133 |
|
134 |
def save_as_splat(
|
135 |
filepath: str,
|
@@ -171,6 +159,44 @@ def save_as_splat(
|
|
171 |
with open(filepath, "wb") as f:
|
172 |
f.write(buffer.getvalue())
|
173 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
174 |
def view_from_rt(position, rotation):
|
175 |
t = np.array(position)
|
176 |
euler = np.array(rotation)
|
@@ -213,48 +239,87 @@ def view_from_rt(position, rotation):
|
|
213 |
return B @ view_mtx
|
214 |
|
215 |
|
216 |
-
# Setup.
|
217 |
-
download_models()
|
218 |
-
|
219 |
-
mde = torch.hub.load(
|
220 |
-
'./extern/ZoeDepth',
|
221 |
-
'ZoeD_N',
|
222 |
-
source='local',
|
223 |
-
pretrained=True,
|
224 |
-
trust_repo=True
|
225 |
-
)
|
226 |
-
|
227 |
-
import spaces
|
228 |
-
|
229 |
-
check_call([
|
230 |
-
sys.executable, '-m', 'pip', 'install',
|
231 |
-
'extern/splatting-0.0.1-py3-none-any.whl'
|
232 |
-
])
|
233 |
-
|
234 |
-
from genwarp import GenWarp
|
235 |
-
from genwarp.ops import (
|
236 |
-
camera_lookat, get_projection_matrix, get_viewport_matrix
|
237 |
-
)
|
238 |
-
|
239 |
-
# GenWarp
|
240 |
-
genwarp_cfg = dict(
|
241 |
-
pretrained_model_path='checkpoints',
|
242 |
-
checkpoint_name='multi1',
|
243 |
-
half_precision_weights=True
|
244 |
-
)
|
245 |
-
genwarp_nvs = GenWarp(cfg=genwarp_cfg, device='cpu')
|
246 |
-
|
247 |
-
|
248 |
with tempfile.TemporaryDirectory() as tmpdir:
|
249 |
with gr.Blocks(
|
250 |
title='GenWarp Demo',
|
251 |
css='img {display: inline;}'
|
252 |
) as demo:
|
253 |
# Internal states.
|
254 |
-
|
255 |
-
|
256 |
-
|
257 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
258 |
|
259 |
# Blocks.
|
260 |
gr.Markdown(
|
@@ -270,23 +335,23 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
270 |
This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer to the [paper](https://arxiv.org/abs/2405.17251).
|
271 |
|
272 |
## How to Use
|
|
|
|
|
|
|
|
|
|
|
273 |
1. Upload a reference image to "Reference Input"
|
274 |
-
- You can also select a image from "Examples"
|
275 |
2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer
|
276 |
3. Hit "Generate a novel view" button and check the result
|
277 |
|
278 |
## Tips
|
|
|
279 |
- Extremely large camera movement from the input view might cause low performance results due to the unexpected deviation from the training distribution, which is not the scope of this model. Instead, you can feed the generation result for the small camera movement repeatedly and progressively move towards a desired view.
|
280 |
- 3D viewer might take some time to update especially when trying different images back to back. Wait until it fully updates to the new image.
|
281 |
|
282 |
"""
|
283 |
)
|
284 |
file = gr.File(label='Reference Input', file_types=['image'])
|
285 |
-
examples = gr.Examples(
|
286 |
-
examples=['./assets/pexels-heyho-5998120_19mm.jpg',
|
287 |
-
'./assets/pexels-itsterrymag-12639296_24mm.jpg'],
|
288 |
-
inputs=file
|
289 |
-
)
|
290 |
with gr.Row():
|
291 |
image_widget = gr.Image(
|
292 |
label='Reference View', type='filepath',
|
@@ -312,68 +377,39 @@ with tempfile.TemporaryDirectory() as tmpdir:
|
|
312 |
gen_widget = gr.Image(
|
313 |
label='Generated View', type='pil', interactive=False
|
314 |
)
|
315 |
-
|
316 |
-
|
317 |
-
|
318 |
-
|
319 |
-
|
320 |
-
|
321 |
-
|
322 |
-
|
323 |
-
|
324 |
-
|
325 |
-
|
326 |
-
@spaces.GPU
|
327 |
-
def cb_3d(image, depth, image_file):
|
328 |
-
xyz, camera_pos, view_mtx, proj_mtx = unproject(depth.cuda())
|
329 |
-
rgb = rearrange(image, 'b c h w -> b (h w) c')[0]
|
330 |
-
splat_file = join(tmpdir, f'./{splitext(basename(image_file))[0]}.splat')
|
331 |
-
save_as_splat(splat_file, xyz.cpu().detach().numpy(), rgb.cpu().detach().numpy())
|
332 |
-
return (splat_file, camera_pos, None), view_mtx.cpu().detach(), proj_mtx.cpu().detach()
|
333 |
-
|
334 |
-
@spaces.GPU
|
335 |
-
def cb_generate(viewer, image, depth, src_view_mtx, proj_mtx):
|
336 |
-
image = image.cuda()
|
337 |
-
depth = depth.cuda()
|
338 |
-
src_view_mtx = src_view_mtx.cuda()
|
339 |
-
proj_mtx = proj_mtx.cuda()
|
340 |
-
src_camera_pos = viewer[1]
|
341 |
-
src_camera_rot = viewer[2]
|
342 |
-
tar_view_mtx = view_from_rt(src_camera_pos, src_camera_rot)
|
343 |
-
tar_view_mtx = torch.from_numpy(tar_view_mtx).to(image)
|
344 |
-
rel_view_mtx = (
|
345 |
-
tar_view_mtx @ torch.linalg.inv(src_view_mtx.to(image))
|
346 |
-
).to(image)
|
347 |
-
|
348 |
-
# GenWarp.
|
349 |
-
renders = genwarp_nvs.to('cuda')(
|
350 |
-
src_image=image.half(),
|
351 |
-
src_depth=depth.half(),
|
352 |
-
rel_view_mtx=rel_view_mtx.half(),
|
353 |
-
src_proj_mtx=proj_mtx.half(),
|
354 |
-
tar_proj_mtx=proj_mtx.half()
|
355 |
-
)
|
356 |
-
|
357 |
-
warped = renders['warped']
|
358 |
-
synthesized = renders['synthesized']
|
359 |
-
warped_pil = to_pil_image(warped[0])
|
360 |
-
synthesized_pil = to_pil_image(synthesized[0])
|
361 |
-
|
362 |
-
return warped_pil, synthesized_pil
|
363 |
|
364 |
# Events
|
365 |
-
file.
|
366 |
fn=cb_mde,
|
367 |
inputs=file,
|
368 |
-
outputs=[image_widget, depth_widget,
|
369 |
-
).
|
370 |
fn=cb_3d,
|
371 |
-
inputs=[
|
372 |
-
outputs=
|
|
|
373 |
button.click(
|
374 |
fn=cb_generate,
|
375 |
-
inputs=[viewer,
|
376 |
-
outputs=[warped_widget, gen_widget]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
377 |
|
378 |
if __name__ == '__main__':
|
379 |
demo.launch()
|
|
|
20 |
|
21 |
from gradio_model3dgscamera import Model3DGSCamera
|
22 |
|
|
|
|
|
|
|
|
|
23 |
def download_models():
|
24 |
models = [
|
25 |
{
|
|
|
55 |
token=model['token']
|
56 |
)
|
57 |
|
58 |
+
# Setup.
|
59 |
+
download_models()
|
60 |
+
|
61 |
+
mde = torch.hub.load(
|
62 |
+
'./extern/ZoeDepth',
|
63 |
+
'ZoeD_N',
|
64 |
+
source='local',
|
65 |
+
pretrained=True,
|
66 |
+
trust_repo=True
|
67 |
+
)
|
68 |
+
|
69 |
+
import spaces
|
70 |
+
|
71 |
+
check_call([
|
72 |
+
sys.executable, '-m', 'pip', 'install',
|
73 |
+
'extern/splatting-0.0.1-py3-none-any.whl'
|
74 |
+
])
|
75 |
+
|
76 |
+
from genwarp import GenWarp
|
77 |
+
from genwarp.ops import (
|
78 |
+
camera_lookat, get_projection_matrix, get_viewport_matrix
|
79 |
+
)
|
80 |
+
|
81 |
+
# GenWarp
|
82 |
+
genwarp_cfg = dict(
|
83 |
+
pretrained_model_path='checkpoints',
|
84 |
+
checkpoint_name='multi1',
|
85 |
+
half_precision_weights=True
|
86 |
+
)
|
87 |
+
genwarp_nvs = GenWarp(cfg=genwarp_cfg, device='cpu')
|
88 |
+
|
89 |
+
# Fixed parameters.
|
90 |
+
IMAGE_SIZE = 512
|
91 |
+
NEAR, FAR = 0.01, 100
|
92 |
+
FOVY = np.deg2rad(55)
|
93 |
+
PROJ_MTX = get_projection_matrix(
|
94 |
+
fovy=torch.ones(1) * FOVY,
|
95 |
+
aspect_wh=1.,
|
96 |
+
near=NEAR,
|
97 |
+
far=FAR
|
98 |
+
)
|
99 |
+
VIEW_MTX = camera_lookat(
|
100 |
+
torch.tensor([[0., 0., 0.]]),
|
101 |
+
torch.tensor([[0., 0., 1.]]),
|
102 |
+
torch.tensor([[0., -1., 0.]])
|
103 |
+
)
|
104 |
+
VIEWPORT_MTX = get_viewport_matrix(
|
105 |
+
IMAGE_SIZE, IMAGE_SIZE,
|
106 |
+
batch_size=1
|
107 |
+
)
|
108 |
+
|
109 |
# Crop the image to the shorter side.
|
110 |
def crop(img: Image) -> Image:
|
111 |
W, H = img.size
|
|
|
115 |
else:
|
116 |
left, right = np.ceil((W - H) / 2.), np.floor((W - H) / 2.) + H
|
117 |
top, bottom = 0, H
|
118 |
+
img = img.crop((left, top, right, bottom))
|
119 |
+
img = img.resize((IMAGE_SIZE, IMAGE_SIZE))
|
120 |
+
return img
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
121 |
|
122 |
def save_as_splat(
|
123 |
filepath: str,
|
|
|
159 |
with open(filepath, "wb") as f:
|
160 |
f.write(buffer.getvalue())
|
161 |
|
162 |
+
def calc_dist2(points: np.ndarray):
|
163 |
+
dists, _ = KDTree(points).query(points, k=4)
|
164 |
+
mean_dists = (dists[:, 1:] ** 2).mean(1)
|
165 |
+
return mean_dists
|
166 |
+
|
167 |
+
def unproject(depth):
|
168 |
+
H, W = depth.shape[2:4]
|
169 |
+
mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
|
170 |
+
|
171 |
+
# Matrices.
|
172 |
+
viewport_mtx = VIEWPORT_MTX.to(depth)
|
173 |
+
proj_mtx = PROJ_MTX.to(depth)
|
174 |
+
view_mtx = VIEW_MTX.to(depth)
|
175 |
+
scr_mtx = (viewport_mtx @ proj_mtx).to(depth)
|
176 |
+
|
177 |
+
grid = torch.stack(torch.meshgrid(
|
178 |
+
torch.arange(W), torch.arange(H), indexing='xy'), dim=-1
|
179 |
+
).to(depth)[None] # BHW2
|
180 |
+
|
181 |
+
screen = F.pad(grid, (0, 1), 'constant', 0)
|
182 |
+
screen = F.pad(screen, (0, 1), 'constant', 1)
|
183 |
+
screen_flat = rearrange(screen, 'b h w c -> b (h w) c')
|
184 |
+
|
185 |
+
eye = screen_flat @ torch.linalg.inv_ex(
|
186 |
+
scr_mtx.float()
|
187 |
+
)[0].mT.to(depth)
|
188 |
+
eye = eye * rearrange(depth, 'b c h w -> b (h w) c')
|
189 |
+
eye[..., 3] = 1
|
190 |
+
|
191 |
+
points = eye @ torch.linalg.inv_ex(view_mtx.float())[0].mT.to(depth)
|
192 |
+
points = points[0, :, :3]
|
193 |
+
|
194 |
+
# Translate to the origin.
|
195 |
+
points[..., 2] -= mean_depth
|
196 |
+
camera_pos = (0, 0, -mean_depth)
|
197 |
+
|
198 |
+
return points, camera_pos
|
199 |
+
|
200 |
def view_from_rt(position, rotation):
|
201 |
t = np.array(position)
|
202 |
euler = np.array(rotation)
|
|
|
239 |
return B @ view_mtx
|
240 |
|
241 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
242 |
with tempfile.TemporaryDirectory() as tmpdir:
|
243 |
with gr.Blocks(
|
244 |
title='GenWarp Demo',
|
245 |
css='img {display: inline;}'
|
246 |
) as demo:
|
247 |
# Internal states.
|
248 |
+
image = gr.State()
|
249 |
+
depth = gr.State()
|
250 |
+
|
251 |
+
# Callbacks
|
252 |
+
@spaces.GPU()
|
253 |
+
def cb_mde(image_file: str):
|
254 |
+
# Load an image.
|
255 |
+
image_pil = crop(Image.open(image_file).convert('RGB'))
|
256 |
+
image = to_tensor(image_pil)[None].detach()
|
257 |
+
# Get depth.
|
258 |
+
depth = mde.cuda().infer(image.cuda()).cpu().detach()
|
259 |
+
depth_pil = to_pil_image(colorize(depth[0]))
|
260 |
+
return image_pil, depth_pil, image, depth
|
261 |
+
|
262 |
+
@spaces.GPU()
|
263 |
+
def cb_3d(image_file, image, depth):
|
264 |
+
# Unproject.
|
265 |
+
xyz, camera_pos = unproject(depth.cuda())
|
266 |
+
xyz = xyz.cpu().detach().numpy()
|
267 |
+
# Save as a splat.
|
268 |
+
## Output filename.
|
269 |
+
splat_file = join(
|
270 |
+
tmpdir, f'./{splitext(basename(image_file))[0]}.splat')
|
271 |
+
rgb = rearrange(image, 'b c h w -> b (h w) c')[0].numpy()
|
272 |
+
save_as_splat(splat_file, xyz, rgb)
|
273 |
+
return splat_file, camera_pos, (0, 0, 0)
|
274 |
+
|
275 |
+
@spaces.GPU()
|
276 |
+
def cb_generate(viewer, image, depth):
|
277 |
+
if depth is None:
|
278 |
+
gr.Error('Image and Depth are not set. Try again.')
|
279 |
+
return None, None
|
280 |
+
|
281 |
+
mean_depth = depth.mean(dim=(2, 3)).squeeze().item()
|
282 |
+
src_view_mtx = camera_lookat(
|
283 |
+
torch.tensor([[0., 0., -mean_depth]]),
|
284 |
+
torch.tensor([[0., 0., 0.]]),
|
285 |
+
torch.tensor([[0., -1., 0.]])
|
286 |
+
).to(depth)
|
287 |
+
tar_camera_pos, tar_camera_rot = viewer[1:3]
|
288 |
+
tar_view_mtx = torch.from_numpy(view_from_rt(
|
289 |
+
tar_camera_pos, tar_camera_rot
|
290 |
+
))
|
291 |
+
rel_view_mtx = (
|
292 |
+
tar_view_mtx @ torch.linalg.inv(src_view_mtx.double())
|
293 |
+
).half().cuda()
|
294 |
+
proj_mtx = PROJ_MTX.half().cuda()
|
295 |
+
|
296 |
+
# GenWarp.
|
297 |
+
renders = genwarp_nvs.to('cuda')(
|
298 |
+
src_image=image.half().cuda(),
|
299 |
+
src_depth=depth.half().cuda(),
|
300 |
+
rel_view_mtx=rel_view_mtx,
|
301 |
+
src_proj_mtx=proj_mtx,
|
302 |
+
tar_proj_mtx=proj_mtx
|
303 |
+
)
|
304 |
+
warped_pil = to_pil_image(renders['warped'].cpu()[0])
|
305 |
+
synthesized_pil = to_pil_image(renders['synthesized'].cpu()[0])
|
306 |
+
|
307 |
+
return warped_pil, synthesized_pil
|
308 |
+
|
309 |
+
def process_example(image_file):
|
310 |
+
gr.Error('')
|
311 |
+
image_pil, depth_pil, image, depth = cb_mde(image_file)
|
312 |
+
viewer = cb_3d(image_file, image, depth)
|
313 |
+
# Fixed angle for examples.
|
314 |
+
viewer = (viewer[0], (-2.020, -0.727, -5.236), (-0.132, 0.378, 0.0))
|
315 |
+
warped_pil, synthsized_pil = cb_generate(
|
316 |
+
viewer, image, depth
|
317 |
+
)
|
318 |
+
return (
|
319 |
+
image_pil, depth_pil, viewer,
|
320 |
+
warped_pil, synthsized_pil,
|
321 |
+
None, None # Clear internal states.
|
322 |
+
)
|
323 |
|
324 |
# Blocks.
|
325 |
gr.Markdown(
|
|
|
335 |
This is an official demo for the paper "[GenWarp: Single Image to Novel Views with Semantic-Preserving Generative Warping](https://genwarp-nvs.github.io/)". Genwarp can generate novel view images from a single input conditioned on camera poses. In this demo, we offer a basic use of inference of the model. For detailed information, please refer to the [paper](https://arxiv.org/abs/2405.17251).
|
336 |
|
337 |
## How to Use
|
338 |
+
|
339 |
+
### Try examples
|
340 |
+
- Examples are in the bottom section of the page
|
341 |
+
|
342 |
+
### Upload your own images
|
343 |
1. Upload a reference image to "Reference Input"
|
|
|
344 |
2. Move the camera to your desired view in "Unprojected 3DGS" 3D viewer
|
345 |
3. Hit "Generate a novel view" button and check the result
|
346 |
|
347 |
## Tips
|
348 |
+
- This model is mainly trained for indoor/outdoor scenery. It might not work well for object-centric inputs. For details on training the model, please check our [paper](https://arxiv.org/abs/2405.17251).
|
349 |
- Extremely large camera movement from the input view might cause low performance results due to the unexpected deviation from the training distribution, which is not the scope of this model. Instead, you can feed the generation result for the small camera movement repeatedly and progressively move towards a desired view.
|
350 |
- 3D viewer might take some time to update especially when trying different images back to back. Wait until it fully updates to the new image.
|
351 |
|
352 |
"""
|
353 |
)
|
354 |
file = gr.File(label='Reference Input', file_types=['image'])
|
|
|
|
|
|
|
|
|
|
|
355 |
with gr.Row():
|
356 |
image_widget = gr.Image(
|
357 |
label='Reference View', type='filepath',
|
|
|
377 |
gen_widget = gr.Image(
|
378 |
label='Generated View', type='pil', interactive=False
|
379 |
)
|
380 |
+
examples = gr.Examples(
|
381 |
+
examples=[
|
382 |
+
'./assets/pexels-heyho-5998120_19mm.jpg',
|
383 |
+
'./assets/pexels-itsterrymag-12639296_24mm.jpg'
|
384 |
+
],
|
385 |
+
fn=process_example,
|
386 |
+
inputs=file,
|
387 |
+
outputs=[image_widget, depth_widget, viewer,
|
388 |
+
warped_widget, gen_widget,
|
389 |
+
image, depth]
|
390 |
+
)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
391 |
|
392 |
# Events
|
393 |
+
file.upload(
|
394 |
fn=cb_mde,
|
395 |
inputs=file,
|
396 |
+
outputs=[image_widget, depth_widget, image, depth]
|
397 |
+
).success(
|
398 |
fn=cb_3d,
|
399 |
+
inputs=[image_widget, image, depth],
|
400 |
+
outputs=viewer
|
401 |
+
)
|
402 |
button.click(
|
403 |
fn=cb_generate,
|
404 |
+
inputs=[viewer, image, depth],
|
405 |
+
outputs=[warped_widget, gen_widget]
|
406 |
+
)
|
407 |
+
# To re-calculate the uncached depth for examples in background.
|
408 |
+
examples.load_input_event.success(
|
409 |
+
fn=lambda x: cb_mde(x)[2:4],
|
410 |
+
inputs=file,
|
411 |
+
outputs=[image, depth]
|
412 |
+
)
|
413 |
|
414 |
if __name__ == '__main__':
|
415 |
demo.launch()
|