WorkTimer commited on
Commit
d4bcb75
1 Parent(s): f3caa5b

Update frame interval handling and frame_per slider to reflect exported image interval in video processing

Browse files
Files changed (1) hide show
  1. app.py +33 -6
app.py CHANGED
@@ -38,6 +38,23 @@ def clean(Seg_Tracker):
38
  torch.cuda.empty_cache()
39
  return None, ({}, {}), None, None, 0, None, None, None, 0
40
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
41
  def get_meta_from_video(Seg_Tracker, input_video, scale_slider, checkpoint):
42
 
43
  output_dir = '/tmp/output_frames'
@@ -49,10 +66,10 @@ def get_meta_from_video(Seg_Tracker, input_video, scale_slider, checkpoint):
49
  if input_video is None:
50
  return None, ({}, {}), None, None, 0, None, None, None, 0
51
  cap = cv2.VideoCapture(input_video)
 
52
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
53
  cap.release()
54
- output_frames = int(total_frames * scale_slider)
55
- frame_interval = max(1, total_frames // output_frames)
56
  print(f"frame_interval: {frame_interval}")
57
  try:
58
  ffmpeg.input(input_video, hwaccel='cuda').output(
@@ -99,7 +116,11 @@ def get_meta_from_video(Seg_Tracker, input_video, scale_slider, checkpoint):
99
  image_predictor = SAM2ImagePredictor(sam2_model)
100
  inference_state = predictor.init_state(video_path=output_dir)
101
  predictor.reset_state(inference_state)
102
- return (predictor, inference_state, image_predictor), ({}, {}), first_frame_rgb, first_frame_rgb, 0, None, None, None, 0
 
 
 
 
103
 
104
  def mask2bbox(mask):
105
  if len(np.where(mask > 0)[0]) == 0:
@@ -142,7 +163,7 @@ def draw_rect(image, bbox, obj_id):
142
  rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8)))
143
  inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8)))
144
  x0, y0, x1, y1 = bbox
145
- image_with_rect = cv2.rectangle(image.copy(), (x0, y0), (x1, y1), inv_color, thickness=2)
146
  return image_with_rect
147
 
148
  def sam_click(Seg_Tracker, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData):
@@ -432,7 +453,7 @@ def seg_track_app():
432
  with gr.Row():
433
  checkpoint = gr.Dropdown(label="Model Size", choices=["tiny", "small", "base-plus", "large"], value="tiny")
434
  scale_slider = gr.Slider(
435
- label="Downsampe Frame Rate",
436
  minimum=0.0,
437
  maximum=1.0,
438
  step=0.25,
@@ -464,7 +485,7 @@ def seg_track_app():
464
  with gr.Row():
465
  with gr.Column():
466
  frame_per = gr.Slider(
467
- label = "Percentage of Frames Viewed",
468
  minimum= 0.0,
469
  maximum= 100.0,
470
  step=0.01,
@@ -611,6 +632,12 @@ def seg_track_app():
611
  Seg_Tracker, input_first_frame, drawing_board, last_draw
612
  ]
613
  )
 
 
 
 
 
 
614
 
615
  app.queue(concurrency_count=1)
616
  app.launch(debug=True, enable_queue=True, share=False)
 
38
  torch.cuda.empty_cache()
39
  return None, ({}, {}), None, None, 0, None, None, None, 0
40
 
41
+ def change_video(input_video):
42
+ if input_video is None:
43
+ return 0, 0
44
+ cap = cv2.VideoCapture(input_video)
45
+ fps = cap.get(cv2.CAP_PROP_FPS)
46
+ total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
47
+ cap.release()
48
+ scale_slider = gr.Slider.update(minimum=1.0,
49
+ maximum=fps,
50
+ step=1.0,
51
+ value=fps,)
52
+ frame_per = gr.Slider.update(minimum= 0.0,
53
+ maximum= total_frames / fps,
54
+ step=1.0/fps,
55
+ value=0.0,)
56
+ return scale_slider, frame_per
57
+
58
  def get_meta_from_video(Seg_Tracker, input_video, scale_slider, checkpoint):
59
 
60
  output_dir = '/tmp/output_frames'
 
66
  if input_video is None:
67
  return None, ({}, {}), None, None, 0, None, None, None, 0
68
  cap = cv2.VideoCapture(input_video)
69
+ fps = cap.get(cv2.CAP_PROP_FPS)
70
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
71
  cap.release()
72
+ frame_interval = max(1, int(fps // scale_slider))
 
73
  print(f"frame_interval: {frame_interval}")
74
  try:
75
  ffmpeg.input(input_video, hwaccel='cuda').output(
 
116
  image_predictor = SAM2ImagePredictor(sam2_model)
117
  inference_state = predictor.init_state(video_path=output_dir)
118
  predictor.reset_state(inference_state)
119
+ frame_per = gr.Slider.update(minimum= 0.0,
120
+ maximum= total_frames / fps,
121
+ step=frame_interval / fps,
122
+ value=0.0,)
123
+ return (predictor, inference_state, image_predictor), ({}, {}), first_frame_rgb, first_frame_rgb, frame_per, None, None, None, 0
124
 
125
  def mask2bbox(mask):
126
  if len(np.where(mask > 0)[0]) == 0:
 
163
  rgb_color = tuple(map(int, (color[:3] * 255).astype(np.uint8)))
164
  inv_color = tuple(map(int, (255 - color[:3] * 255).astype(np.uint8)))
165
  x0, y0, x1, y1 = bbox
166
+ image_with_rect = cv2.rectangle(image.copy(), (x0, y0), (x1, y1), rgb_color, thickness=2)
167
  return image_with_rect
168
 
169
  def sam_click(Seg_Tracker, frame_num, point_mode, click_stack, ann_obj_id, evt: gr.SelectData):
 
453
  with gr.Row():
454
  checkpoint = gr.Dropdown(label="Model Size", choices=["tiny", "small", "base-plus", "large"], value="tiny")
455
  scale_slider = gr.Slider(
456
+ label="Downsampe Frame Rate (fps)",
457
  minimum=0.0,
458
  maximum=1.0,
459
  step=0.25,
 
485
  with gr.Row():
486
  with gr.Column():
487
  frame_per = gr.Slider(
488
+ label = "Time (seconds)",
489
  minimum= 0.0,
490
  maximum= 100.0,
491
  step=0.01,
 
632
  Seg_Tracker, input_first_frame, drawing_board, last_draw
633
  ]
634
  )
635
+
636
+ input_video.change(
637
+ fn=change_video,
638
+ inputs=[input_video],
639
+ outputs=[scale_slider, frame_per]
640
+ )
641
 
642
  app.queue(concurrency_count=1)
643
  app.launch(debug=True, enable_queue=True, share=False)