abreza commited on
Commit
0121705
1 Parent(s): f1e3ccc

refactor the app file

Browse files
Files changed (2) hide show
  1. app.py +56 -209
  2. examples/get_examples.py +66 -0
app.py CHANGED
@@ -1,13 +1,14 @@
1
  import os
2
  import platform
3
- import torch
4
- import gradio as gr
5
- from huggingface_hub import snapshot_download
6
  import uuid
7
  import shutil
8
  from pydub import AudioSegment
9
  import spaces
 
 
 
10
 
 
11
  from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
12
  from src.utils.preprocess import CropAndExtract
13
  from src.test_audio2coeff import Audio2Coeff
@@ -16,179 +17,113 @@ from src.generate_batch import get_data
16
  from src.generate_facerender_batch import get_facerender_data
17
  from src.utils.init_path import init_path
18
 
19
-
20
- def get_source_image(image):
21
- return image
22
-
23
-
24
- def toggle_audio_file(choice):
25
- if choice == False:
26
- return gr.update(visible=True), gr.update(visible=False)
27
- else:
28
- return gr.update(visible=False), gr.update(visible=True)
29
-
30
-
31
- def ref_video_fn(path_of_ref_video):
32
- if path_of_ref_video is not None:
33
- return gr.update(value=True)
34
- else:
35
- return gr.update(value=False)
36
-
37
-
38
- if torch.cuda.is_available():
39
- device = "cuda"
40
- elif platform.system() == 'Darwin': # macos
41
- device = "mps"
42
- else:
43
- device = "cpu"
44
-
45
- os.environ['TORCH_HOME'] = 'checkpoints'
46
-
47
  checkpoint_path = 'checkpoints'
48
  config_path = 'src/config'
 
 
49
 
 
50
  snapshot_download(repo_id='vinthony/SadTalker-V002rc',
51
- local_dir='./checkpoints', local_dir_use_symlinks=True)
52
 
53
 
54
  def mp3_to_wav(mp3_filename, wav_filename, frame_rate):
55
- mp3_file = AudioSegment.from_file(file=mp3_filename)
56
- mp3_file.set_frame_rate(frame_rate).export(wav_filename, format="wav")
57
 
58
 
59
  @spaces.GPU()
60
- def test(source_image, driven_audio, preprocess='crop',
61
- still_mode=False, use_enhancer=False, batch_size=1, size=256,
62
- pose_style=0,
63
- facerender='facevid2vid',
64
- exp_scale=1.0,
65
- use_ref_video=False,
66
- ref_video=None,
67
- ref_info=None,
68
- use_idle_mode=False,
69
- length_of_audio=0, use_blink=True,
70
- result_dir='./results/'):
71
-
72
  sadtalker_paths = init_path(
73
  checkpoint_path, config_path, size, False, preprocess)
74
-
75
  audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
76
  preprocess_model = CropAndExtract(sadtalker_paths, device)
 
 
77
 
78
- if facerender == 'facevid2vid' and device != 'mps':
79
- animate_from_coeff = AnimateFromCoeff(
80
- sadtalker_paths, device)
81
- elif facerender == 'pirender' or device == 'mps':
82
- animate_from_coeff = AnimateFromCoeff_PIRender(
83
- sadtalker_paths, device)
84
- facerender = 'pirender'
85
- else:
86
- raise (RuntimeError('Unknown model: {}'.format(facerender)))
87
-
88
  time_tag = str(uuid.uuid4())
89
  save_dir = os.path.join(result_dir, time_tag)
90
  os.makedirs(save_dir, exist_ok=True)
91
-
92
  input_dir = os.path.join(save_dir, 'input')
93
  os.makedirs(input_dir, exist_ok=True)
94
 
95
- print(source_image)
96
  pic_path = os.path.join(input_dir, os.path.basename(source_image))
97
  shutil.move(source_image, input_dir)
98
 
99
- if driven_audio is not None and os.path.isfile(driven_audio):
 
