|
import random |
|
import sys |
|
import tqdm |
|
from importlib.resources import files |
|
|
|
import soundfile as sf |
|
import torch |
|
from cached_path import cached_path |
|
|
|
from f5_tts.model import DiT, UNetT |
|
from f5_tts.model.utils import seed_everything |
|
from f5_tts.infer.utils_infer import ( |
|
load_vocoder, |
|
load_model, |
|
infer_process, |
|
remove_silence_for_generated_wav, |
|
save_spectrogram, |
|
) |
|
|
|
|
|
class F5TTS: |
|
def __init__( |
|
self, |
|
model_type="F5-TTS", |
|
ckpt_file="", |
|
vocab_file="", |
|
ode_method="euler", |
|
use_ema=True, |
|
local_path=None, |
|
device=None, |
|
): |
|
|
|
self.final_wave = None |
|
self.target_sample_rate = 24000 |
|
self.n_mel_channels = 100 |
|
self.hop_length = 256 |
|
self.target_rms = 0.1 |
|
self.seed = -1 |
|
|
|
|
|
self.device = device or ( |
|
"cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" |
|
) |
|
|
|
|
|
self.load_vocoder_model(local_path) |
|
self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema) |
|
|
|
def load_vocoder_model(self, local_path): |
|
self.vocos = load_vocoder(local_path is not None, local_path, self.device) |
|
|
|
def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema): |
|
if model_type == "F5-TTS": |
|
if not ckpt_file: |
|
ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) |
|
model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) |
|
model_cls = DiT |
|
elif model_type == "E2-TTS": |
|
if not ckpt_file: |
|
ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) |
|
model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) |
|
model_cls = UNetT |
|
else: |
|
raise ValueError(f"Unknown model type: {model_type}") |
|
|
|
self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device) |
|
|
|
def export_wav(self, wav, file_wave, remove_silence=False): |
|
sf.write(file_wave, wav, self.target_sample_rate) |
|
|
|
if remove_silence: |
|
remove_silence_for_generated_wav(file_wave) |
|
|
|
def export_spectrogram(self, spect, file_spect): |
|
save_spectrogram(spect, file_spect) |
|
|
|
def infer( |
|
self, |
|
ref_file, |
|
ref_text, |
|
gen_text, |
|
show_info=print, |
|
progress=tqdm, |
|
target_rms=0.1, |
|
cross_fade_duration=0.15, |
|
sway_sampling_coef=-1, |
|
cfg_strength=2, |
|
nfe_step=32, |
|
speed=1.0, |
|
fix_duration=None, |
|
remove_silence=False, |
|
file_wave=None, |
|
file_spect=None, |
|
seed=-1, |
|
): |
|
if seed == -1: |
|
seed = random.randint(0, sys.maxsize) |
|
seed_everything(seed) |
|
self.seed = seed |
|
wav, sr, spect = infer_process( |
|
ref_file, |
|
ref_text, |
|
gen_text, |
|
self.ema_model, |
|
show_info=show_info, |
|
progress=progress, |
|
target_rms=target_rms, |
|
cross_fade_duration=cross_fade_duration, |
|
nfe_step=nfe_step, |
|
cfg_strength=cfg_strength, |
|
sway_sampling_coef=sway_sampling_coef, |
|
speed=speed, |
|
fix_duration=fix_duration, |
|
device=self.device, |
|
) |
|
|
|
if file_wave is not None: |
|
self.export_wav(wav, file_wave, remove_silence) |
|
|
|
if file_spect is not None: |
|
self.export_spectrogram(spect, file_spect) |
|
|
|
return wav, sr, spect |
|
|
|
|
|
if __name__ == "__main__": |
|
f5tts = F5TTS() |
|
|
|
wav, sr, spect = f5tts.infer( |
|
ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), |
|
ref_text="some call me nature, others call me mother nature.", |
|
gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", |
|
file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")), |
|
file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")), |
|
seed=-1, |
|
) |
|
|
|
print("seed :", f5tts.seed) |
|
|