Spaces:
Runtime error
Runtime error
SerdarHelli
commited on
Commit
•
d5c0caa
1
Parent(s):
dcd6671
Update app.py
Browse files
app.py
CHANGED
@@ -15,6 +15,7 @@ from models.vae_flow import *
|
|
15 |
airplane=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
|
16 |
chair=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_chair.pt",revision="main")
|
17 |
|
|
|
18 |
|
19 |
ckpt_airplane = torch.load(airplane)
|
20 |
ckpt_chair = torch.load(chair)
|
@@ -42,15 +43,15 @@ def predict(Seed,ckpt):
|
|
42 |
seed_all(Seed)
|
43 |
|
44 |
if ckpt['args'].model == 'gaussian':
|
45 |
-
model = GaussianVAE(ckpt['args']).to(
|
46 |
elif ckpt['args'].model == 'flow':
|
47 |
-
model = FlowVAE(ckpt['args']).to(
|
48 |
|
49 |
model.load_state_dict(ckpt['state_dict'])
|
50 |
# Generate Point Clouds
|
51 |
gen_pcs = []
|
52 |
with torch.no_grad():
|
53 |
-
z = torch.randn([1, ckpt['args'].latent_dim]).to(
|
54 |
x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
|
55 |
gen_pcs.append(x.detach().cpu())
|
56 |
gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
|
|
|
15 |
airplane=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
|
16 |
chair=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_chair.pt",revision="main")
|
17 |
|
18 |
+
device='cuda' if torch.cuda.is_available() else 'cpu'
|
19 |
|
20 |
ckpt_airplane = torch.load(airplane)
|
21 |
ckpt_chair = torch.load(chair)
|
|
|
43 |
seed_all(Seed)
|
44 |
|
45 |
if ckpt['args'].model == 'gaussian':
|
46 |
+
model = GaussianVAE(ckpt['args']).to(device)
|
47 |
elif ckpt['args'].model == 'flow':
|
48 |
+
model = FlowVAE(ckpt['args']).to(device)
|
49 |
|
50 |
model.load_state_dict(ckpt['state_dict'])
|
51 |
# Generate Point Clouds
|
52 |
gen_pcs = []
|
53 |
with torch.no_grad():
|
54 |
+
z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
|
55 |
x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
|
56 |
gen_pcs.append(x.detach().cpu())
|
57 |
gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
|