100
  audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
101
-
102
- # mp3 to wav
103
  if '.mp3' in audio_path:
104
  mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000)
105
  audio_path = audio_path.replace('.mp3', '.wav')
106
  else:
107
  shutil.move(driven_audio, input_dir)
108
-
109
  elif use_idle_mode:
110
- # generate audio from this new audio_path
111
  audio_path = os.path.join(
112
  input_dir, 'idlemode_'+str(length_of_audio)+'.wav')
113
- from pydub import AudioSegment
114
- one_sec_segment = AudioSegment.silent(
115
- duration=1000*length_of_audio) # duration in milliseconds
116
- one_sec_segment.export(audio_path, format="wav")
117
  else:
118
- print(use_ref_video, ref_info)
119
- assert use_ref_video == True and ref_info == 'all'
120
 
121
- if use_ref_video and ref_info == 'all': # full ref mode
122
- ref_video_videoname = os.path.basename(ref_video)
 
123
  audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
124
- print('new audiopath:', audio_path)
125
- # if ref_video contains audio, set the audio from ref_video.
126
- cmd = r"ffmpeg -y -hide_banner -loglevel error -i %s %s" % (
127
- ref_video, audio_path)
128
- os.system(cmd)
129
-
130
- os.makedirs(save_dir, exist_ok=True)
 
131
 
132
- # crop image and extract 3dmm from image
133
  first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
134
  os.makedirs(first_frame_dir, exist_ok=True)
135
  first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(
136
  pic_path, first_frame_dir, preprocess, True, size)
137
-
138
  if first_coeff_path is None:
139
  raise AttributeError("No face is detected")
140
 
 
141
  if use_ref_video:
142
- print('using ref video for genreation')
143
- ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
144
- ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)
145
- os.makedirs(ref_video_frame_dir, exist_ok=True)
146
- print('3DMM Extraction for the reference video providing pose')
147
- ref_video_coeff_path, _, _ = preprocess_model.generate(
148
- ref_video, ref_video_frame_dir, preprocess, source_image_flag=False)
149
- else:
150
- ref_video_coeff_path = None
151
-
152
- if use_ref_video:
153
  if ref_info == 'pose':
154
  ref_pose_coeff_path = ref_video_coeff_path
155
- ref_eyeblink_coeff_path = None
156
  elif ref_info == 'blink':
157
- ref_pose_coeff_path = None
158
  ref_eyeblink_coeff_path = ref_video_coeff_path
159
  elif ref_info == 'pose+blink':
160
- ref_pose_coeff_path = ref_video_coeff_path
161
- ref_eyeblink_coeff_path = ref_video_coeff_path
162
- elif ref_info == 'all':
163
- ref_pose_coeff_path = None
164
- ref_eyeblink_coeff_path = None
165
- else:
166
- raise ('error in refinfo')
167
- else:
168
- ref_pose_coeff_path = None
169
- ref_eyeblink_coeff_path = None
170
 
171
- # audio2ceoff
172
  if use_ref_video and ref_info == 'all':
173
- # audio_to_coeff.generate(batch, save_dir, pose_style, ref_pose_coeff_path)
174
  coeff_path = ref_video_coeff_path
175
  else:
176
- batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path, still=still_mode,
177
- idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink) # longer audio?
178
  coeff_path = audio_to_coeff.generate(
179
  batch, save_dir, pose_style, ref_pose_coeff_path)
180
 
181
- # coeff2video
182
  data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode,
183
  preprocess=preprocess, size=size, expression_scale=exp_scale, facemodel=facerender)
184
- return_path = animate_from_coeff.generate(
185
- data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None, preprocess=preprocess, img_size=size)
186
  video_name = data['video_name']
187
  print(f'The generated video is named {video_name} in {save_dir}')
188
 
189
  return return_path
190
 
191
 
 
192
  with gr.Blocks(analytics_enabled=False) as demo:
