fffiloni commited on
Commit
28dd534
1 Parent(s): 62402d2

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +20 -8
app.py CHANGED
@@ -256,13 +256,15 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
256
 
257
  return "output_first_frame.jpg", frame_names, inference_state
258
 
259
- def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, video_frames_dir):
260
  #### PROPAGATION ####
261
  sam2_checkpoint, model_cfg = load_model(checkpoint)
262
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
 
263
  inference_state = stored_inference_state
264
  frame_names = stored_frame_names
265
  video_dir = video_frames_dir
 
266
  # Define a directory to save the JPEG images
267
  frames_output_dir = "frames_output_images"
268
  os.makedirs(frames_output_dir, exist_ok=True)
@@ -279,7 +281,10 @@ def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, vid
279
  }
280
 
281
  # render the segmentation results every few frames
282
- vis_frame_stride = 15
 
 
 
283
  plt.close("all")
284
  for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
285
  plt.figure(figsize=(6, 4))
@@ -298,7 +303,11 @@ def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, vid
298
  # Close the plot
299
  plt.close()
300
 
301
- return jpeg_images
 
 
 
 
302
 
303
  with gr.Blocks() as demo:
304
  first_frame_path = gr.State()
@@ -323,7 +332,7 @@ with gr.Blocks() as demo:
323
  points_map = gr.Image(
324
  label="points map",
325
  type="filepath",
326
- interactive=True
327
  )
328
  video_in = gr.Video(label="Video IN")
329
  with gr.Row():
@@ -333,8 +342,11 @@ with gr.Blocks() as demo:
333
  submit_btn = gr.Button("Submit")
334
  with gr.Column():
335
  output_result = gr.Image()
336
- propagate_btn = gr.Button("Propagate")
337
- output_propagated = gr.Gallery()
 
 
 
338
  # output_result_mask = gr.Image()
339
 
340
  clear_points_btn.click(
@@ -366,8 +378,8 @@ with gr.Blocks() as demo:
366
 
367
  propagate_btn.click(
368
  fn = propagate_to_all,
369
- inputs = [checkpoint, stored_inference_state, stored_frame_names, video_frames_dir],
370
- outputs = [output_propagated]
371
  )
372
 
373
  demo.launch(show_api=False, show_error=True)
 
256
 
257
  return "output_first_frame.jpg", frame_names, inference_state
258
 
259
+ def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type):
260
  #### PROPAGATION ####
261
  sam2_checkpoint, model_cfg = load_model(checkpoint)
262
  predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
263
+
264
  inference_state = stored_inference_state
265
  frame_names = stored_frame_names
266
  video_dir = video_frames_dir
267
+
268
  # Define a directory to save the JPEG images
269
  frames_output_dir = "frames_output_images"
270
  os.makedirs(frames_output_dir, exist_ok=True)
 
281
  }
282
 
283
  # render the segmentation results every few frames
284
+ if vis_frame_type == "check":
285
+ vis_frame_stride = 15
286
+ elif vis_frame_type == "render":
287
+ vis_frame_stride = 1
288
  plt.close("all")
289
  for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
290
  plt.figure(figsize=(6, 4))
 
303
  # Close the plot
304
  plt.close()
305
 
306
+ if vis_frame_type == "check":
307
+ return gr.update(value=jpeg_images, visible=True), gr.update(visible=False, value=None)
308
+ elif vis_frame_type == "render":
309
+ return gr.update(visible=False, value=None), gr.update(value=final_vid, visible=True)
310
+
311
 
312
  with gr.Blocks() as demo:
313
  first_frame_path = gr.State()
 
332
  points_map = gr.Image(
333
  label="points map",
334
  type="filepath",
335
+ interactive=False
336
  )
337
  video_in = gr.Video(label="Video IN")
338
  with gr.Row():
 
342
  submit_btn = gr.Button("Submit")
343
  with gr.Column():
344
  output_result = gr.Image()
345
+ with gr.Row():
346
+ vis_frame_type = gr.Radio(choices=["check", "render"], value="render", scale=2)
347
+ propagate_btn = gr.Button("Propagate", scale=1)
348
+ output_propagated = gr.Gallery(visible=False)
349
+ output_video = gr.Video(visible=False)
350
  # output_result_mask = gr.Image()
351
 
352
  clear_points_btn.click(
 
378
 
379
  propagate_btn.click(
380
  fn = propagate_to_all,
381
+ inputs = [checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type],
382
+ outputs = [output_propagated, output_video]
383
  )
384
 
385
  demo.launch(show_api=False, show_error=True)