Spaces:
Sleeping
Sleeping
Update app.py
Browse files
app.py
CHANGED
@@ -256,13 +256,15 @@ def sam_process(input_first_frame_image, checkpoint, tracking_points, trackings_
|
|
256 |
|
257 |
return "output_first_frame.jpg", frame_names, inference_state
|
258 |
|
259 |
-
def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, video_frames_dir):
|
260 |
#### PROPAGATION ####
|
261 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
262 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
|
|
263 |
inference_state = stored_inference_state
|
264 |
frame_names = stored_frame_names
|
265 |
video_dir = video_frames_dir
|
|
|
266 |
# Define a directory to save the JPEG images
|
267 |
frames_output_dir = "frames_output_images"
|
268 |
os.makedirs(frames_output_dir, exist_ok=True)
|
@@ -279,7 +281,10 @@ def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, vid
|
|
279 |
}
|
280 |
|
281 |
# render the segmentation results every few frames
|
282 |
-
|
|
|
|
|
|
|
283 |
plt.close("all")
|
284 |
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
|
285 |
plt.figure(figsize=(6, 4))
|
@@ -298,7 +303,11 @@ def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, vid
|
|
298 |
# Close the plot
|
299 |
plt.close()
|
300 |
|
301 |
-
|
|
|
|
|
|
|
|
|
302 |
|
303 |
with gr.Blocks() as demo:
|
304 |
first_frame_path = gr.State()
|
@@ -323,7 +332,7 @@ with gr.Blocks() as demo:
|
|
323 |
points_map = gr.Image(
|
324 |
label="points map",
|
325 |
type="filepath",
|
326 |
-
interactive=
|
327 |
)
|
328 |
video_in = gr.Video(label="Video IN")
|
329 |
with gr.Row():
|
@@ -333,8 +342,11 @@ with gr.Blocks() as demo:
|
|
333 |
submit_btn = gr.Button("Submit")
|
334 |
with gr.Column():
|
335 |
output_result = gr.Image()
|
336 |
-
|
337 |
-
|
|
|
|
|
|
|
338 |
# output_result_mask = gr.Image()
|
339 |
|
340 |
clear_points_btn.click(
|
@@ -366,8 +378,8 @@ with gr.Blocks() as demo:
|
|
366 |
|
367 |
propagate_btn.click(
|
368 |
fn = propagate_to_all,
|
369 |
-
inputs = [checkpoint, stored_inference_state, stored_frame_names, video_frames_dir],
|
370 |
-
outputs = [output_propagated]
|
371 |
)
|
372 |
|
373 |
demo.launch(show_api=False, show_error=True)
|
|
|
256 |
|
257 |
return "output_first_frame.jpg", frame_names, inference_state
|
258 |
|
259 |
+
def propagate_to_all(checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type):
|
260 |
#### PROPAGATION ####
|
261 |
sam2_checkpoint, model_cfg = load_model(checkpoint)
|
262 |
predictor = build_sam2_video_predictor(model_cfg, sam2_checkpoint)
|
263 |
+
|
264 |
inference_state = stored_inference_state
|
265 |
frame_names = stored_frame_names
|
266 |
video_dir = video_frames_dir
|
267 |
+
|
268 |
# Define a directory to save the JPEG images
|
269 |
frames_output_dir = "frames_output_images"
|
270 |
os.makedirs(frames_output_dir, exist_ok=True)
|
|
|
281 |
}
|
282 |
|
283 |
# render the segmentation results every few frames
|
284 |
+
if vis_frame_type == "check":
|
285 |
+
vis_frame_stride = 15
|
286 |
+
elif vis_frame_type == "render":
|
287 |
+
vis_frame_stride = 1
|
288 |
plt.close("all")
|
289 |
for out_frame_idx in range(0, len(frame_names), vis_frame_stride):
|
290 |
plt.figure(figsize=(6, 4))
|
|
|
303 |
# Close the plot
|
304 |
plt.close()
|
305 |
|
306 |
+
if vis_frame_type == "check":
|
307 |
+
return gr.update(value=jpeg_images, visible=True), gr.update(visible=False, value=None)
|
308 |
+
elif vis_frame_type == "render":
|
309 |
+
return gr.update(visible=False, value=None), gr.update(value=final_vid, visible=True)
|
310 |
+
|
311 |
|
312 |
with gr.Blocks() as demo:
|
313 |
first_frame_path = gr.State()
|
|
|
332 |
points_map = gr.Image(
|
333 |
label="points map",
|
334 |
type="filepath",
|
335 |
+
interactive=False
|
336 |
)
|
337 |
video_in = gr.Video(label="Video IN")
|
338 |
with gr.Row():
|
|
|
342 |
submit_btn = gr.Button("Submit")
|
343 |
with gr.Column():
|
344 |
output_result = gr.Image()
|
345 |
+
with gr.Row():
|
346 |
+
vis_frame_type = gr.Radio(choices=["check", "render"], value="render", scale=2)
|
347 |
+
propagate_btn = gr.Button("Propagate", scale=1)
|
348 |
+
output_propagated = gr.Gallery(visible=False)
|
349 |
+
output_video = gr.Video(visible=False)
|
350 |
# output_result_mask = gr.Image()
|
351 |
|
352 |
clear_points_btn.click(
|
|
|
378 |
|
379 |
propagate_btn.click(
|
380 |
fn = propagate_to_all,
|
381 |
+
inputs = [checkpoint, stored_inference_state, stored_frame_names, video_frames_dir, vis_frame_type],
|
382 |
+
outputs = [output_propagated, output_video]
|
383 |
)
|
384 |
|
385 |
demo.launch(show_api=False, show_error=True)
|