Spaces:
Runtime error
Runtime error
"""A2S model definition. | |
Copyright PolyAI Limited. | |
""" | |
from typing import Union | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import torch.optim as optim | |
from einops import rearrange | |
import constants as c | |
from modules import masking_logic | |
from modules.conformer import Conformer | |
from modules.masking_logic import (State, mask_by_random_topk, | |
sample_from_logits, state_init) | |
from utils import load_checkpoint | |
class Pheme(pl.LightningModule): | |
def __init__(self, hp): | |
super().__init__() | |
self.hp = hp | |
self.model = TTSConformer(hp) | |
self.cross_entropy = nn.CrossEntropyLoss( | |
label_smoothing=self.hp.label_smoothing, | |
ignore_index=self.hp.n_codes | |
) | |
if self.hp.pretrained_path: | |
self.load() | |
else: | |
self.apply(self.init_weights) | |
if self.hp.only_inference: | |
self.model.eval() | |
self.save_hyperparameters() | |
def load(self): | |
state_dict = load_checkpoint(self.hp.pretrained_path) | |
print(f"Parameters loaded from {self.hp.pretrained_path}") | |
self.load_state_dict(state_dict, strict=True) | |
def init_weights(self, module): | |
if isinstance(module, nn.Linear): | |
module.weight.data.normal_(mean=0.0, std=0.02) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
if isinstance(module, nn.Embedding): | |
module.weight.data.normal_(mean=0.0, std=0.02) | |
module._fill_padding_idx_with_zero() | |
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)): | |
module.bias.data.zero_() | |
module.weight.data.fill_(1.0) | |
elif isinstance(module, nn.Conv1d): | |
module.weight.data.normal_(mean=0.0, std=0.02) | |
if module.bias is not None: | |
module.bias.data.zero_() | |
def configure_optimizers(self): | |
optimizer_adam = optim.AdamW( | |
self.parameters(), lr=self.hp.lr, | |
betas=(self.hp.adam_beta1, self.hp.adam_beta2)) | |
# Learning rate scheduler | |
num_training_steps = self.hp.training_step | |
num_warmup_steps = self.hp.warmup_step | |
num_flat_steps = int(self.hp.optim_flat_percent * num_training_steps) | |
def lambda_lr(current_step: int): | |
if current_step < num_warmup_steps: | |
return float(current_step) / float(max(1, num_warmup_steps)) | |
elif current_step < (num_warmup_steps + num_flat_steps): | |
return 1.0 | |
return max( | |
0.0, | |
float(num_training_steps - current_step) | |
/ float( | |
max(1, num_training_steps - (num_warmup_steps + num_flat_steps)) # noqa | |
), | |
) | |
scheduler_adam = { | |
"scheduler": optim.lr_scheduler.LambdaLR( | |
optimizer_adam, lambda_lr), | |
"interval": "step", | |
} | |
return [optimizer_adam], [scheduler_adam] | |
def top_k_accuracy(self, y_true, y_pred_probabilities, k): | |
_, sorted_indices = torch.sort(y_pred_probabilities, descending=True) | |
# Get the top-k predictions | |
top_k_indices = sorted_indices[:, :k] | |
expanded_y_true = y_true.unsqueeze(1).expand_as(top_k_indices) | |
# Check if true labels exist in top-k predictions | |
hits = torch.sum(torch.eq(top_k_indices, expanded_y_true)) | |
accuracy = hits.item() / (len(y_true) + 1e-7) | |
return accuracy | |
def training_step(self, batch, batch_idx): | |
# Sample training level | |
rvq_level = torch.randint( | |
0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)).item() | |
target, chosen_tokens, _, _ = self.model( | |
batch["tts_quantize_input"], rvq_level, batch["semantic_tokens"], | |
batch["quantization_lengths"], | |
speaker_emb=batch["speaker"], | |
min_seq_length=batch["quantization_lengths"].min().item()) | |
# Mask targets and labels | |
mask = chosen_tokens | |
target = target[mask] | |
labels = batch["tts_quantize_input"][:, :, rvq_level] | |
labels = labels[mask] | |
loss = self.cross_entropy(target, labels) | |
acc = (target.argmax(-1) == labels).float().mean() | |
self.log("train/loss", loss, on_step=True, prog_bar=True) | |
self.log("train/acc", acc, on_step=True, prog_bar=True) | |
self.log( | |
f"train/acc_lvl_{rvq_level}", acc, on_step=True, prog_bar=False) | |
return loss | |
def validation_step(self, batch, batch_idx, dataloader_idx=0): | |
speaker_emb = batch["speaker"] | |
acoustic_tokens = batch["tts_quantize_input"] | |
semantic_tokens = batch["semantic_tokens"] | |
if self.hp.only_inference: | |
self.inference( | |
acoustic_tokens, semantic_tokens, self.hp.first_n_lvls) | |
else: | |
rvq_level = torch.randint( | |
0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,) | |
).item() | |
# FIXME: edge case | |
if len(semantic_tokens.shape) == 3: | |
semantic_tokens = rearrange(semantic_tokens, "B 1 T -> B T") | |
target, chosen_tokens, _, _ = self.model( | |
acoustic_tokens, rvq_level, semantic_tokens, | |
torch.tensor([acoustic_tokens.shape[1]]).to(self.device), | |
speaker_emb=speaker_emb, | |
min_seq_length=acoustic_tokens.shape[1] | |
) | |
target = target[chosen_tokens] | |
labels = acoustic_tokens[:, :, rvq_level][chosen_tokens] | |
loss = self.cross_entropy(target, labels) | |
acc = (target.argmax(-1) == labels).float().mean() | |
acc_5 = self.top_k_accuracy(labels, target, 5) | |
self.log( | |
f"val/dataset_{dataloader_idx}/loss", | |
loss, | |
on_epoch=True, | |
logger=True, | |
add_dataloader_idx=False, | |
) | |
self.log( | |
f"val/dataset_{dataloader_idx}/acc_lvl", | |
acc, | |
on_epoch=True, | |
logger=True, | |
add_dataloader_idx=False, | |
) | |
self.log( | |
f"val/dataset_{dataloader_idx}/acc_lvl_{rvq_level}", | |
acc, | |
on_epoch=True, | |
logger=True, | |
add_dataloader_idx=False, | |
) | |
self.log( | |
f"val/dataset_{dataloader_idx}/acc_top_5", | |
acc_5, | |
on_epoch=True, | |
logger=True, | |
add_dataloader_idx=False, | |
) | |
self.log( | |
f"val/dataset_{dataloader_idx}/acc_top_5_lvl_{rvq_level}", | |
acc_5, | |
on_epoch=True, | |
logger=True, | |
add_dataloader_idx=False, | |
) | |
def compute_stats(self, logits, labels, mask_ratio=0, rvq_level=0): | |
acc = (logits.argmax(-1) == labels).float().mean() | |
acc_5 = self.top_k_accuracy(labels, logits, 5) | |
acc_10 = self.top_k_accuracy(labels, logits, 10) | |
idx = torch.randperm(logits.shape[0]) | |
logits_shuffled = logits[idx] | |
random = self.top_k_accuracy(labels, logits_shuffled, 10) | |
print(f"Mask ratio: {mask_ratio}, Level {rvq_level}: acc {acc}," | |
f"acc 5 {acc_5}, acc 10 {acc_10}, quasi random {random}") | |
class TTSConformer(pl.LightningModule): | |
def __init__(self, hp): | |
super().__init__() | |
self.hp = hp | |
self.padding_id = self.hp.n_codes | |
additional_codes = [c.PAD, c.SPKR_1, c.SPKR_2] | |
self.embedding = nn.ModuleList( | |
[ | |
nn.Embedding( | |
self.hp.n_codes + len(additional_codes), | |
self.hp.hidden_size, | |
padding_idx=self.padding_id) | |
for _ in range(self.hp.n_cluster_groups) | |
] | |
) | |
# Additional modules | |
self.semantic_embedding = nn.Embedding( | |
self.hp.n_semantic_codes + len(additional_codes), | |
self.hp.hidden_size, | |
padding_idx=self.padding_id) | |
if self.hp.use_spkr_emb: | |
self.spkr_linear = nn.Linear(c.SPKR_EMB_SIZE, self.hp.hidden_size) | |
self.conformer = Conformer( | |
dim=self.hp.hidden_size, | |
num_layers=self.hp.enc_nlayers, | |
heads=self.hp.nheads, | |
dim_head=64, | |
ff_mult=4, # 512*4=2048 | |
conv_expansion_factor=2, | |
conv_kernel_size=self.hp.depthwise_conv_kernel_size, | |
attn_dropout=self.hp.dropout, | |
ff_dropout=self.hp.dropout, | |
conv_dropout=self.hp.dropout, | |
attn_flash=True, | |
t5_rel_pos_bias=False | |
) | |
self.heads = nn.ModuleList( | |
[ | |
nn.Linear( | |
self.hp.hidden_size, | |
self.hp.n_codes + len(additional_codes) | |
) | |
for _ in range(self.hp.n_cluster_groups) | |
] | |
) | |
def build_mask_from_lengths(self, length, max_len=None): | |
max_len = max_len or length.max().item() | |
mask = torch.arange( | |
max_len, device=length.device)[None, :] >= length[:, None] | |
return mask.bool() | |
def create_mask( | |
self, B, T, lengths, mask_ratio=None, start_t=None, | |
min_seq_length=None | |
): | |
# 1. Define the random length of condition tokens given the shortest | |
# audio in the batch | |
if start_t is None: | |
start_t = torch.randint(1, min_seq_length - 1, (1,)).item() | |
# 2. Mask other tokens - sample different masking levels per | |
if mask_ratio is None: | |
ratio = torch.rand(1).item() | |
mask_ratio = masking_logic.schedule(ratio) | |
# Create a random tensor with values between 0 and 1 | |
random_tensor = torch.rand( | |
(B, T - start_t), dtype=torch.float).to(self.device) | |
# Create a mask where values less than p are set to True | |
initial_mask = random_tensor < mask_ratio | |
length_mask = self.build_mask_from_lengths( | |
lengths - start_t, T - start_t) | |
# we can't pick up tokens past token lengths | |
initial_mask = torch.logical_and(initial_mask, ~length_mask) | |
# Constrain ratio to always include some samples | |
# If all are False let's pick up at least one: | |
if torch.sum(initial_mask) == 0: | |
choose_steps = torch.randint(low=0, high=(T - start_t), size=(B,)) | |
initial_mask[torch.arange(B), choose_steps] = torch.tensor( | |
True, device=self.device) | |
# 3. Add condition tokens containing information | |
acoustic_token_mask = torch.cat( | |
(torch.full((B, start_t), False, device=self.device), initial_mask), # noqa | |
1 | |
) | |
return acoustic_token_mask, start_t, mask_ratio | |
def process_input( | |
self, data, lengths, rvq_level, min_seq_length=None, | |
mask_ratio=None, start_t=None, acoustic_token_mask=None | |
): | |
""" | |
data: (B, T, code_level, D) | |
rvq_level: int | |
""" | |
B = data.size(0) | |
T = data.size(1) | |
level_data = data[:, :, rvq_level, :] # [B, T, C, D] -> [B, T, D] | |
# Choose acoustic tokens to mask | |
if acoustic_token_mask is None: | |
acoustic_token_mask, start_t, mask_ratio = self.create_mask( | |
B, T, lengths, mask_ratio=mask_ratio, start_t=start_t, | |
min_seq_length=min_seq_length) | |
# Remove code information from chosen tokens | |
level_data[acoustic_token_mask, :] = 0 | |
# Embed only lower rvq_level | |
lower_code_data = data[:, :, :rvq_level, :].sum(dim=2) | |
# Combine with chosen tokens at rvq_level. | |
# Note: all tokens at rvq_level+1: will be discarded. | |
summed_data = torch.add(lower_code_data, level_data) | |
return summed_data, acoustic_token_mask, mask_ratio, start_t | |
def forward( | |
self, x, code_level, semantic_tokens, lengths, | |
speaker_emb=None, min_seq_length=10, mask_ratio=None, start_t=None, | |
acoustic_token_mask=None | |
): | |
# FIXME: parallelize this | |
batch = [] | |
for lvl, embed in enumerate(self.embedding[:(code_level + 1)]): | |
batch.append(embed(x[:, :, lvl])) # [B T D] | |
x = torch.stack(batch, dim=2) # [B T C D] | |
x, acoustic_token_mask, mask_ratio, start_t = self.process_input( | |
x, lengths, code_level, min_seq_length=min_seq_length, | |
mask_ratio=mask_ratio, start_t=start_t, | |
acoustic_token_mask=acoustic_token_mask | |
) | |
# Add phoneme embeddings | |
# Cross attention for all tokens? | |
# Add semantic tokens | |
# HACK ME | |
semantic_emb = self.semantic_embedding(semantic_tokens) | |
x = torch.add(x, semantic_emb) | |
# FIXME pfb30 | |
# Merge different modalities | |
if self.hp.use_spkr_emb: | |
spkr_emb = F.normalize(speaker_emb, dim=-1) | |
spkr_emb = self.spkr_linear( | |
F.dropout(spkr_emb, self.hp.speaker_embed_dropout) | |
) | |
x = torch.add(x, spkr_emb) | |
output_frames = self.conformer(x, None) | |
x = self.heads[code_level](output_frames) | |
return x, acoustic_token_mask, mask_ratio, start_t | |
def inference( | |
self, codes, semantic_tokens, | |
length: torch.LongTensor, rvq_levels=7, | |
mask_ratio=0.99, maskgit_inference=True, | |
start_t: Union[torch.LongTensor, None] = None, | |
speaker_emb=None, steps=16 | |
): | |
# Use half of the recording for the conditioning | |
if start_t is None: | |
start_t = torch.tensor(int((codes.shape[1]) / 2)).long() | |
start_t = start_t.item() | |
for rvq_level in range(rvq_levels): | |
original_codes = torch.clone(codes) | |
if rvq_level == 0 and maskgit_inference: | |
codes = self.multi_step_inference( | |
original_codes, semantic_tokens, length, | |
start_t=start_t, vamp_filtering=False, | |
speaker_emb=speaker_emb, steps=16 | |
) | |
else: | |
codes = self.one_step_inference( | |
original_codes, semantic_tokens, length, | |
code_level=rvq_level, | |
mask_ratio=mask_ratio, start_t=start_t, | |
speaker_emb=speaker_emb | |
) | |
codes = rearrange(codes, 'T C -> 1 T C') | |
# Remove any padding left | |
codes = rearrange(codes, '1 T C -> 1 C T') | |
codes = torch.where(codes >= self.hp.n_codes, 0, codes) | |
acoustic_tokens = codes | |
semantic_tokens = rearrange(semantic_tokens, 'b c -> b 1 c') | |
semantic_tokens = torch.where( | |
semantic_tokens >= self.hp.n_codes, 0, semantic_tokens) | |
codes = torch.cat([semantic_tokens, acoustic_tokens], dim=1) | |
return codes | |
def one_step_inference( | |
self, original_codes, semantic_tokens, lengths, code_level=0, | |
mask_ratio=0.99, start_t=0, inference_setup="argmax", speaker_emb=None | |
): | |
codes = torch.clone(original_codes) | |
logits, _, _, _ = self.forward( | |
codes, code_level, semantic_tokens, lengths, | |
mask_ratio=mask_ratio, start_t=start_t, | |
speaker_emb=speaker_emb, acoustic_token_mask=False) | |
if inference_setup == "argmax": | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
top_indeces = torch.argmax(probs, dim=-1) | |
if inference_setup == "sampling": | |
top_indeces = torch.distributions.Categorical( | |
logits=logits).sample() | |
codes = rearrange(codes, '1 T C -> T C') | |
codes[start_t:, code_level] = top_indeces[0, start_t:] | |
return codes | |
def multi_step_inference( | |
self, original_codes, semantic_tokens, lengths, | |
start_t: torch.LongTensor=None, | |
choice_temperature=1.0, start_iter=0, | |
steps=16, vamp_filtering=False, speaker_emb=None | |
): | |
codes = torch.clone(original_codes) | |
code_level = 0 | |
_, seq_len, _ = original_codes.shape | |
mask_token_id = self.padding_id | |
# Get true codes for the prompt | |
prompt_mask = codes[:, :start_t, code_level] | |
# Fill up rest with masks | |
mask = torch.full( | |
(1, seq_len - start_t), mask_token_id, device=self.device) | |
inputs = torch.cat((prompt_mask, mask), 1) | |
num_mask_tokens_at_start = torch.sum(inputs == mask_token_id, axis=-1) | |
# Initializes state | |
state = state_init(inputs, steps, start_iter=start_iter) | |
def loop_cond_fn(state): | |
"""Beam search loop termination condition.""" | |
not_at_end = (state.cur_index < steps) | |
return not_at_end | |
while loop_cond_fn(state): | |
"""Beam search loop state update function.""" | |
step = state.cur_index | |
# Current input ids: [batch_size, seq_length]. | |
cur_ids = state.cur_seqs | |
# Calls model on current seqs to get next-iteration seqs. | |
with torch.no_grad(): | |
logits, _, _, _ = self.forward( | |
rearrange(inputs, 'B T -> B T 1'), | |
code_level, | |
semantic_tokens, lengths, | |
acoustic_token_mask=False, | |
speaker_emb=speaker_emb) | |
# Samples the ids using categorical sampling: | |
if vamp_filtering: | |
typical_mass = 0.2 | |
typical_min_tokens = 1 | |
top_p = None | |
sample_cutoff = 0.5 | |
typical_filtering = False | |
sampled_ids, selected_probs = sample_from_logits( | |
logits, sample=((step / steps) <= sample_cutoff), | |
temperature=choice_temperature, | |
typical_filtering=typical_filtering, | |
typical_mass=typical_mass, | |
typical_min_tokens=typical_min_tokens, | |
top_k=None, top_p=top_p, return_probs=True, | |
) | |
else: | |
sampled_ids = torch.distributions.Categorical( | |
logits=logits).sample() | |
# Just updates the masked tokens. | |
unknown_map = (cur_ids == mask_token_id) | |
sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids) | |
# Defines the mask ratio for the next round. The number to mask out | |
# is determined by mask_ratio * unknown_number_in_the_beginning. | |
ratio = 1. * (step + 1) / steps | |
mask_ratio = masking_logic.schedule(ratio) | |
# Updates final seqs with the current sampled_ids. | |
final_seqs = torch.clone(state.final_seqs) | |
final_seqs[:, step, :] = sampled_ids | |
# Computes the probabilities of each selected tokens. | |
probs = torch.nn.functional.softmax(logits, dim=-1) | |
# Extract the probabilities of sampled ids | |
selected_probs = torch.squeeze( | |
torch.take_along_dim( | |
probs, torch.unsqueeze(sampled_ids, -1) , -1), | |
-1 | |
) | |
# Ignores the tokens given in the input | |
# by overwriting their confidence. | |
selected_probs = torch.where( | |
unknown_map, selected_probs, torch.inf) | |
# Gets mask lens for each sample in the | |
# batch according to the mask ratio. | |
num_to_mask = torch.unsqueeze( | |
torch.floor(num_mask_tokens_at_start * mask_ratio), 1) | |
# Keeps at least one of prediction in this | |
# round and also masks out at least | |
# one and for the next iteration | |
num_to_mask = torch.maximum( | |
torch.tensor(1), | |
torch.minimum( | |
torch.sum(unknown_map, dim=-1, keepdim=True) - 1, | |
num_to_mask) | |
) | |
# Adds noise for randomness | |
masking = mask_by_random_topk( | |
num_to_mask, selected_probs, choice_temperature * (1. - ratio)) | |
# Masks tokens with lower confidence. | |
sampled_ids = torch.where(masking, mask_token_id, sampled_ids) | |
state = State( | |
cur_index=state.cur_index + 1, | |
cur_seqs=sampled_ids, | |
final_seqs=final_seqs | |
) | |
codes = torch.clone(original_codes) | |
codes = rearrange(codes, '1 T C -> T C') | |
codes[:, 0] = state.final_seqs[0][-1] | |
return codes | |