multimodalart HF staff commited on
Commit
b7be496
1 Parent(s): d691db5

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +25 -23
app.py CHANGED
@@ -55,36 +55,39 @@ Other times the user will not want modifications , but instead want a new image
55
  Video descriptions must have the same num of words as examples below. Extra words will be ignored.
56
  """
57
 
 
 
 
 
 
 
 
 
 
58
  def get_video_dimensions(input_video_path):
59
  reader = imageio_ffmpeg.read_frames(input_video_path)
60
  metadata = next(reader)
61
  return metadata['size']
62
 
63
  def center_crop_resize(input_video_path, target_width=720, target_height=480):
64
- # Open the video file
65
  cap = cv2.VideoCapture(input_video_path)
66
 
67
- # Get original video properties
68
  orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
69
  orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
70
  orig_fps = cap.get(cv2.CAP_PROP_FPS)
71
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
72
 
73
- # Calculate resize factor
74
  width_factor = target_width / orig_width
75
  height_factor = target_height / orig_height
76
  resize_factor = max(width_factor, height_factor)
77
 
78
- # Calculate intermediate size
79
  inter_width = int(orig_width * resize_factor)
80
  inter_height = int(orig_height * resize_factor)
81
 
82
- # Calculate frame skip
83
  target_fps = 8
84
  ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1)
85
  skip = min(5, ideal_skip) # Cap at 5
86
 
87
- # Adjust skip if not enough frames
88
  while (total_frames / (skip + 1)) < 49 and skip > 0:
89
  skip -= 1
90
 
@@ -98,10 +101,8 @@ def center_crop_resize(input_video_path, target_width=720, target_height=480):
98
  break
99
 
100
  if total_read % (skip + 1) == 0:
101
- # Resize frame
102
  resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA)
103
 
104
- # Center crop
105
  start_x = (inter_width - target_width) // 2
106
  start_y = (inter_height - target_height) // 2
107
  cropped = resized[start_y:start_y+target_height, start_x:start_x+target_width]
@@ -113,7 +114,6 @@ def center_crop_resize(input_video_path, target_width=720, target_height=480):
113
 
114
  cap.release()
115
 
116
- # Save the processed video to a temporary file
117
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
118
  temp_video_path = temp_file.name
119
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
@@ -188,13 +188,12 @@ def infer(
188
  seed = random.randint(0, 2 ** 8 - 1)
189
  if(video_input):
190
  video = load_video(video_input)[:49] # Limit to 49 frames
191
- video_pt = pipe(
192
  video=video,
193
  prompt=prompt,
194
  num_inference_steps=num_inference_steps,
195
  num_videos_per_prompt=1,
196
  strength=video_strenght,
197
- num_frames=49,
198
  use_dynamic_cfg=True,
199
  output_type="pt",
200
  guidance_scale=guidance_scale,
@@ -241,7 +240,7 @@ def delete_old_files():
241
 
242
 
243
  threading.Thread(target=delete_old_files, daemon=True).start()
244
- examples = [["horse.mp4", "Pixel art of a horse running"]]
245
 
246
  with gr.Blocks() as demo:
247
  gr.Markdown("""
@@ -265,12 +264,11 @@ with gr.Blocks() as demo:
265
 
266
  """)
267
  with gr.Row():
268
- with gr.Accordion("Video-to-video", open=False):
269
- video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
270
- strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
271
- examples_component = gr.Examples(examples, fn=process_video, inputs=[input_video, prompt], outputs=output_video, cache_examples="lazy")
272
- examples_component.dataset._components = [input_video]
273
  with gr.Column():
 
 
 
 
274
  prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
275
 
276
  with gr.Row():
@@ -366,14 +364,18 @@ with gr.Blocks() as demo:
366
 
367
 
368
  def generate(prompt,
 
 
369
  seed_value,
370
  scale_status,
371
  rife_status,
372
- progress=gr.Progress(track_tqdm=True)
373
  ):
374
 
375
  latents, seed = infer(
376
  prompt,
 
 
377
  num_inference_steps=50, # NOT Changed
378
  guidance_scale=7.0, # NOT Changed
379
  seed=seed_value,
@@ -409,17 +411,17 @@ with gr.Blocks() as demo:
409
 
410
  generate_button.click(
411
  generate,
412
- inputs=[prompt, seed_param, enable_scale, enable_rife],
413
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
414
  )
415
 
416
  enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
417
 
418
- input_video.upload(
419
  resize_if_unfit,
420
- inputs=[input_video],
421
- outputs=[input_video]
422
  )
423
  if __name__ == "__main__":
424
  demo.queue(max_size=15)
425
- demo.launch()
 
55
  Video descriptions must have the same num of words as examples below. Extra words will be ignored.
56
  """
57
 
