SerdarHelli commited on
Commit
562e1b7
1 Parent(s): e651224

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +7 -7
app.py CHANGED
@@ -48,7 +48,7 @@ def normalize_point_clouds(pcs,mode):
48
 
49
 
50
 
51
- def predict(Seed,ckpt,truncate_std):
52
  if Seed==None:
53
  Seed=777
54
  seed_all(Seed)
@@ -63,14 +63,14 @@ def predict(Seed,ckpt,truncate_std):
63
  gen_pcs = []
64
  with torch.no_grad():
65
  z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
66
- x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility,truncate_std=truncate_std)
67
  gen_pcs.append(x.detach().cpu())
68
  gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
69
  gen_pcs = normalize_point_clouds(gen_pcs, mode="shape_bbox")
70
 
71
  return gen_pcs[0]
72
 
73
- def generate(seed,value,truncate_std):
74
  if value=="Airplane":
75
  ckpt=ckpt_airplane
76
  elif value=="Chair":
@@ -79,7 +79,7 @@ def generate(seed,value,truncate_std):
79
  ckpt=ckpt_airplane
80
 
81
  colors=(238, 75, 43)
82
- points=predict(seed,ckpt,truncate_std)
83
  num_points=points.shape[0]
84
 
85
 
@@ -126,11 +126,11 @@ with gr.Blocks() as demo:
126
  with gr.Row():
127
  seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
128
  value=gr.Dropdown(choices=["Airplane","Chair"],label="Choose Model Type")
129
- truncate_std = gr.Slider( minimum=1, maximum=2,label='Truncate Std')
130
 
131
  btn = gr.Button(value="Generate")
132
  point_cloud = gr.Plot()
133
- demo.load(generate, [seed,value,truncate_std], point_cloud)
134
- btn.click(generate, [seed,value,truncate_std], point_cloud)
135
 
136
  demo.launch()
 
48
 
49
 
50
 
51
+ def predict(Seed,ckpt):
52
  if Seed==None:
53
  Seed=777
54
  seed_all(Seed)
 
63
  gen_pcs = []
64
  with torch.no_grad():
65
  z = torch.randn([1, ckpt['args'].latent_dim]).to(device)
66
+ x = model.sample(z, 2048, flexibility=ckpt['args'].flexibility)
67
  gen_pcs.append(x.detach().cpu())
68
  gen_pcs = torch.cat(gen_pcs, dim=0)[:1]
69
  gen_pcs = normalize_point_clouds(gen_pcs, mode="shape_bbox")
70
 
71
  return gen_pcs[0]
72
 
73
+ def generate(seed,value):
74
  if value=="Airplane":
75
  ckpt=ckpt_airplane
76
  elif value=="Chair":
 
79
  ckpt=ckpt_airplane
80
 
81
  colors=(238, 75, 43)
82
+ points=predict(seed,ckpt)
83
  num_points=points.shape[0]
84
 
85
 
 
126
  with gr.Row():
127
  seed = gr.Slider( minimum=0, maximum=2**16,label='Seed')
128
  value=gr.Dropdown(choices=["Airplane","Chair"],label="Choose Model Type")
129
+ #truncate_std = gr.Slider( minimum=1, maximum=2,label='Truncate Std')
130
 
131
  btn = gr.Button(value="Generate")
132
  point_cloud = gr.Plot()
133
+ demo.load(generate, [seed,value], point_cloud)
134
+ btn.click(generate, [seed,value], point_cloud)
135
 
136
  demo.launch()