File size: 4,125 Bytes
3dd84f8
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
import torch
import torch.nn as nn

from dataclasses import asdict

from utils.audio import LogMelSpectrogram
from config import ModelConfig, MelConfig
from models.model import StableTTS

from text import symbols
from text import cleaned_text_to_sequence
from text.mandarin import chinese_to_cnm3
from text.english import english_to_ipa2
from text.japanese import japanese_to_ipa2


from datas.dataset import intersperse
from utils.audio import load_and_resample_audio

def get_vocoder(model_path, model_name='ffgan') -> nn.Module:
    if model_name == 'ffgan':
        # training or changing ffgan config is not supported in this repo
        # you can train your own model at https://github.com/fishaudio/vocoder
        from vocoders.ffgan.model import FireflyGANBaseWrapper
        vocoder = FireflyGANBaseWrapper(model_path)
        
    elif model_name == 'vocos':
        from vocoders.vocos.models.model import Vocos
        from config import VocosConfig, MelConfig
        vocoder = Vocos(VocosConfig(), MelConfig())
        vocoder.load_state_dict(torch.load(model_path, weights_only=True, map_location='cpu'))
        vocoder.eval()
        
    else:
        raise NotImplementedError(f"Unsupported model: {model_name}")
        
    return vocoder

class StableTTSAPI(nn.Module):
    def __init__(self, tts_model_path, vocoder_model_path, vocoder_name='ffgan'):
        super().__init__()

        self.mel_config = MelConfig()
        self.tts_model_config = ModelConfig()
        
        self.mel_extractor = LogMelSpectrogram(**asdict(self.mel_config))
        
        # text to mel spectrogram
        self.tts_model = StableTTS(len(symbols), self.mel_config.n_mels, **asdict(self.tts_model_config))
        self.tts_model.load_state_dict(torch.load(tts_model_path, map_location='cpu', weights_only=True))
        self.tts_model.eval()
        
        # mel spectrogram to waveform
        self.vocoder_model = get_vocoder(vocoder_model_path, vocoder_name)
        self.vocoder_model.eval()
        
        self.g2p_mapping = {
            'chinese': chinese_to_cnm3,
            'japanese': japanese_to_ipa2,
            'english': english_to_ipa2,
        }
        self.supported_languages = self.g2p_mapping.keys()
        
    @ torch.inference_mode()
    def inference(self, text, ref_audio, language, step, temperature=1.0, length_scale=1.0, solver=None, cfg=3.0):
        device = next(self.parameters()).device
        phonemizer = self.g2p_mapping.get(language)
        
        text = phonemizer(text)
        text = torch.tensor(intersperse(cleaned_text_to_sequence(text), item=0), dtype=torch.long, device=device).unsqueeze(0)
        text_length = torch.tensor([text.size(-1)], dtype=torch.long, device=device)
        
        ref_audio = load_and_resample_audio(ref_audio, self.mel_config.sample_rate).to(device)
        ref_audio = self.mel_extractor(ref_audio)
        
        mel_output = self.tts_model.synthesise(text, text_length, step, temperature, ref_audio, length_scale, solver, cfg)['decoder_outputs']
        audio_output = self.vocoder_model(mel_output)
        return audio_output.cpu(), mel_output.cpu()
    
    def get_params(self):
        tts_param = sum(p.numel() for p in self.tts_model.parameters()) / 1e6
        vocoder_param = sum(p.numel() for p in self.vocoder_model.parameters()) / 1e6
        return tts_param, vocoder_param
    
if __name__ == '__main__':
    device = 'cuda' if torch.cuda.is_available() else 'cpu'
    tts_model_path = './checkpoints/checkpoint_0.pt'
    vocoder_model_path = './vocoders/pretrained/vocos.pt'
    
    model = StableTTSAPI(tts_model_path, vocoder_model_path, 'vocos')
    model.to(device)
    
    text = '樱落满殇祈念集……殇歌花落集思祈……樱花满地集于我心……揲舞纷飞祈愿相随……'
    audio = './audio_1.wav'
    
    audio_output, mel_output = model.inference(text, audio, 'chinese', 10, solver='dopri5', cfg=3)
    print(audio_output.shape)
    print(mel_output.shape)
    
    import torchaudio
    torchaudio.save('output.wav', audio_output, MelConfig().sample_rate)