193
  with gr.Row():
194
  with gr.Column(variant='panel'):
@@ -214,8 +149,8 @@ with gr.Blocks(analytics_enabled=False) as demo:
214
  label="Use Idle Animation")
215
  length_of_audio = gr.Number(
216
  value=5, label="The length(seconds) of the generated video.")
217
- use_idle_mode.change(toggle_audio_file, inputs=use_idle_mode, outputs=[
218
- driven_audio, driven_audio_no]) # todo
219
 
220
  with gr.Row():
221
  ref_video = gr.Video(
@@ -227,15 +162,13 @@ with gr.Blocks(analytics_enabled=False) as demo:
227
  ref_info = gr.Radio(['pose', 'blink', 'pose+blink', 'all'], value='pose', label='Reference Video',
228
  info="How to borrow from reference Video?((fully transfer, aka, video driving mode))")
229
 
230
- ref_video.change(ref_video_fn, inputs=ref_video, outputs=[
231
- use_ref_video]) # todo
232
 
233
  with gr.Column(variant='panel'):
234
  with gr.Tabs(elem_id="sadtalker_checkbox"):
235
  with gr.TabItem('Settings'):
236
  with gr.Column(variant='panel'):
237
- # width = gr.Slider(minimum=64, elem_id="img2img_width", maximum=2048, step=8, label="Manually Crop Width", value=512) # img2img_width
238
- # height = gr.Slider(minimum=64, elem_id="img2img_height", maximum=2048, step=8, label="Manually Crop Height", value=512) # img2img_width
239
  with gr.Row():
240
  pose_style = gr.Slider(
241
  minimum=0, maximum=45, step=1, label="Pose style", value=0)
@@ -265,104 +198,18 @@ with gr.Blocks(analytics_enabled=False) as demo:
265
  submit = gr.Button(
266
  'Generate', elem_id="sadtalker_generate", variant='primary')
267
 
268
- with gr.Tabs(elem_id="sadtalker_genearted"):
269
  gen_video = gr.Video(label="Generated video")
270
 
271
  submit.click(
272
- fn=test,
273
- inputs=[source_image,
274
- driven_audio,
275
- preprocess_type,
276
- is_still_mode,
277
- enhancer,
278
- batch_size,
279
- size_of_image,
280
- pose_style,
281
- facerender,
282
- exp_weight,
283
- use_ref_video,
284
- ref_video,
285
- ref_info,
286
- use_idle_mode,
287
- length_of_audio,
288
- blink_every
289
- ],
290
  outputs=[gen_video],
291
  )
292
 
293
  with gr.Row():
294
- gr.Examples(examples=[
295
- [
296
- 'examples/source_image/full_body_1.png',
297
- 'examples/driven_audio/bus_chinese.wav',
298
- 'crop',
299
- True,
300
- False
301
- ],
302
- [
303
- 'examples/source_image/full_body_2.png',
304
- 'examples/driven_audio/japanese.wav',
305
- 'crop',
306
- False,
307
- False
308
- ],
309
- [
310
- 'examples/source_image/full3.png',
311
- 'examples/driven_audio/deyu.wav',
312
- 'crop',
313
- False,
314
- True
315
- ],
316
- [
317
- 'examples/source_image/full4.jpeg',
318
- 'examples/driven_audio/eluosi.wav',
319
- 'full',
320
- False,
321
- True
322
- ],
323
- [
324
- 'examples/source_image/full4.jpeg',
325
- 'examples/driven_audio/imagine.wav',
326
- 'full',
327
- True,
328
- True
329
- ],
330
- [
331
- 'examples/source_image/full_body_1.png',
332
- 'examples/driven_audio/bus_chinese.wav',
333
- 'full',
334
- True,
335
- False
336
- ],
337
- [
338
- 'examples/source_image/art_13.png',
339
- 'examples/driven_audio/fayu.wav',
340
- 'resize',
341
- True,
342
- False
343
- ],
344
- [
345
- 'examples/source_image/art_5.png',
346
- 'examples/driven_audio/chinese_news.wav',
347
- 'resize',
348
- False,
349
- False
350
- ],
351
- [
352
- 'examples/source_image/art_5.png',
353
- 'examples/driven_audio/RD_Radio31_000.wav',
354
- 'resize',
355
- True,
356
- True
357
- ],
358
- ],
359
- inputs=[
360
- source_image,
361
- driven_audio,
362
- preprocess_type,
363
- is_still_mode,
364
- enhancer],
365
- outputs=[gen_video],
366
- fn=test)
367
 
368
  demo.launch(debug=True)
 
1
  import os
2
  import platform
 
 
 
3
  import uuid
4
  import shutil
5
  from pydub import AudioSegment
6
  import spaces
7
+ import torch
8
+ import gradio as gr
9
+ from huggingface_hub import snapshot_download
10
 
11
+ from examples.get_examples import get_examples
12
  from src.facerender.pirender_animate import AnimateFromCoeff_PIRender
13
  from src.utils.preprocess import CropAndExtract
14
  from src.test_audio2coeff import Audio2Coeff
 
17
  from src.generate_facerender_batch import get_facerender_data
18
  from src.utils.init_path import init_path
19
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
20
  checkpoint_path = 'checkpoints'
21
  config_path = 'src/config'
22
+ device = "cuda" if torch.cuda.is_available(
23
+ ) else "mps" if platform.system() == 'Darwin' else "cpu"
24
 
25
+ os.environ['TORCH_HOME'] = checkpoint_path
26
  snapshot_download(repo_id='vinthony/SadTalker-V002rc',
27
+ local_dir=checkpoint_path, local_dir_use_symlinks=True)
28
 
29
 
30
  def mp3_to_wav(mp3_filename, wav_filename, frame_rate):
31
+ AudioSegment.from_file(file=mp3_filename).set_frame_rate(
32
+ frame_rate).export(wav_filename, format="wav")
33
 
34
 
35
  @spaces.GPU()
36
+ def generate_video(source_image, driven_audio, preprocess='crop', still_mode=False, use_enhancer=False,
37
+ batch_size=1, size=256, pose_style=0, facerender='facevid2vid', exp_scale=1.0,
38
+ use_ref_video=False, ref_video=None, ref_info=None, use_idle_mode=False,
39
+ length_of_audio=0, use_blink=True, result_dir='./results/'):
40
+ # Initialize models and paths
 
 
 
 
 
 
 
41
  sadtalker_paths = init_path(
42
  checkpoint_path, config_path, size, False, preprocess)
 
43
  audio_to_coeff = Audio2Coeff(sadtalker_paths, device)
44
  preprocess_model = CropAndExtract(sadtalker_paths, device)
45
+ animate_from_coeff = AnimateFromCoeff(sadtalker_paths, device) if facerender == 'facevid2vid' and device != 'mps' \
46
+ else AnimateFromCoeff_PIRender(sadtalker_paths, device)
47
 
48
+ # Create directories for saving results
 
 
 
 
 
 
 
 
 
49
  time_tag = str(uuid.uuid4())
50
  save_dir = os.path.join(result_dir, time_tag)
51
  os.makedirs(save_dir, exist_ok=True)
 
52
  input_dir = os.path.join(save_dir, 'input')
53
  os.makedirs(input_dir, exist_ok=True)
54
 
55
+ # Process source image
56
  pic_path = os.path.join(input_dir, os.path.basename(source_image))
57
  shutil.move(source_image, input_dir)
58
 
59
+ # Process driven audio
60
+ if driven_audio and os.path.isfile(driven_audio):
61
  audio_path = os.path.join(input_dir, os.path.basename(driven_audio))
 
 
62
  if '.mp3' in audio_path:
63
  mp3_to_wav(driven_audio, audio_path.replace('.mp3', '.wav'), 16000)
64
  audio_path = audio_path.replace('.mp3', '.wav')
65
  else:
66
  shutil.move(driven_audio, input_dir)
 
67
  elif use_idle_mode:
 
68
  audio_path = os.path.join(
69
  input_dir, 'idlemode_'+str(length_of_audio)+'.wav')
70
+ AudioSegment.silent(
71
+ duration=1000*length_of_audio).export(audio_path, format="wav")
 
 
72
  else:
73
+ assert use_ref_video and ref_info == 'all'
 
74
 
75
+ # Process reference video
76
+ if use_ref_video and ref_info == 'all':
77
+ ref_video_videoname = os.path.splitext(os.path.split(ref_video)[-1])[0]
78
  audio_path = os.path.join(save_dir, ref_video_videoname+'.wav')
79
+ os.system(
80
+ f"ffmpeg -y -hide_banner -loglevel error -i {ref_video} {audio_path}")
81
+ ref_video_frame_dir = os.path.join(save_dir, ref_video_videoname)
82
+ os.makedirs(ref_video_frame_dir, exist_ok=True)
83
+ ref_video_coeff_path, _, _ = preprocess_model.generate(
84
+ ref_video, ref_video_frame_dir, preprocess, source_image_flag=False)
85
+ else:
86
+ ref_video_coeff_path = None
87
 
88
+ # Preprocess source image
89
  first_frame_dir = os.path.join(save_dir, 'first_frame_dir')
90
  os.makedirs(first_frame_dir, exist_ok=True)
91
  first_coeff_path, crop_pic_path, crop_info = preprocess_model.generate(
92
  pic_path, first_frame_dir, preprocess, True, size)
 
93
  if first_coeff_path is None:
94
  raise AttributeError("No face is detected")
95
 
96
+ # Determine reference coefficients
97
  if use_ref_video:
98
+ ref_pose_coeff_path, ref_eyeblink_coeff_path = None, None
 
 
 
 
 
 
 
 
 
 
99
  if ref_info == 'pose':
100
  ref_pose_coeff_path = ref_video_coeff_path
 
101
  elif ref_info == 'blink':
 
102
  ref_eyeblink_coeff_path = ref_video_coeff_path
103
  elif ref_info == 'pose+blink':
104
+ ref_pose_coeff_path = ref_eyeblink_coeff_path = ref_video_coeff_path
 
 
 
 
 
 
 
 
 
105
 
106
+ # Generate coefficients from audio or reference video
107
  if use_ref_video and ref_info == 'all':
 
108
  coeff_path = ref_video_coeff_path
109
  else:
110
+ batch = get_data(first_coeff_path, audio_path, device, ref_eyeblink_coeff_path=ref_eyeblink_coeff_path,
111
+ still=still_mode, idlemode=use_idle_mode, length_of_audio=length_of_audio, use_blink=use_blink)
112
  coeff_path = audio_to_coeff.generate(
113
  batch, save_dir, pose_style, ref_pose_coeff_path)
114
 
115
+ # Generate video from coefficients
116
  data = get_facerender_data(coeff_path, crop_pic_path, first_coeff_path, audio_path, batch_size, still_mode=still_mode,
117
  preprocess=preprocess, size=size, expression_scale=exp_scale, facemodel=facerender)
118
+ return_path = animate_from_coeff.generate(data, save_dir, pic_path, crop_info, enhancer='gfpgan' if use_enhancer else None,
119
+ preprocess=preprocess, img_size=size)
120
  video_name = data['video_name']
121
  print(f'The generated video is named {video_name} in {save_dir}')
122
 
123
  return return_path
124
 
125
 
126
+ # Gradio UI
127
  with gr.Blocks(analytics_enabled=False) as demo:
128
  with gr.Row():
129
  with gr.Column(variant='panel'):
 
149
  label="Use Idle Animation")
