SerdarHelli commited on
Commit
d5c0caa
1 Parent(s): dcd6671

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +4 -3
app.py CHANGED
@@ -15,6 +15,7 @@ from models.vae_flow import *
15
  airplane=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
16
  chair=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_chair.pt",revision="main")
17
 
 
18
 
19
  ckpt_airplane = torch.load(airplane)
20
  ckpt_chair = torch.load(chair)
@@ -42,15 +43,15 @@ def predict(Seed,ckpt):
42
  seed_all(Seed)
43
 
44
  if ckpt['args'].model == 'gaussian':
45
- model = GaussianVAE(ckpt['args']).to("cuda")
46
  elif ckpt['args'].model == 'flow':
47
- model = FlowVAE(ckpt['args']).to("cuda")
48
 
49
  model.load_state_dict(ckpt['state_dict'])
50
  # Generate Point Clouds
51
  gen_pcs = []
52
  with torch.no_grad():
53
- z = torch.randn([1, ckpt['args'].latent_dim]).to("cuda")
54
  x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
55
  gen_pcs.append(x.detach().cpu())
56
  gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
 
15
  airplane=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_airplane.pt",revision="main")
16
  chair=network_pkl=hf_hub_download("SerdarHelli/diffusion-point-cloud", filename="GEN_chair.pt",revision="main")
17
 
18
+ device='cuda' if torch.cuda.is_available() else 'cpu'
19
 
20
  ckpt_airplane = torch.load(airplane)
21
  ckpt_chair = torch.load(chair)
 
43
  seed_all(Seed)
44
 
45
  if ckpt['args'].model == 'gaussian':
46
+ model = GaussianVAE(ckpt['args']).to(device)
47
  elif ckpt['args'].model == 'flow':
48
+ model = FlowVAE(ckpt['args']).to(device)
49
 
50
  model.load_state_dict(ckpt['state_dict'])
51
  # Generate Point Clouds
52
  gen_pcs = []
53
  with torch.no_grad():
54
+ z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
55
  x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
56
  gen_pcs.append(x.detach().cpu())
57
  gen_pcs = torch.cat(gen_pcs, dim=0)[:1]