mrfakename commited on
Commit
35005eb
1 Parent(s): 2669b3f

Sync from GitHub repo

Browse files

This Space is synced from the GitHub repo: https://github.com/SWivid/F5-TTS. Please submit contributions to the Space there

Files changed (1) hide show
  1. model/utils_infer.py +25 -7
model/utils_infer.py CHANGED
@@ -19,8 +19,14 @@ from model.utils import (
19
  convert_char_to_pinyin,
20
  )
21
 
22
- device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
23
- print(f"Using {device} device")
 
 
 
 
 
 
24
 
25
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
26
 
@@ -76,7 +82,9 @@ def chunk_text(text, max_chars=135):
76
 
77
 
78
  # load vocoder
79
- def load_vocoder(is_local=False, local_path="", device=device):
 
 
80
  if is_local:
81
  print(f"Load vocos from local path {local_path}")
82
  vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
@@ -94,8 +102,10 @@ def load_vocoder(is_local=False, local_path="", device=device):
94
  asr_pipe = None
95
 
96
 
97
- def initialize_asr_pipeline(device=device):
98
  global asr_pipe
 
 
99
 
100
  asr_pipe = pipeline(
101
  "automatic-speech-recognition",
@@ -108,7 +118,9 @@ def initialize_asr_pipeline(device=device):
108
  # load model for inference
109
 
110
 
111
- def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=device):
 
 
112
  if vocab_file == "":
113
  vocab_file = "Emilia_ZH_EN"
114
  tokenizer = "pinyin"
@@ -141,7 +153,9 @@ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_me
141
  # preprocess reference audio and text
142
 
143
 
144
- def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=device):
 
 
145
  show_info("Converting audio...")
146
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
147
  aseg = AudioSegment.from_file(ref_audio_orig)
@@ -243,7 +257,11 @@ def infer_batch_process(
243
  sway_sampling_coef=-1,
244
  speed=1,
245
  fix_duration=None,
 
246
  ):
 
 
 
247
  audio, sr = ref_audio
248
  if audio.shape[0] > 1:
249
  audio = torch.mean(audio, dim=0, keepdim=True)
@@ -254,7 +272,7 @@ def infer_batch_process(
254
  if sr != target_sample_rate:
255
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
256
  audio = resampler(audio)
257
- audio = audio.to(device)
258
 
259
  generated_waves = []
260
  spectrograms = []
 
19
  convert_char_to_pinyin,
20
  )
21
 
22
+ # get device
23
+
24
+
25
+ def get_device():
26
+ device = "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu"
27
+ # print(f"Using {device} device")
28
+ return device
29
+
30
 
31
  vocos = Vocos.from_pretrained("charactr/vocos-mel-24khz")
32
 
 
82
 
83
 
84
  # load vocoder
85
+ def load_vocoder(is_local=False, local_path="", device=None):
86
+ if device is None:
87
+ device = get_device()
88
  if is_local:
89
  print(f"Load vocos from local path {local_path}")
90
  vocos = Vocos.from_hparams(f"{local_path}/config.yaml")
 
102
  asr_pipe = None
103
 
104
 
105
+ def initialize_asr_pipeline(device=None):
106
  global asr_pipe
107
+ if device is None:
108
+ device = get_device()
109
 
110
  asr_pipe = pipeline(
111
  "automatic-speech-recognition",
 
118
  # load model for inference
119
 
120
 
121
+ def load_model(model_cls, model_cfg, ckpt_path, vocab_file="", ode_method=ode_method, use_ema=True, device=None):
122
+ if device is None:
123
+ device = get_device()
124
  if vocab_file == "":
125
  vocab_file = "Emilia_ZH_EN"
126
  tokenizer = "pinyin"
 
153
  # preprocess reference audio and text
154
 
155
 
156
+ def preprocess_ref_audio_text(ref_audio_orig, ref_text, show_info=print, device=None):
157
+ device = get_device(device)
158
+
159
  show_info("Converting audio...")
160
  with tempfile.NamedTemporaryFile(delete=False, suffix=".wav") as f:
161
  aseg = AudioSegment.from_file(ref_audio_orig)
 
257
  sway_sampling_coef=-1,
258
  speed=1,
259
  fix_duration=None,
260
+ device=None,
261
  ):
262
+ if device is None:
263
+ device = get_device()
264
+
265
  audio, sr = ref_audio
266
  if audio.shape[0] > 1:
267
  audio = torch.mean(audio, dim=0, keepdim=True)
 
272
  if sr != target_sample_rate:
273
  resampler = torchaudio.transforms.Resample(sr, target_sample_rate)
274
  audio = resampler(audio)
275
+ audio = audio.to()
276
 
277
  generated_waves = []
278
  spectrograms = []