150
  length_of_audio = gr.Number(
151
  value=5, label="The length(seconds) of the generated video.")
152
+ use_idle_mode.change(lambda choice: (gr.update(visible=not choice), gr.update(visible=choice)),
153
+ inputs=use_idle_mode, outputs=[driven_audio, driven_audio_no])
154
 
155
  with gr.Row():
156
  ref_video = gr.Video(
 
162
  ref_info = gr.Radio(['pose', 'blink', 'pose+blink', 'all'], value='pose', label='Reference Video',
163
  info="How to borrow from reference Video?((fully transfer, aka, video driving mode))")
164
 
165
+ ref_video.change(lambda path: gr.update(
166
+ value=path is not None), inputs=ref_video, outputs=use_ref_video)
167
 
168
  with gr.Column(variant='panel'):
169
  with gr.Tabs(elem_id="sadtalker_checkbox"):
170
  with gr.TabItem('Settings'):
171
  with gr.Column(variant='panel'):
 
 
172
  with gr.Row():
173
  pose_style = gr.Slider(
174
  minimum=0, maximum=45, step=1, label="Pose style", value=0)
 
198
  submit = gr.Button(
199
  'Generate', elem_id="sadtalker_generate", variant='primary')
200
 
201
+ with gr.Tabs(elem_id="sadtalker_generated"):
202
  gen_video = gr.Video(label="Generated video")
203
 
204
  submit.click(
205
+ fn=generate_video,
206
+ inputs=[source_image, driven_audio, preprocess_type, is_still_mode, enhancer, batch_size, size_of_image,
207
+ pose_style, facerender, exp_weight, use_ref_video, ref_video, ref_info, use_idle_mode, length_of_audio, blink_every],
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
208
  outputs=[gen_video],
209
  )
210
 
211
  with gr.Row():
212
+ gr.Examples(examples=get_examples(), inputs=[source_image, driven_audio, preprocess_type, is_still_mode, enhancer],
213
+ outputs=[gen_video], fn=generate_video)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
214
 
215
  demo.launch(debug=True)
examples/get_examples.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ def get_examples():
2
+ return [
3
+ [
4
+ 'examples/source_image/full_body_1.png',
5
+ 'examples/driven_audio/bus_chinese.wav',
6
+ 'crop',
7
+ True,
8
+ False
9
+ ],
10
+ [
11
+ 'examples/source_image/full_body_2.png',
12
+ 'examples/driven_audio/japanese.wav',
13
+ 'crop',
14
+ False,
15
+ False
16
+ ],
17
+ [
18
+ 'examples/source_image/full3.png',
19
+ 'examples/driven_audio/deyu.wav',
20
+ 'crop',
21
+ False,
22
+ True
23
+ ],
24
+ [
25
+ 'examples/source_image/full4.jpeg',
26
+ 'examples/driven_audio/eluosi.wav',
27
+ 'full',
28
+ False,
29
+ True
30
+ ],
31
+ [
32
+ 'examples/source_image/full4.jpeg',
33
+ 'examples/driven_audio/imagine.wav',
34
+ 'full',
35
+ True,
36
+ True
37
+ ],
38
+ [
39
+ 'examples/source_image/full_body_1.png',
40
+ 'examples/driven_audio/bus_chinese.wav',
41
+ 'full',
42
+ True,
43
+ False
44
+ ],
45
+ [
46
+ 'examples/source_image/art_13.png',
47
+ 'examples/driven_audio/fayu.wav',
48
+ 'resize',
49
+ True,
50
+ False
51
+ ],
52
+ [
53
+ 'examples/source_image/art_5.png',
54
+ 'examples/driven_audio/chinese_news.wav',
55
+ 'resize',
56
+ False,
57
+ False
58
+ ],
59
+ [
60
+ 'examples/source_image/art_5.png',
61
+ 'examples/driven_audio/RD_Radio31_000.wav',
62
+ 'resize',
63
+ True,
64
+ True
65
+ ],
66
+ ]