58
+ def resize_if_unfit(input_video, progress=gr.Progress(track_tqdm=True)):
59
+ width, height = get_video_dimensions(input_video)
60
+
61
+ if width == 720 and height == 480:
62
+ processed_video = input_video
63
+ else:
64
+ processed_video = center_crop_resize(input_video)
65
+ return processed_video
66
+
67
  def get_video_dimensions(input_video_path):
68
  reader = imageio_ffmpeg.read_frames(input_video_path)
69
  metadata = next(reader)
70
  return metadata['size']
71
 
72
  def center_crop_resize(input_video_path, target_width=720, target_height=480):
 
73
  cap = cv2.VideoCapture(input_video_path)
74
 
 
75
  orig_width = int(cap.get(cv2.CAP_PROP_FRAME_WIDTH))
76
  orig_height = int(cap.get(cv2.CAP_PROP_FRAME_HEIGHT))
77
  orig_fps = cap.get(cv2.CAP_PROP_FPS)
78
  total_frames = int(cap.get(cv2.CAP_PROP_FRAME_COUNT))
79
 
 
80
  width_factor = target_width / orig_width
81
  height_factor = target_height / orig_height
82
  resize_factor = max(width_factor, height_factor)
83
 
 
84
  inter_width = int(orig_width * resize_factor)
85
  inter_height = int(orig_height * resize_factor)
86
 
 
87
  target_fps = 8
88
  ideal_skip = max(0, math.ceil(orig_fps / target_fps) - 1)
89
  skip = min(5, ideal_skip) # Cap at 5
90
 
 
91
  while (total_frames / (skip + 1)) < 49 and skip > 0:
92
  skip -= 1
93
 
 
101
  break
102
 
103
  if total_read % (skip + 1) == 0:
 
104
  resized = cv2.resize(frame, (inter_width, inter_height), interpolation=cv2.INTER_AREA)
105
 
 
106
  start_x = (inter_width - target_width) // 2
107
  start_y = (inter_height - target_height) // 2
108
  cropped = resized[start_y:start_y+target_height, start_x:start_x+target_width]
 
114
 
115
  cap.release()
116
 
 
117
  with tempfile.NamedTemporaryFile(suffix='.mp4', delete=False) as temp_file:
118
  temp_video_path = temp_file.name
119
  fourcc = cv2.VideoWriter_fourcc(*'mp4v')
 
188
  seed = random.randint(0, 2 ** 8 - 1)
189
  if(video_input):
190
  video = load_video(video_input)[:49] # Limit to 49 frames
191
+ video_pt = pipe_video(
192
  video=video,
193
  prompt=prompt,
194
  num_inference_steps=num_inference_steps,
195
  num_videos_per_prompt=1,
196
  strength=video_strenght,
 
197
  use_dynamic_cfg=True,
198
  output_type="pt",
199
  guidance_scale=guidance_scale,
 
240
 
241
 
242
  threading.Thread(target=delete_old_files, daemon=True).start()
243
+ examples = [["horse.mp4"], ["kitten.mp4"], ["train_running.mp4"]]
244
 
245
  with gr.Blocks() as demo:
246
  gr.Markdown("""
 
264
 
265
  """)
266
  with gr.Row():
 
 
 
 
 
267
  with gr.Column():
268
+ with gr.Accordion("Video-to-video", open=False):
269
+ video_input = gr.Video(label="Input Video (will be cropped to 49 frames, 6 seconds at 8fps)")
270
+ strength = gr.Slider(0.1, 1.0, value=0.8, step=0.01, label="Strength")
271
+ examples_component = gr.Examples(examples, inputs=[video_input], cache_examples=False)
272
  prompt = gr.Textbox(label="Prompt (Less than 200 Words)", placeholder="Enter your prompt here", lines=5)
273
 
274
  with gr.Row():
 
364
 
365
 
366
  def generate(prompt,
367
+ video_input,
368
+ video_strenght,
369
  seed_value,
370
  scale_status,
371
  rife_status,
372
+ #progress=gr.Progress(track_tqdm=True)
373
  ):
374
 
375
  latents, seed = infer(
376
  prompt,
377
+ video_input,
378
+ video_strenght,
379
  num_inference_steps=50, # NOT Changed
380
  guidance_scale=7.0, # NOT Changed
381
  seed=seed_value,
 
411
 
412
  generate_button.click(
413
  generate,
414
+ inputs=[prompt, video_input, strength, seed_param, enable_scale, enable_rife],
415
  outputs=[video_output, download_video_button, download_gif_button, seed_text],
416
  )
417
 
418
  enhance_button.click(enhance_prompt_func, inputs=[prompt], outputs=[prompt])
419
 
420
+ video_input.upload(
421
  resize_if_unfit,
422
+ inputs=[video_input],
423
+ outputs=[video_input]
424
  )
425
  if __name__ == "__main__":
426
  demo.queue(max_size=15)
427
+ demo.launch()