Spaces:
Runtime error
Runtime error
import torch | |
import torch.nn as nn | |
from dataclasses import asdict | |
from utils.audio import LogMelSpectrogram | |
from config import ModelConfig, MelConfig | |
from models.model import StableTTS | |
from text import symbols | |
from text import cleaned_text_to_sequence | |
from text.mandarin import chinese_to_cnm3 | |
from text.english import english_to_ipa2 | |
from text.japanese import japanese_to_ipa2 | |
from datas.dataset import intersperse | |
from utils.audio import load_and_resample_audio | |
def get_vocoder(model_path, model_name='ffgan') -> nn.Module: | |
if model_name == 'ffgan': | |
# training or changing ffgan config is not supported in this repo | |
# you can train your own model at https://github.com/fishaudio/vocoder | |
from vocoders.ffgan.model import FireflyGANBaseWrapper | |
vocoder = FireflyGANBaseWrapper(model_path) | |
elif model_name == 'vocos': | |
from vocoders.vocos.models.model import Vocos | |
from config import VocosConfig, MelConfig | |
vocoder = Vocos(VocosConfig(), MelConfig()) | |
vocoder.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu')) | |
vocoder.eval() | |
else: | |
raise NotImplementedError(f"Unsupported model: {model_name}") | |
return vocoder | |
class StableTTSAPI(nn.Module): | |
def __init__(self, tts_model_path, vocoder_model_path, vocoder_name='ffgan'): | |
super().__init__() | |
self.mel_config = MelConfig() | |
self.tts_model_config = ModelConfig() | |
self.mel_extractor = LogMelSpectrogram(**asdict(self.mel_config)) | |
# text to mel spectrogram | |
self.tts_model = StableTTS(len(symbols), self.mel_config.n_mels, **asdict(self.tts_model_config)) | |
self.tts_model.load_state_dict(torch.load(tts_model_path, map_location='cpu', weights_only=True)) | |
self.tts_model.eval() | |
# mel spectrogram to waveform | |
self.vocoder_model = get_vocoder(vocoder_model_path, vocoder_name) | |
self.vocoder_model.eval() | |
self.g2p_mapping = { | |
'chinese': chinese_to_cnm3, | |
'japanese': japanese_to_ipa2, | |
'english': english_to_ipa2, | |
} | |
self.supported_languages = self.g2p_mapping.keys() | |
def inference(self, text, ref_audio, language, step, temperature=1.0, length_scale=1.0, solver=None, cfg=3.0): | |
device = next(self.parameters()).device | |
phonemizer = self.g2p_mapping.get(language) | |
text = phonemizer(text) | |
text = torch.tensor(intersperse(cleaned_text_to_sequence(text), item=0), dtype=torch.long, device=device).unsqueeze(0) | |
text_length = torch.tensor([text.size(-1)], dtype=torch.long, device=device) | |
ref_audio = load_and_resample_audio(ref_audio, self.mel_config.sample_rate).to(device) | |
ref_audio = self.mel_extractor(ref_audio) | |
mel_output = self.tts_model.synthesise(text, text_length, step, temperature, ref_audio, length_scale, solver, cfg)['decoder_outputs'] | |
audio_output = self.vocoder_model(mel_output) | |
return audio_output.cpu(), mel_output.cpu() | |
def get_params(self): | |
tts_param = sum(p.numel() for p in self.tts_model.parameters()) / 1e6 | |
vocoder_param = sum(p.numel() for p in self.vocoder_model.parameters()) / 1e6 | |
return tts_param, vocoder_param | |
if __name__ == '__main__': | |
device = 'cuda' if torch.cuda.is_available() else 'cpu' | |
tts_model_path = './checkpoints/checkpoint_0.pt' | |
vocoder_model_path = './vocoders/pretrained/vocos.pt' | |
model = StableTTSAPI(tts_model_path, vocoder_model_path, 'vocos') | |
model.to(device) | |
text = '樱落满殇祈念集……殇歌花落集思祈……樱花满地集于我心……揲舞纷飞祈愿相随……' | |
audio = './audio_1.wav' | |
audio_output, mel_output = model.inference(text, audio, 'chinese', 10, solver='dopri5', cfg=3) | |
print(audio_output.shape) | |
print(mel_output.shape) | |
import torchaudio | |
torchaudio.save('output.wav', audio_output, MelConfig().sample_rate) | |