Spaces:
Runtime error
Runtime error
import os | |
import gradio as gr | |
import plotly.graph_objects as go | |
import sys | |
import torch | |
from huggingface_hub import hf_hub_download | |
import numpy as np | |
import random | |
os.system("git clone https://github.com/luost26/diffusion-point-cloud") | |
sys.path.append("diffusion-point-cloud") | |
from models.vae_gaussian import * | |
from models.vae_flow import * | |
airplane=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main") | |
chair=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_chair.pt",revision="main") | |
device='cuda' if torch.cuda.is_available() else 'cpu' | |
ckpt_airplane = torch.load(airplane,map_location=torch.device(device)) | |
ckpt_chair = torch.load(chair,map_location=torch.device(device)) | |
def seed_all(seed): | |
torch.manual_seed(seed) | |
np.random.seed(seed) | |
random.seed(seed) | |
def normalize_point_clouds(pcs,mode): | |
if mode is None: | |
return pcs | |
for i in range(pcs.size(0)): | |
pc = pcs[i] | |
if mode == 'shape_unit': | |
shift = pc.mean(dim=0).reshape(1, 3) | |
scale = pc.flatten().std().reshape(1, 1) | |
elif mode == 'shape_bbox': | |
pc_max, _ = pc.max(dim=0, keepdim=True) # (1, 3) | |
pc_min, _ = pc.min(dim=0, keepdim=True) # (1, 3) | |
shift = ((pc_min + pc_max) / 2).view(1, 3) | |
scale = (pc_max - pc_min).max().reshape(1, 1) / 2 | |
pc = (pc - shift) / scale | |
pcs[i] = pc | |
return pcs | |
def predict(Seed,ckpt): | |
if Seed==None: | |
Seed=777 | |
seed_all(Seed) | |
if ckpt['args'].model == 'gaussian': | |
model = GaussianVAE(ckpt['args']).to(device) | |
elif ckpt['args'].model == 'flow': | |
model = FlowVAE(ckpt['args']).to(device) | |
model.load_state_dict(ckpt['state_dict']) | |
# Generate Point Clouds | |
gen_pcs = [] | |
with torch.no_grad(): | |
z = torch.randn([1, ckpt['args'].latent_dim]).to(device) | |
x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility) | |
gen_pcs.append(x.detach().cpu()) | |
gen_pcs = torch.cat(gen_pcs, dim=0)[:1] | |
gen_pcs = normalize_point_clouds(gen_pcs, mode="shape_bbox") | |
return gen_pcs[0] | |
def generate(seed,value): | |
if value=="Airplane": | |
ckpt=ckpt_airplane | |
elif value=="Chair": | |
ckpt=ckpt_chair | |
else : | |
ckpt=ckpt_airplane | |
colors=(238, 75, 43) | |
points=predict(seed,ckpt) | |
num_points=points.shape[0] | |
fig = go.Figure( | |
data=[ | |
go.Scatter3d( | |
x=points[:,0], y=points[:,1], z=points[:,2], | |
mode='markers', | |
marker=dict(size=1, color=colors) | |
) | |
], | |
layout=dict( | |
scene=dict( | |
xaxis=dict(visible=False), | |
yaxis=dict(visible=False), | |
zaxis=dict(visible=False) | |
) | |
) | |
) | |
return fig | |
markdown=f''' | |
# Diffusion Probabilistic Models for 3D Point Cloud Generation | |
[[The Paper](https://arxiv.org/abs/2103.01458)] [[Original Code](https://github.com/luost26/diffusion-point-cloud)] | |
The space demo for our CVPR 2021 paper "Diffusion Probabilistic Models for 3D Point Cloud Generation". | |
It is running on {device} | |
''' | |
with gr.Blocks() as demo: | |
with gr.Column(): | |
with gr.Row(): | |
gr.Markdown(markdown) | |
with gr.Row(): | |
seed = gr.Slider( minimum=0, maximum=2**16,label='Seed') | |
value=gr.Dropdown(choices=["Airplane","Chair"],label="Choose Model Type") | |
btn = gr.Button(value="Generate") | |
point_cloud = gr.Plot() | |
demo.load(generate, [seed,value], point_cloud) | |
btn.click(generate, [seed,value], point_cloud) | |
demo.launch() |