StableTTS1.1 / api.py
KdaiP's picture
Upload 80 files
3dd84f8 verified
raw
history blame
No virus
4.13 kB
import torch
import torch.nn as nn
from dataclasses import asdict
from utils.audio import LogMelSpectrogram
from config import ModelConfig, MelConfig
from models.model import StableTTS
from text import symbols
from text import cleaned_text_to_sequence
from text.mandarin import chinese_to_cnm3
from text.english import english_to_ipa2
from text.japanese import japanese_to_ipa2
from datas.dataset import intersperse
from utils.audio import load_and_resample_audio
def get_vocoder(model_path, model_name='ffgan') -> nn.Module:
if model_name == 'ffgan':
# training or changing ffgan config is not supported in this repo
# you can train your own model at https://github.com/fishaudio/vocoder
from vocoders.ffgan.model import FireflyGANBaseWrapper
vocoder = FireflyGANBaseWrapper(model_path)
elif model_name == 'vocos':
from vocoders.vocos.models.model import Vocos
from config import VocosConfig, MelConfig
vocoder = Vocos(VocosConfig(), MelConfig())
vocoder.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu'))
vocoder.eval()
else:
raise NotImplementedError(f"Unsupported model: {model_name}")
return vocoder
class StableTTSAPI(nn.Module):
def __init__(self, tts_model_path, vocoder_model_path, vocoder_name='ffgan'):
super().__init__()
self.mel_config = MelConfig()
self.tts_model_config = ModelConfig()
self.mel_extractor = LogMelSpectrogram(**asdict(self.mel_config))
# text to mel spectrogram
self.tts_model = StableTTS(len(symbols), self.mel_config.n_mels, **asdict(self.tts_model_config))
self.tts_model.load_state_dict(torch.load(tts_model_path, map_location='cpu', weights_only=True))
self.tts_model.eval()
# mel spectrogram to waveform
self.vocoder_model = get_vocoder(vocoder_model_path, vocoder_name)
self.vocoder_model.eval()
self.g2p_mapping = {
'chinese': chinese_to_cnm3,
'japanese': japanese_to_ipa2,
'english': english_to_ipa2,
}
self.supported_languages = self.g2p_mapping.keys()
@ torch.inference_mode()
def inference(self, text, ref_audio, language, step, temperature=1.0, length_scale=1.0, solver=None, cfg=3.0):
device = next(self.parameters()).device
phonemizer = self.g2p_mapping.get(language)
text = phonemizer(text)
text = torch.tensor(intersperse(cleaned_text_to_sequence(text), item=0), dtype=torch.long, device=device).unsqueeze(0)
text_length = torch.tensor([text.size(-1)], dtype=torch.long, device=device)
ref_audio = load_and_resample_audio(ref_audio, self.mel_config.sample_rate).to(device)
ref_audio = self.mel_extractor(ref_audio)
mel_output = self.tts_model.synthesise(text, text_length, step, temperature, ref_audio, length_scale, solver, cfg)['decoder_outputs']
audio_output = self.vocoder_model(mel_output)
return audio_output.cpu(), mel_output.cpu()
def get_params(self):
tts_param = sum(p.numel() for p in self.tts_model.parameters()) / 1e6
vocoder_param = sum(p.numel() for p in self.vocoder_model.parameters()) / 1e6
return tts_param, vocoder_param
if __name__ == '__main__':
device = 'cuda' if torch.cuda.is_available() else 'cpu'
tts_model_path = './checkpoints/checkpoint_0.pt'
vocoder_model_path = './vocoders/pretrained/vocos.pt'
model = StableTTSAPI(tts_model_path, vocoder_model_path, 'vocos')
model.to(device)
text = '樱落满殇祈念集……殇歌花落集思祈……樱花满地集于我心……揲舞纷飞祈愿相随……'
audio = './audio_1.wav'
audio_output, mel_output = model.inference(text, audio, 'chinese', 10, solver='dopri5', cfg=3)
print(audio_output.shape)
print(mel_output.shape)
import torchaudio
torchaudio.save('output.wav', audio_output, MelConfig().sample_rate)