File size: 3,531 Bytes
ec0c8fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728b726
ec0c8fa
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
728b726
ec0c8fa
 
 
728b726
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
import os
import time
from pathlib import Path
import uuid
import tempfile
from typing import Union
import spaces
import atexit
from concurrent.futures import ThreadPoolExecutor

import gradio as gr
import cv2
import torch
import numpy as np

from moge.model import MoGeModel
from moge.utils.vis import colorize_depth
import utils3d

model = MoGeModel.from_pretrained('Ruicheng/moge-vitl').cuda().eval()
thread_pool_executor = ThreadPoolExecutor(max_workers=1)


def delete_later(path: Union[str, os.PathLike], delay: int = 300):
    def _delete():
        try: 
            os.remove(path) 
        except: 
            pass
    def _wait_and_delete():
        time.sleep(delay)
        _delete(path)
    thread_pool_executor.submit(_wait_and_delete)
    atexit.register(_delete)

@spaces.GPU
def run(image: np.ndarray, remove_edge: bool = True):
    run_id = str(uuid.uuid4())

    larger_size = max(image.shape[:2])
    if larger_size > 1024:
        scale = 1024 / larger_size
        image = cv2.resize(image, (0, 0), fx=scale, fy=scale, interpolation=cv2.INTER_AREA)

    image_tensor = torch.tensor(image, dtype=torch.float32, device=torch.device('cuda')).permute(2, 0, 1) / 255
    output = model.infer(image_tensor, resolution_level=9, apply_mask=True)
    points, depth, mask = output['points'].cpu().numpy(), output['depth'].cpu().numpy(), output['mask'].cpu().numpy()

    if remove_edge:
        mask = mask & ~utils3d.numpy.depth_edge(depth, mask=mask, rtol=0.02)
    mask = mask & (depth > 0)

    _, faces, indices = utils3d.numpy.image_mesh(width=image.shape[1], height=image.shape[0], mask=mask)
    faces = utils3d.numpy.triangulate(faces)

    tempdir = Path(tempfile.gettempdir(), 'moge')
    tempdir.mkdir(exist_ok=True)

    output_glb_path = Path(tempdir, f'{run_id}.glb')
    output_glb_path.parent.mkdir(exist_ok=True)
    tempfile.TemporaryFile()
    utils3d.io.write_glb(
        output_glb_path,
        vertices=points.reshape(-1, 3)[indices] * [-1, -1, 1], 
        faces=faces,
        vertex_colors=image.reshape(-1, 3)[indices] / 255,
    )

    output_ply_path = Path(tempdir, f'{run_id}.ply')
    output_ply_path.parent.mkdir(exist_ok=True)
    utils3d.io.write_ply(
        output_ply_path,
        vertices=points.reshape(-1, 3)[indices] * [-1, -1, 1], 
        faces=faces,
        vertex_colors=image.reshape(-1, 3)[indices] / 255,
    )

    colorized_depth = colorize_depth(depth)

    delete_later(output_glb_path, delay=300)
    delete_later(output_ply_path, delay=300)
        
    return colorized_depth, output_glb_path, output_ply_path.as_posix()


DESCRIPTION = """
Turns 2D images into 3D point maps with MoGe

NOTE: 
* If the image is too large (> 1024px), it will be resized accordingly.
* The color in the 3D viewer may look dark due to rendering of 3D viewer. You may download the 3D model as .glb or .ply file to view it in other 3D viewers.
"""

if __name__ == '__main__':
    
    gr.Interface(
        fn=run,
        inputs=[
            gr.Image(type="numpy", image_mode="RGB"),
            gr.Checkbox(True, label="Remove edges"),
        ],
        outputs=[
            gr.Image(type="numpy", label="Depth map (colorized)"),
            gr.Model3D(display_mode="solid", clear_color=[1.0, 1.0, 1.0, 1.0], label="3D Viewer"),
            gr.File(type="filepath", label="Download the model as .ply file"),
        ],
        title=None,
        description=DESCRIPTION,
        clear_btn=None,
        allow_flagging="never",
        theme=gr.themes.Soft()
    ).launch()