Pie31415 commited on
Commit
4a38f47
β€’
1 Parent(s): 8160e04

updated app for video inference

Browse files
Files changed (1) hide show
  1. app.py +39 -31
app.py CHANGED
@@ -138,34 +138,35 @@ def image_inference(
138
  out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
139
  return res[..., ::-1]
140
 
141
- def extract_frames(driver_vid):
142
- image_frames = []
143
- vid = cv2.VideoCapture(driver_vid) # path to mp4
144
-
145
- while True:
146
- success, img = vid.read()
147
-
148
- if not success: break
149
-
150
- img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
151
- pil_img = Image.fromarray(img)
152
- image_frames.append(pil_img)
153
-
154
- return image_frames
155
-
156
- def video_inference(source_img, driver_vid):
 
 
 
 
 
157
  image_frames = extract_frames(driver_vid)
158
 
159
  resulted_imgs = defaultdict(list)
160
 
161
- video_folder = 'jenya_driver/'
162
- image_frames = sorted(glob(f"{video_folder}/*", recursive=True), key=lambda x: int(x.split('/')[-1][:-4]))
163
-
164
  mask_hard_threshold = 0.5
165
- N = len(image_frames)//20
166
- for i in range(0, N, 4):
167
- new_out = infer.evaluate(source_img, Image.open(image_frames[i]),
168
- source_information_for_reuse=out.get('source_information'))
169
 
170
  mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float()
171
  mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255)
@@ -192,34 +193,41 @@ def video_inference(source_img, driver_vid):
192
  im.set_data(video[i,:,:,::-1])
193
  return im
194
 
195
- anim = animation.FuncAnimation(fig, animate, init_func=init,
196
- frames=video.shape[0], interval=30)
197
 
198
- return anim
199
 
200
  with gr.Blocks() as demo:
201
  gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**")
202
-
203
  gr.Markdown(
204
  """
 
 
205
  <p style='text-align: center'>
206
  Create a personal avatar from just a single image using ROME.
207
  <br> <a href='https://arxiv.org/abs/2206.08343' target='_blank'>Paper</a> | <a href='https://samsunglabs.github.io/rome' target='_blank'>Project Page</a> | <a href='https://github.com/SamsungLabs/rome' target='_blank'>Github</a>
208
  </p>
 
 
 
 
 
209
  """
210
  )
211
 
212
  with gr.Tab("Image Inference"):
213
  with gr.Row():
214
- source_img = gr.Image(type="pil", label="source image", show_label=True)
215
- driver_img = gr.Image(type="pil", label="driver image", show_label=True)
216
- image_output = gr.Image()
217
  image_button = gr.Button("Predict")
218
  with gr.Tab("Video Inference"):
219
  with gr.Row():
220
  source_img2 = gr.Image(type="pil", label="source image", show_label=True)
221
  driver_vid = gr.Video(label="driver video")
222
- video_output = gr.Image()
223
  video_button = gr.Button("Predict")
224
 
225
  gr.Examples(
 
138
  out['render_masked'].cpu(), out['pred_target_shape_img'][0].cpu()], dim=2))
139
  return res[..., ::-1]
140
 
141
+ def extract_frames(
142
+ driver_vid: gr.inputs.Video = None
143
+ ):
144
+ image_frames = []
145
+ vid = cv2.VideoCapture(driver_vid) # path to mp4
146
+
147
+ while True:
148
+ success, img = vid.read()
149
+
150
+ if not success: break
151
+
152
+ img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB)
153
+ pil_img = Image.fromarray(img)
154
+ image_frames.append(pil_img)
155
+
156
+ return image_frames
157
+
158
+ def video_inference(
159
+ source_img: gr.inputs.Image = None,
160
+ driver_vid: gr.inputs.Video = None
161
+ ):
162
  image_frames = extract_frames(driver_vid)
163
 
164
  resulted_imgs = defaultdict(list)
165
 
 
 
 
166
  mask_hard_threshold = 0.5
167
+ N = len(image_frames)
168
+ for i in range(0, N, 4): # frame limits
169
+ new_out = infer.evaluate(source_img, image_frames[i])
 
170
 
171
  mask_pred = (new_out['pred_target_unet_mask'].cpu() > mask_hard_threshold).float()
172
  mask_pred = mask_errosion(mask_pred[0].float().numpy() * 255)
 
193
  im.set_data(video[i,:,:,::-1])
194
  return im
195
 
196
+ anim = animation.FuncAnimation(fig, animate, init_func=init, frames=video.shape[0], interval=30)
197
+ anim.save("avatar.gif", dpi=300, writer = animation.PillowWriter(fps=24))
198
 
199
+ return "avatar.gif"
200
 
201
  with gr.Blocks() as demo:
202
  gr.Markdown("# **<p align='center'>ROME: Realistic one-shot mesh-based head avatars</p>**")
203
+
204
  gr.Markdown(
205
  """
206
+ <img src='https://github.com/SamsungLabs/rome/blob/main/media/tease.gif'>
207
+
208
  <p style='text-align: center'>
209
  Create a personal avatar from just a single image using ROME.
210
  <br> <a href='https://arxiv.org/abs/2206.08343' target='_blank'>Paper</a> | <a href='https://samsunglabs.github.io/rome' target='_blank'>Project Page</a> | <a href='https://github.com/SamsungLabs/rome' target='_blank'>Github</a>
211
  </p>
212
+
213
+ <blockquote>
214
+ [The] system creates realistic mesh-based avatars from a single <strong>source</strong>
215
+ photo. These avatars are rigged, i.e., they can be driven by the animation parameters from a different <strong>driving</strong> frame.
216
+ </blockquote>
217
  """
218
  )
219
 
220
  with gr.Tab("Image Inference"):
221
  with gr.Row():
222
+ source_img = gr.Image(type="pil", label="Source image", show_label=True)
223
+ driver_img = gr.Image(type="pil", label="Driver image", show_label=True)
224
+ image_output = gr.Image("Rendered avatar")
225
  image_button = gr.Button("Predict")
226
  with gr.Tab("Video Inference"):
227
  with gr.Row():
228
  source_img2 = gr.Image(type="pil", label="source image", show_label=True)
229
  driver_vid = gr.Video(label="driver video")
230
+ video_output = gr.Image(label="Rendered GIF avatar")
231
  video_button = gr.Button("Predict")
232
 
233
  gr.Examples(