import random import sys import tqdm from importlib.resources import files import soundfile as sf import torch from cached_path import cached_path from f5_tts.model import DiT, UNetT from f5_tts.model.utils import seed_everything from f5_tts.infer.utils_infer import ( load_vocoder, load_model, infer_process, remove_silence_for_generated_wav, save_spectrogram, ) class F5TTS: def __init__( self, model_type="F5-TTS", ckpt_file="", vocab_file="", ode_method="euler", use_ema=True, local_path=None, device=None, ): # Initialize parameters self.final_wave = None self.target_sample_rate = 24000 self.n_mel_channels = 100 self.hop_length = 256 self.target_rms = 0.1 self.seed = -1 # Set device self.device = device or ( "cuda" if torch.cuda.is_available() else "mps" if torch.backends.mps.is_available() else "cpu" ) # Load models self.load_vocoder_model(local_path) self.load_ema_model(model_type, ckpt_file, vocab_file, ode_method, use_ema) def load_vocoder_model(self, local_path): self.vocos = load_vocoder(local_path is not None, local_path, self.device) def load_ema_model(self, model_type, ckpt_file, vocab_file, ode_method, use_ema): if model_type == "F5-TTS": if not ckpt_file: ckpt_file = str(cached_path("hf://SWivid/F5-TTS/F5TTS_Base/model_1200000.safetensors")) model_cfg = dict(dim=1024, depth=22, heads=16, ff_mult=2, text_dim=512, conv_layers=4) model_cls = DiT elif model_type == "E2-TTS": if not ckpt_file: ckpt_file = str(cached_path("hf://SWivid/E2-TTS/E2TTS_Base/model_1200000.safetensors")) model_cfg = dict(dim=1024, depth=24, heads=16, ff_mult=4) model_cls = UNetT else: raise ValueError(f"Unknown model type: {model_type}") self.ema_model = load_model(model_cls, model_cfg, ckpt_file, vocab_file, ode_method, use_ema, self.device) def export_wav(self, wav, file_wave, remove_silence=False): sf.write(file_wave, wav, self.target_sample_rate) if remove_silence: remove_silence_for_generated_wav(file_wave) def export_spectrogram(self, spect, file_spect): save_spectrogram(spect, file_spect) def infer( self, ref_file, ref_text, gen_text, show_info=print, progress=tqdm, target_rms=0.1, cross_fade_duration=0.15, sway_sampling_coef=-1, cfg_strength=2, nfe_step=32, speed=1.0, fix_duration=None, remove_silence=False, file_wave=None, file_spect=None, seed=-1, ): if seed == -1: seed = random.randint(0, sys.maxsize) seed_everything(seed) self.seed = seed wav, sr, spect = infer_process( ref_file, ref_text, gen_text, self.ema_model, show_info=show_info, progress=progress, target_rms=target_rms, cross_fade_duration=cross_fade_duration, nfe_step=nfe_step, cfg_strength=cfg_strength, sway_sampling_coef=sway_sampling_coef, speed=speed, fix_duration=fix_duration, device=self.device, ) if file_wave is not None: self.export_wav(wav, file_wave, remove_silence) if file_spect is not None: self.export_spectrogram(spect, file_spect) return wav, sr, spect if __name__ == "__main__": f5tts = F5TTS() wav, sr, spect = f5tts.infer( ref_file=str(files("f5_tts").joinpath("infer/examples/basic/basic_ref_en.wav")), ref_text="some call me nature, others call me mother nature.", gen_text="""I don't really care what you call me. I've been a silent spectator, watching species evolve, empires rise and fall. But always remember, I am mighty and enduring. Respect me and I'll nurture you; ignore me and you shall face the consequences.""", file_wave=str(files("f5_tts").joinpath("../../tests/api_out.wav")), file_spect=str(files("f5_tts").joinpath("../../tests/api_out.png")), seed=-1, # random seed = -1 ) print("seed :", f5tts.seed)