pheme / modules /s2a_model.py
taras-sereda's picture
minimal set of files to run inference; pheme-small checkpoint
96ee597
"""A2S model definition.
Copyright PolyAI Limited.
"""
from typing import Union
import pytorch_lightning as pl
import torch
import torch.nn as nn
import torch.nn.functional as F
import torch.optim as optim
from einops import rearrange
import constants as c
from modules import masking_logic
from modules.conformer import Conformer
from modules.masking_logic import (State, mask_by_random_topk,
sample_from_logits, state_init)
from utils import load_checkpoint
class Pheme(pl.LightningModule):
def __init__(self, hp):
super().__init__()
self.hp = hp
self.model = TTSConformer(hp)
self.cross_entropy = nn.CrossEntropyLoss(
label_smoothing=self.hp.label_smoothing,
ignore_index=self.hp.n_codes
)
if self.hp.pretrained_path:
self.load()
else:
self.apply(self.init_weights)
if self.hp.only_inference:
self.model.eval()
self.save_hyperparameters()
def load(self):
state_dict = load_checkpoint(self.hp.pretrained_path)
print(f"Parameters loaded from {self.hp.pretrained_path}")
self.load_state_dict(state_dict, strict=True)
def init_weights(self, module):
if isinstance(module, nn.Linear):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
if isinstance(module, nn.Embedding):
module.weight.data.normal_(mean=0.0, std=0.02)
module._fill_padding_idx_with_zero()
elif isinstance(module, (nn.LayerNorm, nn.GroupNorm)):
module.bias.data.zero_()
module.weight.data.fill_(1.0)
elif isinstance(module, nn.Conv1d):
module.weight.data.normal_(mean=0.0, std=0.02)
if module.bias is not None:
module.bias.data.zero_()
def configure_optimizers(self):
optimizer_adam = optim.AdamW(
self.parameters(), lr=self.hp.lr,
betas=(self.hp.adam_beta1, self.hp.adam_beta2))
# Learning rate scheduler
num_training_steps = self.hp.training_step
num_warmup_steps = self.hp.warmup_step
num_flat_steps = int(self.hp.optim_flat_percent * num_training_steps)
def lambda_lr(current_step: int):
if current_step < num_warmup_steps:
return float(current_step) / float(max(1, num_warmup_steps))
elif current_step < (num_warmup_steps + num_flat_steps):
return 1.0
return max(
0.0,
float(num_training_steps - current_step)
/ float(
max(1, num_training_steps - (num_warmup_steps + num_flat_steps)) # noqa
),
)
scheduler_adam = {
"scheduler": optim.lr_scheduler.LambdaLR(
optimizer_adam, lambda_lr),
"interval": "step",
}
return [optimizer_adam], [scheduler_adam]
def top_k_accuracy(self, y_true, y_pred_probabilities, k):
_, sorted_indices = torch.sort(y_pred_probabilities, descending=True)
# Get the top-k predictions
top_k_indices = sorted_indices[:, :k]
expanded_y_true = y_true.unsqueeze(1).expand_as(top_k_indices)
# Check if true labels exist in top-k predictions
hits = torch.sum(torch.eq(top_k_indices, expanded_y_true))
accuracy = hits.item() / (len(y_true) + 1e-7)
return accuracy
def training_step(self, batch, batch_idx):
# Sample training level
rvq_level = torch.randint(
0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)).item()
target, chosen_tokens, _, _ = self.model(
batch["tts_quantize_input"], rvq_level, batch["semantic_tokens"],
batch["quantization_lengths"],
speaker_emb=batch["speaker"],
min_seq_length=batch["quantization_lengths"].min().item())
# Mask targets and labels
mask = chosen_tokens
target = target[mask]
labels = batch["tts_quantize_input"][:, :, rvq_level]
labels = labels[mask]
loss = self.cross_entropy(target, labels)
acc = (target.argmax(-1) == labels).float().mean()
self.log("train/loss", loss, on_step=True, prog_bar=True)
self.log("train/acc", acc, on_step=True, prog_bar=True)
self.log(
f"train/acc_lvl_{rvq_level}", acc, on_step=True, prog_bar=False)
return loss
def validation_step(self, batch, batch_idx, dataloader_idx=0):
speaker_emb = batch["speaker"]
acoustic_tokens = batch["tts_quantize_input"]
semantic_tokens = batch["semantic_tokens"]
if self.hp.only_inference:
self.inference(
acoustic_tokens, semantic_tokens, self.hp.first_n_lvls)
else:
rvq_level = torch.randint(
0, min(self.hp.first_n_lvls, self.hp.n_cluster_groups),(1,)
).item()
# FIXME: edge case
if len(semantic_tokens.shape) == 3:
semantic_tokens = rearrange(semantic_tokens, "B 1 T -> B T")
target, chosen_tokens, _, _ = self.model(
acoustic_tokens, rvq_level, semantic_tokens,
torch.tensor([acoustic_tokens.shape[1]]).to(self.device),
speaker_emb=speaker_emb,
min_seq_length=acoustic_tokens.shape[1]
)
target = target[chosen_tokens]
labels = acoustic_tokens[:, :, rvq_level][chosen_tokens]
loss = self.cross_entropy(target, labels)
acc = (target.argmax(-1) == labels).float().mean()
acc_5 = self.top_k_accuracy(labels, target, 5)
self.log(
f"val/dataset_{dataloader_idx}/loss",
loss,
on_epoch=True,
logger=True,
add_dataloader_idx=False,
)
self.log(
f"val/dataset_{dataloader_idx}/acc_lvl",
acc,
on_epoch=True,
logger=True,
add_dataloader_idx=False,
)
self.log(
f"val/dataset_{dataloader_idx}/acc_lvl_{rvq_level}",
acc,
on_epoch=True,
logger=True,
add_dataloader_idx=False,
)
self.log(
f"val/dataset_{dataloader_idx}/acc_top_5",
acc_5,
on_epoch=True,
logger=True,
add_dataloader_idx=False,
)
self.log(
f"val/dataset_{dataloader_idx}/acc_top_5_lvl_{rvq_level}",
acc_5,
on_epoch=True,
logger=True,
add_dataloader_idx=False,
)
def compute_stats(self, logits, labels, mask_ratio=0, rvq_level=0):
acc = (logits.argmax(-1) == labels).float().mean()
acc_5 = self.top_k_accuracy(labels, logits, 5)
acc_10 = self.top_k_accuracy(labels, logits, 10)
idx = torch.randperm(logits.shape[0])
logits_shuffled = logits[idx]
random = self.top_k_accuracy(labels, logits_shuffled, 10)
print(f"Mask ratio: {mask_ratio}, Level {rvq_level}: acc {acc},"
f"acc 5 {acc_5}, acc 10 {acc_10}, quasi random {random}")
class TTSConformer(pl.LightningModule):
def __init__(self, hp):
super().__init__()
self.hp = hp
self.padding_id = self.hp.n_codes
additional_codes = [c.PAD, c.SPKR_1, c.SPKR_2]
self.embedding = nn.ModuleList(
[
nn.Embedding(
self.hp.n_codes + len(additional_codes),
self.hp.hidden_size,
padding_idx=self.padding_id)
for _ in range(self.hp.n_cluster_groups)
]
)
# Additional modules
self.semantic_embedding = nn.Embedding(
self.hp.n_semantic_codes + len(additional_codes),
self.hp.hidden_size,
padding_idx=self.padding_id)
if self.hp.use_spkr_emb:
self.spkr_linear = nn.Linear(c.SPKR_EMB_SIZE, self.hp.hidden_size)
self.conformer = Conformer(
dim=self.hp.hidden_size,
num_layers=self.hp.enc_nlayers,
heads=self.hp.nheads,
dim_head=64,
ff_mult=4, # 512*4=2048
conv_expansion_factor=2,
conv_kernel_size=self.hp.depthwise_conv_kernel_size,
attn_dropout=self.hp.dropout,
ff_dropout=self.hp.dropout,
conv_dropout=self.hp.dropout,
attn_flash=True,
t5_rel_pos_bias=False
)
self.heads = nn.ModuleList(
[
nn.Linear(
self.hp.hidden_size,
self.hp.n_codes + len(additional_codes)
)
for _ in range(self.hp.n_cluster_groups)
]
)
def build_mask_from_lengths(self, length, max_len=None):
max_len = max_len or length.max().item()
mask = torch.arange(
max_len, device=length.device)[None, :] >= length[:, None]
return mask.bool()
@torch.no_grad()
def create_mask(
self, B, T, lengths, mask_ratio=None, start_t=None,
min_seq_length=None
):
# 1. Define the random length of condition tokens given the shortest
# audio in the batch
if start_t is None:
start_t = torch.randint(1, min_seq_length - 1, (1,)).item()
# 2. Mask other tokens - sample different masking levels per
if mask_ratio is None:
ratio = torch.rand(1).item()
mask_ratio = masking_logic.schedule(ratio)
# Create a random tensor with values between 0 and 1
random_tensor = torch.rand(
(B, T - start_t), dtype=torch.float).to(self.device)
# Create a mask where values less than p are set to True
initial_mask = random_tensor < mask_ratio
length_mask = self.build_mask_from_lengths(
lengths - start_t, T - start_t)
# we can't pick up tokens past token lengths
initial_mask = torch.logical_and(initial_mask, ~length_mask)
# Constrain ratio to always include some samples
# If all are False let's pick up at least one:
if torch.sum(initial_mask) == 0:
choose_steps = torch.randint(low=0, high=(T - start_t), size=(B,))
initial_mask[torch.arange(B), choose_steps] = torch.tensor(
True, device=self.device)
# 3. Add condition tokens containing information
acoustic_token_mask = torch.cat(
(torch.full((B, start_t), False, device=self.device), initial_mask), # noqa
1
)
return acoustic_token_mask, start_t, mask_ratio
def process_input(
self, data, lengths, rvq_level, min_seq_length=None,
mask_ratio=None, start_t=None, acoustic_token_mask=None
):
"""
data: (B, T, code_level, D)
rvq_level: int
"""
B = data.size(0)
T = data.size(1)
level_data = data[:, :, rvq_level, :] # [B, T, C, D] -> [B, T, D]
# Choose acoustic tokens to mask
if acoustic_token_mask is None:
acoustic_token_mask, start_t, mask_ratio = self.create_mask(
B, T, lengths, mask_ratio=mask_ratio, start_t=start_t,
min_seq_length=min_seq_length)
# Remove code information from chosen tokens
level_data[acoustic_token_mask, :] = 0
# Embed only lower rvq_level
lower_code_data = data[:, :, :rvq_level, :].sum(dim=2)
# Combine with chosen tokens at rvq_level.
# Note: all tokens at rvq_level+1: will be discarded.
summed_data = torch.add(lower_code_data, level_data)
return summed_data, acoustic_token_mask, mask_ratio, start_t
def forward(
self, x, code_level, semantic_tokens, lengths,
speaker_emb=None, min_seq_length=10, mask_ratio=None, start_t=None,
acoustic_token_mask=None
):
# FIXME: parallelize this
batch = []
for lvl, embed in enumerate(self.embedding[:(code_level + 1)]):
batch.append(embed(x[:, :, lvl])) # [B T D]
x = torch.stack(batch, dim=2) # [B T C D]
x, acoustic_token_mask, mask_ratio, start_t = self.process_input(
x, lengths, code_level, min_seq_length=min_seq_length,
mask_ratio=mask_ratio, start_t=start_t,
acoustic_token_mask=acoustic_token_mask
)
# Add phoneme embeddings
# Cross attention for all tokens?
# Add semantic tokens
# HACK ME
semantic_emb = self.semantic_embedding(semantic_tokens)
x = torch.add(x, semantic_emb)
# FIXME pfb30
# Merge different modalities
if self.hp.use_spkr_emb:
spkr_emb = F.normalize(speaker_emb, dim=-1)
spkr_emb = self.spkr_linear(
F.dropout(spkr_emb, self.hp.speaker_embed_dropout)
)
x = torch.add(x, spkr_emb)
output_frames = self.conformer(x, None)
x = self.heads[code_level](output_frames)
return x, acoustic_token_mask, mask_ratio, start_t
@torch.no_grad()
def inference(
self, codes, semantic_tokens,
length: torch.LongTensor, rvq_levels=7,
mask_ratio=0.99, maskgit_inference=True,
start_t: Union[torch.LongTensor, None] = None,
speaker_emb=None, steps=16
):
# Use half of the recording for the conditioning
if start_t is None:
start_t = torch.tensor(int((codes.shape[1]) / 2)).long()
start_t = start_t.item()
for rvq_level in range(rvq_levels):
original_codes = torch.clone(codes)
if rvq_level == 0 and maskgit_inference:
codes = self.multi_step_inference(
original_codes, semantic_tokens, length,
start_t=start_t, vamp_filtering=False,
speaker_emb=speaker_emb, steps=16
)
else:
codes = self.one_step_inference(
original_codes, semantic_tokens, length,
code_level=rvq_level,
mask_ratio=mask_ratio, start_t=start_t,
speaker_emb=speaker_emb
)
codes = rearrange(codes, 'T C -> 1 T C')
# Remove any padding left
codes = rearrange(codes, '1 T C -> 1 C T')
codes = torch.where(codes >= self.hp.n_codes, 0, codes)
acoustic_tokens = codes
semantic_tokens = rearrange(semantic_tokens, 'b c -> b 1 c')
semantic_tokens = torch.where(
semantic_tokens >= self.hp.n_codes, 0, semantic_tokens)
codes = torch.cat([semantic_tokens, acoustic_tokens], dim=1)
return codes
@torch.no_grad()
def one_step_inference(
self, original_codes, semantic_tokens, lengths, code_level=0,
mask_ratio=0.99, start_t=0, inference_setup="argmax", speaker_emb=None
):
codes = torch.clone(original_codes)
logits, _, _, _ = self.forward(
codes, code_level, semantic_tokens, lengths,
mask_ratio=mask_ratio, start_t=start_t,
speaker_emb=speaker_emb, acoustic_token_mask=False)
if inference_setup == "argmax":
probs = torch.nn.functional.softmax(logits, dim=-1)
top_indeces = torch.argmax(probs, dim=-1)
if inference_setup == "sampling":
top_indeces = torch.distributions.Categorical(
logits=logits).sample()
codes = rearrange(codes, '1 T C -> T C')
codes[start_t:, code_level] = top_indeces[0, start_t:]
return codes
@torch.no_grad()
def multi_step_inference(
self, original_codes, semantic_tokens, lengths,
start_t: torch.LongTensor=None,
choice_temperature=1.0, start_iter=0,
steps=16, vamp_filtering=False, speaker_emb=None
):
codes = torch.clone(original_codes)
code_level = 0
_, seq_len, _ = original_codes.shape
mask_token_id = self.padding_id
# Get true codes for the prompt
prompt_mask = codes[:, :start_t, code_level]
# Fill up rest with masks
mask = torch.full(
(1, seq_len - start_t), mask_token_id, device=self.device)
inputs = torch.cat((prompt_mask, mask), 1)
num_mask_tokens_at_start = torch.sum(inputs == mask_token_id, axis=-1)
# Initializes state
state = state_init(inputs, steps, start_iter=start_iter)
def loop_cond_fn(state):
"""Beam search loop termination condition."""
not_at_end = (state.cur_index < steps)
return not_at_end
while loop_cond_fn(state):
"""Beam search loop state update function."""
step = state.cur_index
# Current input ids: [batch_size, seq_length].
cur_ids = state.cur_seqs
# Calls model on current seqs to get next-iteration seqs.
with torch.no_grad():
logits, _, _, _ = self.forward(
rearrange(inputs, 'B T -> B T 1'),
code_level,
semantic_tokens, lengths,
acoustic_token_mask=False,
speaker_emb=speaker_emb)
# Samples the ids using categorical sampling:
if vamp_filtering:
typical_mass = 0.2
typical_min_tokens = 1
top_p = None
sample_cutoff = 0.5
typical_filtering = False
sampled_ids, selected_probs = sample_from_logits(
logits, sample=((step / steps) <= sample_cutoff),
temperature=choice_temperature,
typical_filtering=typical_filtering,
typical_mass=typical_mass,
typical_min_tokens=typical_min_tokens,
top_k=None, top_p=top_p, return_probs=True,
)
else:
sampled_ids = torch.distributions.Categorical(
logits=logits).sample()
# Just updates the masked tokens.
unknown_map = (cur_ids == mask_token_id)
sampled_ids = torch.where(unknown_map, sampled_ids, cur_ids)
# Defines the mask ratio for the next round. The number to mask out
# is determined by mask_ratio * unknown_number_in_the_beginning.
ratio = 1. * (step + 1) / steps
mask_ratio = masking_logic.schedule(ratio)
# Updates final seqs with the current sampled_ids.
final_seqs = torch.clone(state.final_seqs)
final_seqs[:, step, :] = sampled_ids
# Computes the probabilities of each selected tokens.
probs = torch.nn.functional.softmax(logits, dim=-1)
# Extract the probabilities of sampled ids
selected_probs = torch.squeeze(
torch.take_along_dim(
probs, torch.unsqueeze(sampled_ids, -1) , -1),
-1
)
# Ignores the tokens given in the input
# by overwriting their confidence.
selected_probs = torch.where(
unknown_map, selected_probs, torch.inf)
# Gets mask lens for each sample in the
# batch according to the mask ratio.
num_to_mask = torch.unsqueeze(
torch.floor(num_mask_tokens_at_start * mask_ratio), 1)
# Keeps at least one of prediction in this
# round and also masks out at least
# one and for the next iteration
num_to_mask = torch.maximum(
torch.tensor(1),
torch.minimum(
torch.sum(unknown_map, dim=-1, keepdim=True) - 1,
num_to_mask)
)
# Adds noise for randomness
masking = mask_by_random_topk(
num_to_mask, selected_probs, choice_temperature * (1. - ratio))
# Masks tokens with lower confidence.
sampled_ids = torch.where(masking, mask_token_id, sampled_ids)
state = State(
cur_index=state.cur_index + 1,
cur_seqs=sampled_ids,
final_seqs=final_seqs
)
codes = torch.clone(original_codes)
codes = rearrange(codes, '1 T C -> T C')
codes[:, 0] = state.final_seqs[0][-1]
return codes