|
|
|
import os |
|
import time |
|
from collections import OrderedDict |
|
from typing import Optional, List |
|
import argparse |
|
from functools import partial |
|
|
|
from einops import repeat, rearrange |
|
import numpy as np |
|
from PIL import Image |
|
import trimesh |
|
import cv2 |
|
|
|
import torch |
|
import pytorch_lightning as pl |
|
|
|
from michelangelo.models.tsal.tsal_base import Latent2MeshOutput |
|
from michelangelo.models.tsal.inference_utils import extract_geometry |
|
from michelangelo.utils.misc import get_config_from_file, instantiate_from_config |
|
from michelangelo.utils.visualizers.pythreejs_viewer import PyThreeJSViewer |
|
from michelangelo.utils.visualizers import html_util |
|
|
|
def load_model(args): |
|
|
|
model_config = get_config_from_file(args.config_path) |
|
if hasattr(model_config, "model"): |
|
model_config = model_config.model |
|
|
|
model = instantiate_from_config(model_config, ckpt_path=args.ckpt_path) |
|
model = model.cuda() |
|
model = model.eval() |
|
|
|
return model |
|
|
|
def load_surface(fp): |
|
|
|
with np.load(args.pointcloud_path) as input_pc: |
|
surface = input_pc['points'] |
|
normal = input_pc['normals'] |
|
|
|
rng = np.random.default_rng() |
|
ind = rng.choice(surface.shape[0], 4096, replace=False) |
|
surface = torch.FloatTensor(surface[ind]) |
|
normal = torch.FloatTensor(normal[ind]) |
|
|
|
surface = torch.cat([surface, normal], dim=-1).unsqueeze(0).cuda() |
|
|
|
return surface |
|
|
|
def prepare_image(args, number_samples=2): |
|
|
|
image = cv2.imread(f"{args.image_path}") |
|
image = cv2.cvtColor(image, cv2.COLOR_BGR2RGB) |
|
|
|
image_pt = torch.tensor(image).float() |
|
image_pt = image_pt / 255 * 2 - 1 |
|
image_pt = rearrange(image_pt, "h w c -> c h w") |
|
|
|
image_pt = repeat(image_pt, "c h w -> b c h w", b=number_samples) |
|
|
|
return image_pt |
|
|
|
def save_output(args, mesh_outputs): |
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
for i, mesh in enumerate(mesh_outputs): |
|
mesh.mesh_f = mesh.mesh_f[:, ::-1] |
|
mesh_output = trimesh.Trimesh(mesh.mesh_v, mesh.mesh_f) |
|
|
|
name = str(i) + "_out_mesh.obj" |
|
mesh_output.export(os.path.join(args.output_dir, name), include_normals=True) |
|
|
|
print(f'-----------------------------------------------------------------------------') |
|
print(f'>>> Finished and mesh saved in {args.output_dir}') |
|
print(f'-----------------------------------------------------------------------------') |
|
|
|
return 0 |
|
|
|
def reconstruction(args, model, bounds=(-1.25, -1.25, -1.25, 1.25, 1.25, 1.25), octree_depth=7, num_chunks=10000): |
|
|
|
surface = load_surface(args.pointcloud_path) |
|
|
|
|
|
shape_embed, shape_latents = model.model.encode_shape_embed(surface, return_latents=True) |
|
shape_zq, posterior = model.model.shape_model.encode_kl_embed(shape_latents) |
|
|
|
|
|
latents = model.model.shape_model.decode(shape_zq) |
|
geometric_func = partial(model.model.shape_model.query_geometry, latents=latents) |
|
|
|
|
|
mesh_v_f, has_surface = extract_geometry( |
|
geometric_func=geometric_func, |
|
device=surface.device, |
|
batch_size=surface.shape[0], |
|
bounds=bounds, |
|
octree_depth=octree_depth, |
|
num_chunks=num_chunks, |
|
) |
|
recon_mesh = trimesh.Trimesh(mesh_v_f[0][0], mesh_v_f[0][1]) |
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
recon_mesh.export(os.path.join(args.output_dir, 'reconstruction.obj')) |
|
|
|
print(f'-----------------------------------------------------------------------------') |
|
print(f'>>> Finished and mesh saved in {os.path.join(args.output_dir, "reconstruction.obj")}') |
|
print(f'-----------------------------------------------------------------------------') |
|
|
|
return 0 |
|
|
|
def image2mesh(args, model, guidance_scale=7.5, box_v=1.1, octree_depth=7): |
|
|
|
sample_inputs = { |
|
"image": prepare_image(args) |
|
} |
|
|
|
mesh_outputs = model.sample( |
|
sample_inputs, |
|
sample_times=1, |
|
guidance_scale=guidance_scale, |
|
return_intermediates=False, |
|
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], |
|
octree_depth=octree_depth, |
|
)[0] |
|
|
|
save_output(args, mesh_outputs) |
|
|
|
return 0 |
|
|
|
def text2mesh(args, model, num_samples=2, guidance_scale=7.5, box_v=1.1, octree_depth=7): |
|
|
|
sample_inputs = { |
|
"text": [args.text] * num_samples |
|
} |
|
mesh_outputs = model.sample( |
|
sample_inputs, |
|
sample_times=1, |
|
guidance_scale=guidance_scale, |
|
return_intermediates=False, |
|
bounds=[-box_v, -box_v, -box_v, box_v, box_v, box_v], |
|
octree_depth=octree_depth, |
|
)[0] |
|
|
|
save_output(args, mesh_outputs) |
|
|
|
return 0 |
|
|
|
task_dick = { |
|
'reconstruction': reconstruction, |
|
'image2mesh': image2mesh, |
|
'text2mesh': text2mesh, |
|
} |
|
|
|
if __name__ == "__main__": |
|
''' |
|
1. Reconstruct point cloud |
|
2. Image-conditioned generation |
|
3. Text-conditioned generation |
|
''' |
|
parser = argparse.ArgumentParser() |
|
parser.add_argument("--task", type=str, choices=['reconstruction', 'image2mesh', 'text2mesh'], required=True) |
|
parser.add_argument("--config_path", type=str, required=True) |
|
parser.add_argument("--ckpt_path", type=str, required=True) |
|
parser.add_argument("--pointcloud_path", type=str, default='./example_data/surface.npz', help='Path to the input point cloud') |
|
parser.add_argument("--image_path", type=str, help='Path to the input image') |
|
parser.add_argument("--text", type=str, help='Input text within a format: A 3D model of motorcar; Porsche 911.') |
|
parser.add_argument("--output_dir", type=str, default='./output') |
|
parser.add_argument("-s", "--seed", type=int, default=0) |
|
args = parser.parse_args() |
|
|
|
pl.seed_everything(args.seed) |
|
|
|
print(f'-----------------------------------------------------------------------------') |
|
print(f'>>> Running {args.task}') |
|
args.output_dir = os.path.join(args.output_dir, args.task) |
|
print(f'>>> Output directory: {args.output_dir}') |
|
print(f'-----------------------------------------------------------------------------') |
|
|
|
task_dick[args.task](args, load_model(args)) |