mrfakename commited on
Commit
0ac3155
1 Parent(s): 5d2c622

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (2) hide show
  1. README_REPO.md +3 -0
  2. inference-cli.py +51 -15
README_REPO.md CHANGED
@@ -86,6 +86,9 @@ Currently support 30s for a single generation, which is the **TOTAL** length of
86
 
87
  Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
88
 
 
 
 
89
  ```bash
90
  python inference-cli.py \
91
  --model "F5-TTS" \
 
86
 
87
  Either you can specify everything in `inference-cli.toml` or override with flags. Leave `--ref_text ""` will have ASR model transcribe the reference audio automatically (use extra GPU memory). If encounter network error, consider use local ckpt, just set `ckpt_path` in `inference-cli.py`
88
 
89
+ for change model use --ckpt_file to specify the model you want to load,
90
+ for change vocab.txt use --vocab_file to provide your vocab.txt file.
91
+
92
  ```bash
93
  python inference-cli.py \
94
  --model "F5-TTS" \
inference-cli.py CHANGED
@@ -36,6 +36,16 @@ parser.add_argument(
36
  "--model",
37
  help="F5-TTS | E2-TTS",
38
  )
 
 
 
 
 
 
 
 
 
 
39
  parser.add_argument(
40
  "-r",
41
  "--ref_audio",
@@ -88,6 +98,8 @@ if gen_file:
88
  gen_text = codecs.open(gen_file, "r", "utf-8").read()
89
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
90
  model = args.model if args.model else config["model"]
 
 
91
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
92
  wave_path = Path(output_dir)/"out.wav"
93
  spectrogram_path = Path(output_dir)/"out.png"
@@ -125,11 +137,19 @@ speed = 1.0
125
  # fix_duration = 27 # None or float (duration in seconds)
126
  fix_duration = None
127
 
128
- def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
129
- ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
130
- if not Path(ckpt_path).exists():
131
- ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
132
- vocab_char_map, vocab_size = get_tokenizer("Emilia_ZH_EN", "pinyin")
 
 
 
 
 
 
 
 
133
  model = CFM(
134
  transformer=model_cls(
135
  **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
@@ -149,14 +169,12 @@ def load_model(repo_name, exp_name, model_cls, model_cfg, ckpt_step):
149
 
150
  return model
151
 
152
-
153
  # load models
154
  F5TTS_model_cfg = dict(
155
  dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
156
  )
157
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
158
 
159
-
160
  def chunk_text(text, max_chars=135):
161
  """
162
  Splits the input text into chunks, each with a maximum number of characters.
@@ -184,12 +202,29 @@ def chunk_text(text, max_chars=135):
184
 
185
  return chunks
186
 
 
 
 
187
 
188
- def infer_batch(ref_audio, ref_text, gen_text_batches, model, remove_silence, cross_fade_duration=0.15):
189
  if model == "F5-TTS":
190
- ema_model = load_model(model, "F5TTS_Base", DiT, F5TTS_model_cfg, 1200000)
 
 
 
 
 
 
 
 
191
  elif model == "E2-TTS":
192
- ema_model = load_model(model, "E2TTS_Base", UNetT, E2TTS_model_cfg, 1200000)
 
 
 
 
 
 
193
 
194
  audio, sr = ref_audio
195
  if audio.shape[0] > 1:
@@ -325,7 +360,7 @@ def process_voice(ref_audio_orig, ref_text):
325
  print("Using custom reference text...")
326
  return ref_audio, ref_text
327
 
328
- def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_duration=0.15):
329
  print(gen_text)
330
  # Add the functionality to ensure it ends with ". "
331
  if not ref_text.endswith(". ") and not ref_text.endswith("。"):
@@ -343,10 +378,10 @@ def infer(ref_audio, ref_text, gen_text, model, remove_silence, cross_fade_durat
343
  print(f'gen_text {i}', gen_text)
344
 
345
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
346
- return infer_batch((audio, sr), ref_text, gen_text_batches, model, remove_silence, cross_fade_duration)
347
 
348
 
349
- def process(ref_audio, ref_text, text_gen, model, remove_silence):
350
  main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
351
  if "voices" not in config:
352
  voices = {"main": main_voice}
@@ -371,7 +406,7 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
371
  ref_audio = voices[voice]['ref_audio']
372
  ref_text = voices[voice]['ref_text']
373
  print(f"Voice: {voice}")
374
- audio, spectragram = infer(ref_audio, ref_text, gen_text, model, remove_silence)
375
  generated_audio_segments.append(audio)
376
 
377
  if generated_audio_segments:
@@ -389,4 +424,5 @@ def process(ref_audio, ref_text, text_gen, model, remove_silence):
389
  aseg.export(f.name, format="wav")
390
  print(f.name)
391
 
392
- process(ref_audio, ref_text, gen_text, model, remove_silence)
 
 
36
  "--model",
37
  help="F5-TTS | E2-TTS",
38
  )
39
+ parser.add_argument(
40
+ "-p",
41
+ "--ckpt_file",
42
+ help="The Checkpoint .pt",
43
+ )
44
+ parser.add_argument(
45
+ "-v",
46
+ "--vocab_file",
47
+ help="The vocab .txt",
48
+ )
49
  parser.add_argument(
50
  "-r",
51
  "--ref_audio",
 
98
  gen_text = codecs.open(gen_file, "r", "utf-8").read()
99
  output_dir = args.output_dir if args.output_dir else config["output_dir"]
100
  model = args.model if args.model else config["model"]
101
+ ckpt_file = args.ckpt_file if args.ckpt_file else ""
102
+ vocab_file = args.vocab_file if args.vocab_file else ""
103
  remove_silence = args.remove_silence if args.remove_silence else config["remove_silence"]
104
  wave_path = Path(output_dir)/"out.wav"
105
  spectrogram_path = Path(output_dir)/"out.png"
 
137
  # fix_duration = 27 # None or float (duration in seconds)
138
  fix_duration = None
139
 
140
+ def load_model(model_cls, model_cfg, ckpt_path,file_vocab):
141
+
142
+ if file_vocab=="":
143
+ file_vocab="Emilia_ZH_EN"
144
+ tokenizer="pinyin"
145
+ else:
146
+ tokenizer="custom"
147
+
148
+ print("\nvocab : ",vocab_file,tokenizer)
149
+ print("tokenizer : ",tokenizer)
150
+ print("model : ",ckpt_path,"\n")
151
+
152
+ vocab_char_map, vocab_size = get_tokenizer(file_vocab, tokenizer)
153
  model = CFM(
154
  transformer=model_cls(
155
  **model_cfg, text_num_embeds=vocab_size, mel_dim=n_mel_channels
 
169
 
170
  return model
171
 
 
172
  # load models
173
  F5TTS_model_cfg = dict(
174
  dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4
175
  )
176
  E2TTS_model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4)
177
 
 
178
  def chunk_text(text, max_chars=135):
179
  """
180
  Splits the input text into chunks, each with a maximum number of characters.
 
202
 
203
  return chunks
204
 
205
+ #ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors
206
+ #if not Path(ckpt_path).exists():
207
+ #ckpt_path = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
208
 
209
+ def infer_batch(ref_audio, ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
210
  if model == "F5-TTS":
211
+
212
+ if ckpt_file == "":
213
+ repo_name= "F5-TTS"
214
+ exp_name = "F5TTS_Base"
215
+ ckpt_step= 1200000
216
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
217
+
218
+ ema_model = load_model(DiT, F5TTS_model_cfg, ckpt_file,file_vocab)
219
+
220
  elif model == "E2-TTS":
221
+ if ckpt_file == "":
222
+ repo_name= "E2-TTS"
223
+ exp_name = "E2TTS_Base"
224
+ ckpt_step= 1200000
225
+ ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
226
+
227
+ ema_model = load_model(UNetT, E2TTS_model_cfg, ckpt_file,file_vocab)
228
 
229
  audio, sr = ref_audio
230
  if audio.shape[0] > 1:
 
360
  print("Using custom reference text...")
361
  return ref_audio, ref_text
362
 
363
+ def infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration=0.15):
364
  print(gen_text)
365
  # Add the functionality to ensure it ends with ". "
366
  if not ref_text.endswith(". ") and not ref_text.endswith("。"):
 
378
  print(f'gen_text {i}', gen_text)
379
 
380
  print(f"Generating audio using {model} in {len(gen_text_batches)} batches, loading models...")
381
+ return infer_batch((audio, sr), ref_text, gen_text_batches, model,ckpt_file,file_vocab, remove_silence, cross_fade_duration)
382
 
383
 
384
+ def process(ref_audio, ref_text, text_gen, model,ckpt_file,file_vocab, remove_silence):
385
  main_voice = {"ref_audio":ref_audio, "ref_text":ref_text}
386
  if "voices" not in config:
387
  voices = {"main": main_voice}
 
406
  ref_audio = voices[voice]['ref_audio']
407
  ref_text = voices[voice]['ref_text']
408
  print(f"Voice: {voice}")
409
+ audio, spectragram = infer(ref_audio, ref_text, gen_text, model,ckpt_file,file_vocab, remove_silence)
410
  generated_audio_segments.append(audio)
411
 
412
  if generated_audio_segments:
 
424
  aseg.export(f.name, format="wav")
425
  print(f.name)
426
 
427
+
428
+ process(ref_audio, ref_text, gen_text, model,ckpt_file,vocab_file, remove_silence)