Spaces:
Runtime error
Runtime error
File size: 3,870 Bytes
0d0e451 7872317 0d0e451 2f68096 0d0e451 d5c0caa 0d0e451 7872317 0d0e451 d5c0caa 0d0e451 d5c0caa 0d0e451 d5c0caa 0d0e451 7872317 0d0e451 2f68096 0d0e451 7872317 2f68096 0d0e451 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 |
import os
import gradio as gr
import plotly.graph_objects as go
import sys
import torch
from huggingface_hub import hf_hub_download
import numpy as np
import random
os.system("git clone https://github.com/luost26/diffusion-point-cloud")
sys.path.append("diffusion-point-cloud")
from models.vae_gaussian import *
from models.vae_flow import *
airplane=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
chair="./GEN_chair.pt"
device='cuda' if torch.cuda.is_available() else 'cpu'
ckpt_airplane = torch.load(airplane,map_location=torch.device(device))
ckpt_chair = torch.load(chair,map_location=torch.device(device))
def seed_all(seed):
torch.manual_seed(seed)
np.random.seed(seed)
random.seed(seed)
def normalize_point_clouds(pcs,mode):
if mode is None:
return pcs
for i in range(pcs.size(0)):
pc = pcs[i]
if mode == 'shape_unit':
shift = pc.mean(dim=0).reshape(1, 3)
scale = pc.flatten().std().reshape(1, 1)
elif mode == 'shape_bbox':
pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3)
pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3)
shift = ((pc_min + pc_max) / 2).view(1, 3)
scale = (pc_max - pc_min).max().reshape(1, 1) / 2
pc = (pc - shift) / scale
pcs[i] = pc
return pcs
def predict(Seed,ckpt):
if Seed==None:
Seed=777
seed_all(Seed)
if ckpt['args'].model == 'gaussian':
model = GaussianVAE(ckpt['args']).to(device)
elif ckpt['args'].model == 'flow':
model = FlowVAE(ckpt['args']).to(device)
model.load_state_dict(ckpt['state_dict'])
# Generate Point Clouds
gen_pcs = []
with torch.no_grad():
z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
gen_pcs.append(x.detach().cpu())
gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
gen_pcs = normalize_point_clouds(gen_pcs, mode="shape_bbox")
return gen_pcs[0]
def generate(seed,value):
if value=="Airplane":
ckpt=ckpt_airplane
elif value=="Chair":
ckpt=ckpt_chair
else :
ckpt=ckpt_airplane
colors=(238, 75, 43)
points=predict(seed,ckpt)
num_points=points.shape[0]
fig = go.Figure(
data=[
go.Scatter3d(
x=points[:,0], y=points[:,1], z=points[:,2],
mode='markers',
marker=dict(size=1, color=colors)
)
],
layout=dict(
scene=dict(
xaxis=dict(visible=False),
yaxis=dict(visible=False),
zaxis=dict(visible=False)
)
)
)
return fig
markdown=f'''
# Diffusion Probabilistic Models for 3D Point Cloud Generation
[The space demo for the CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation".](https://arxiv.org/abs/2103.01458)
[For the official implementation.](https://github.com/luost26/diffusion-point-cloud)
It is running on {device}
### Citation By
@inproceedings{luo2021diffusion,
author = {Luo, Shitong and Hu, Wei},
title = {Diffusion Probabilistic Models for 3D Point Cloud Generation},
booktitle = {Proceedings of the IEEE/CVF Conference on Computer Vision and Pattern Recognition (CVPR)},
month = {June},
year = {2021}
}
'''
with gr.Blocks() as demo:
with gr.Column():
with gr.Row():
gr.Markdown(markdown)
with gr.Row():
seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
value=gr.Dropdown(choices=["Airplane","Chair"],label="Choose Model Type")
btn = gr.Button(value="Generate")
point_cloud = gr.Plot()
demo.load(generate, [seed,value], point_cloud)
btn.click(generate, [seed,value], point_cloud)
demo.launch() |