StableTTS1.1 / models /model.py
KdaiP's picture
Upload 80 files
3dd84f8 verified
raw
history blame
9.07 kB
import math
import torch
import torch.nn as nn
import monotonic_align
from models.text_encoder import TextEncoder
from models.flow_matching import CFMDecoder
from models.reference_encoder import MelStyleEncoder
from models.duration_predictor import DurationPredictor, duration_loss
from utils.mask import sequence_mask
def convert_pad_shape(pad_shape):
inverted_shape = pad_shape[::-1]
pad_shape = [item for sublist in inverted_shape for item in sublist]
return pad_shape
def generate_path(duration, mask):
b, t_x, t_y = mask.shape
cum_duration = torch.cumsum(duration, 1)
path = torch.zeros(b, t_x, t_y, dtype=mask.dtype, device=duration.device)
cum_duration_flat = cum_duration.view(b * t_x)
path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
path = path.view(b, t_x, t_y)
path = path - torch.nn.functional.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
path = path * mask
return path
# modified from https://github.com/shivammehta25/Matcha-TTS/blob/main/matcha/models/matcha_tts.py
class StableTTS(nn.Module):
def __init__(self, n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, n_dec_layers, kernel_size, p_dropout, gin_channels):
super().__init__()
self.n_vocab = n_vocab
self.mel_channels = mel_channels
self.encoder = TextEncoder(n_vocab, mel_channels, hidden_channels, filter_channels, n_heads, n_enc_layers, kernel_size, p_dropout, gin_channels)
self.ref_encoder = MelStyleEncoder(mel_channels, style_vector_dim=gin_channels, style_kernel_size=5, dropout=0.25)
self.dp = DurationPredictor(hidden_channels, filter_channels, kernel_size, 0.5, gin_channels)
self.decoder = CFMDecoder(mel_channels, mel_channels, hidden_channels, mel_channels, filter_channels, n_heads, n_dec_layers, kernel_size, p_dropout, gin_channels)
# uncondition input for cfg
self.fake_speaker = nn.Parameter(torch.zeros(1, gin_channels))
self.fake_content = nn.Parameter(torch.zeros(1, mel_channels, 1))
self.cfg_dropout = 0.2
@torch.inference_mode()
def synthesise(self, x, x_lengths, n_timesteps, temperature=1.0, y=None, length_scale=1.0, solver=None, cfg=1.0):
"""
Generates mel-spectrogram from text. Returns:
1. encoder outputs
2. decoder outputs
3. generated alignment
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): lengths of texts in batch.
shape: (batch_size,)
n_timesteps (int): number of steps to use for reverse diffusion in decoder.
temperature (float, optional): controls variance of terminal distribution.
y (torch.Tensor): mel spectrogram of reference audio
shape: (batch_size, mel_channels, time)
length_scale (float, optional): controls speech pace.
Increase value to slow down generated speech and vice versa.
Returns:
dict: {
"encoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Average mel spectrogram generated by the encoder
"decoder_outputs": torch.Tensor, shape: (batch_size, n_feats, max_mel_length),
# Refined mel spectrogram improved by the CFM
"attn": torch.Tensor, shape: (batch_size, max_text_length, max_mel_length),
# Alignment map between text and mel spectrogram
"""
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
c = self.ref_encoder(y, None)
x, mu_x, x_mask = self.encoder(x, c, x_lengths)
logw = self.dp(x, x_mask, c)
w = torch.exp(logw) * x_mask
w_ceil = torch.ceil(w) * length_scale
y_lengths = torch.clamp_min(torch.sum(w_ceil, [1, 2]), 1).long()
y_max_length = y_lengths.max()
# Using obtained durations `w` construct alignment map `attn`
y_mask = sequence_mask(y_lengths, y_max_length).unsqueeze(1).to(x_mask.dtype)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
attn = generate_path(w_ceil.squeeze(1), attn_mask.squeeze(1)).unsqueeze(1)
# Align encoded text and get mu_y
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
encoder_outputs = mu_y[:, :, :y_max_length]
# Generate sample tracing the probability flow
if cfg == 1.0:
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c, solver)
else:
cfg_kwargs = {'fake_speaker': self.fake_speaker, 'fake_content': self.fake_content, 'cfg_strength': cfg}
decoder_outputs = self.decoder(mu_y, y_mask, n_timesteps, temperature, c, solver, cfg_kwargs)
decoder_outputs = decoder_outputs[:, :, :y_max_length]
return {
"encoder_outputs": encoder_outputs,
"decoder_outputs": decoder_outputs,
"attn": attn[:, :, :y_max_length],
}
def forward(self, x, x_lengths, y, y_lengths, z, z_lengths):
"""
Computes 3 losses:
1. duration loss: loss between predicted token durations and those extracted by Monotinic Alignment Search (MAS).
2. prior loss: loss between mel-spectrogram and encoder outputs.
3. flow matching loss: loss between mel-spectrogram and decoder outputs.
Args:
x (torch.Tensor): batch of texts, converted to a tensor with phoneme embedding ids.
shape: (batch_size, max_text_length)
x_lengths (torch.Tensor): lengths of texts in batch.
shape: (batch_size,)
y (torch.Tensor): batch of corresponding mel-spectrograms.
shape: (batch_size, n_feats, max_mel_length)
y_lengths (torch.Tensor): lengths of mel-spectrograms in batch.
shape: (batch_size,)
z (torch.Tensor): batch of cliced mel-spectrograms.
shape: (batch_size, n_feats, max_mel_length)
z_lengths (torch.Tensor): lengths of sliced mel-spectrograms in batch.
shape: (batch_size,)
"""
# Get encoder_outputs `mu_x` and log-scaled token durations `logw`
y_mask = sequence_mask(y_lengths, y.size(2)).unsqueeze(1).to(y.dtype)
z_mask = sequence_mask(z_lengths, z.size(2)).unsqueeze(1).to(z.dtype)
cfg_mask = torch.rand(y.size(0), 1, device=y.device) > self.cfg_dropout
# compute global speaker embedding
c = self.ref_encoder(z, z_mask) * cfg_mask + ~cfg_mask * self.fake_speaker.repeat(z.size(0), 1)
x, mu_x, x_mask = self.encoder(x, c, x_lengths)
logw = self.dp(x, x_mask, c)
attn_mask = x_mask.unsqueeze(-1) * y_mask.unsqueeze(2)
# Use MAS to find most likely alignment `attn` between text and mel-spectrogram
with torch.no_grad():
s_p_sq_r = torch.ones_like(mu_x) # [b, d, t]
neg_cent1 = torch.sum(-0.5 * math.log(2 * math.pi)- torch.zeros_like(mu_x), [1], keepdim=True)
neg_cent2 = torch.einsum("bdt, bds -> bts", -0.5 * (y**2), s_p_sq_r)
neg_cent3 = torch.einsum("bdt, bds -> bts", y, (mu_x * s_p_sq_r))
neg_cent4 = torch.sum(-0.5 * (mu_x**2) * s_p_sq_r, [1], keepdim=True)
neg_cent = neg_cent1 + neg_cent2 + neg_cent3 + neg_cent4
attn_mask = torch.unsqueeze(x_mask, 2) * torch.unsqueeze(y_mask, -1)
attn = (monotonic_align.maximum_path(neg_cent, attn_mask.squeeze(1)).unsqueeze(1).detach())
# Compute loss between predicted log-scaled durations and those obtained from MAS
# refered to as prior loss in the paper
logw_ = torch.log(1e-8 + attn.sum(2)) * x_mask
dur_loss = duration_loss(logw, logw_, x_lengths)
# Align encoded text with mel-spectrogram and get mu_y segment
attn = attn.squeeze(1).transpose(1,2)
mu_y = torch.matmul(attn.squeeze(1).transpose(1, 2), mu_x.transpose(1, 2))
mu_y = mu_y.transpose(1, 2)
# Compute loss of the decoder
cfg_mask = cfg_mask.unsqueeze(-1)
mu_y_masked = mu_y * cfg_mask + ~cfg_mask * self.fake_content.repeat(mu_y.size(0), 1, mu_y.size(-1)) # mask content information for better diversity for flow-matching
diff_loss, _ = self.decoder.compute_loss(y, y_mask, mu_y_masked, c)
prior_loss = torch.sum(0.5 * ((y - mu_y) ** 2 + math.log(2 * math.pi)) * y_mask)
prior_loss = prior_loss / (torch.sum(y_mask) * self.mel_channels)
return dur_loss, diff_loss, prior_loss, attn