mrfakename commited on
Commit
118c154
1 Parent(s): fe296ca

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (3) hide show
  1. api.py +2 -2
  2. inference-cli.py +2 -2
  3. model/utils_infer.py +22 -10
api.py CHANGED
@@ -33,10 +33,10 @@ class F5TTS:
33
  )
34
 
35
  # Load models
36
- self.load_vecoder_model(local_path)
37
  self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
38
 
39
- def load_vecoder_model(self, local_path):
40
  self.vocos = load_vocoder(local_path is not None, local_path, self.device)
41
 
42
  def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
 
33
  )
34
 
35
  # Load models
36
+ self.load_vocoder_model(local_path)
37
  self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema)
38
 
39
+ def load_vocoder_model(self, local_path):
40
  self.vocos = load_vocoder(local_path is not None, local_path, self.device)
41
 
42
  def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema):
inference-cli.py CHANGED
@@ -104,7 +104,7 @@ if model == "F5-TTS":
104
  exp_name = "F5TTS_Base"
105
  ckpt_step = 1200000
106
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
107
- # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
108
 
109
  elif model == "E2-TTS":
110
  model_cls = UNetT
@@ -114,7 +114,7 @@ elif model == "E2-TTS":
114
  exp_name = "E2TTS_Base"
115
  ckpt_step = 1200000
116
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
117
- # ckpt_path = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
118
 
119
  print(f"Using {model}...")
120
  ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
 
104
  exp_name = "F5TTS_Base"
105
  ckpt_step = 1200000
106
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
107
+ # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
108
 
109
  elif model == "E2-TTS":
110
  model_cls = UNetT
 
114
  exp_name = "E2TTS_Base"
115
  ckpt_step = 1200000
116
  ckpt_file = str(cached_path(f"hf://SWivid/{repo_name}/{exp_name}/model_{ckpt_step}.safetensors"))
117
+ # ckpt_file = f"ckpts/{exp_name}/model_{ckpt_step}.pt" # .pt | .safetensors; local path
118
 
119
  print(f"Using {model}...")
120
  ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file)
model/utils_infer.py CHANGED
@@ -22,13 +22,6 @@ from model.utils import (
22
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
23
  print(f"Using {device} device")
24
 
25
- asr_pipe = pipeline(
26
- "automatic-speech-recognition",
27
- model="openai/whisper-large-v3-turbo",
28
- torch_dtype=torch.float16,
29
- device=device,
30
- )
31
-
32
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
33
 
34
 
@@ -82,8 +75,6 @@ def chunk_text(text, max_chars=135):
82
 
83
 
84
  # load vocoder
85
-
86
-
87
  def load_vocoder(is_local=False, local_path="", device=device):
88
  if is_local:
89
  print(f"Load vocos from local path {local_path}")
@@ -97,6 +88,22 @@ def load_vocoder(is_local=False, local_path="", device=device):
97
  return vocos
98
 
99
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
100
  # load model for inference
101
 
102
 
@@ -133,7 +140,7 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method="euler
133
  # preprocess reference audio and text
134
 
135
 
136
- def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
137
  show_info("Converting audio...")
138
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
139
  aseg = AudioSegment.from_file(ref_audio_orig)
@@ -152,6 +159,9 @@ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print):
152
  ref_audio = f.name
153
 
154
  if not ref_text.strip():
 
 
 
155
  show_info("No reference text provided, transcribing reference audio...")
156
  ref_text = asr_pipe(
157
  ref_audio,
@@ -329,6 +339,8 @@ def infer_batch_process(
329
 
330
 
331
  # remove silence from generated wav
 
 
332
  def remove_silence_for_generated_wav(filename):
333
  aseg = AudioSegment.from_file(filename)
334
  non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)
 
22
  device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
23
  print(f"Using {device} device")
24
 
 
 
 
 
 
 
 
25
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
26
 
27
 
 
75
 
76
 
77
  # load vocoder
 
 
78
  def load_vocoder(is_local=False, local_path="", device=device):
79
  if is_local:
80
  print(f"Load vocos from local path {local_path}")
 
88
  return vocos
89
 
90
 
91
+ # load asr pipeline
92
+
93
+ asr_pipe = None
94
+
95
+
96
+ def initialize_asr_pipeline(device=device):
97
+ global asr_pipe
98
+
99
+ asr_pipe = pipeline(
100
+ "automatic-speech-recognition",
101
+ model="openai/whisper-large",
102
+ torch_dtype=torch.float16,
103
+ device=device,
104
+ )
105
+
106
+
107
  # load model for inference
108
 
109
 
 
140
  # preprocess reference audio and text
141
 
142
 
143
+ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
144
  show_info("Converting audio...")
145
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
146
  aseg = AudioSegment.from_file(ref_audio_orig)
 
159
  ref_audio = f.name
160
 
161
  if not ref_text.strip():
162
+ global asr_pipe
163
+ if asr_pipe is None:
164
+ initialize_asr_pipeline(device=device)
165
  show_info("No reference text provided, transcribing reference audio...")
166
  ref_text = asr_pipe(
167
  ref_audio,
 
339
 
340
 
341
  # remove silence from generated wav
342
+
343
+
344
  def remove_silence_for_generated_wav(filename):
345
  aseg = AudioSegment.from_file(filename)
346
  non_silent_segs = silence.split_on_silence(aseg, min_silence_len=1000, silence_thresh=-50, keep_silence=500)