|
|
|
|
|
import os |
|
import torch |
|
import torchaudio |
|
import numpy as np |
|
import torch.nn.functional as F |
|
from torch import Tensor, nn |
|
from typing import Dict, Iterable, Optional |
|
|
|
|
|
SAMPLE_RATE = 16000 |
|
N_FFT = 1024 |
|
N_MELS = 128 |
|
HOP_LENGTH = int(0.01 * SAMPLE_RATE) |
|
DURATION = 10 |
|
N_SAMPLES = int(DURATION * SAMPLE_RATE) |
|
N_FRAMES = N_SAMPLES // HOP_LENGTH + 1 |
|
|
|
def sinusoids(length, channels, max_timescale=10000): |
|
"""Returns sinusoids for positional embedding""" |
|
log_timescale_increment = np.log(max_timescale) / (channels // 2 - 1) |
|
inv_timescales = torch.exp(-log_timescale_increment * torch.arange(channels // 2)) |
|
scaled_time = torch.arange(length)[:, np.newaxis] * inv_timescales[np.newaxis, :] |
|
return torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], dim=1) |
|
|
|
class MelEncoder(nn.Module): |
|
""" |
|
time-frequency represntation |
|
""" |
|
def __init__(self, |
|
sample_rate= 16000, |
|
f_min=0, |
|
f_max=8000, |
|
n_fft=1024, |
|
win_length=1024, |
|
hop_length = int(0.01 * 16000), |
|
n_mels = 128, |
|
power = None, |
|
pad= 0, |
|
normalized= False, |
|
center= True, |
|
pad_mode= "reflect" |
|
): |
|
super(MelEncoder, self).__init__() |
|
self.window = torch.hann_window(win_length) |
|
self.spec_fn = torchaudio.transforms.Spectrogram( |
|
n_fft = n_fft, |
|
win_length = win_length, |
|
hop_length = hop_length, |
|
power = power |
|
) |
|
self.mel_scale = torchaudio.transforms.MelScale( |
|
n_mels, |
|
sample_rate, |
|
f_min, |
|
f_max, |
|
n_fft // 2 + 1) |
|
|
|
self.amplitude_to_db = torchaudio.transforms.AmplitudeToDB() |
|
|
|
def forward(self, wav): |
|
spec = self.spec_fn(wav) |
|
power_spec = spec.real.abs().pow(2) |
|
mel_spec = self.mel_scale(power_spec) |
|
mel_spec = self.amplitude_to_db(mel_spec) |
|
return mel_spec |
|
|
|
class AudioEncoder(nn.Module): |
|
def __init__( |
|
self, n_mels: int, n_ctx: int, audio_dim: int, text_dim: int, num_of_stride_conv: int, |
|
): |
|
super().__init__() |
|
self.mel_encoder = MelEncoder(n_mels=n_mels) |
|
self.conv1 = nn.Conv1d(n_mels, audio_dim, kernel_size=3, padding=1) |
|
self.conv_stack = nn.ModuleList([]) |
|
for _ in range(num_of_stride_conv): |
|
self.conv_stack.append( |
|
nn.Conv1d(audio_dim, audio_dim, kernel_size=3, stride=2, padding=1) |
|
) |
|
|
|
self.register_buffer("positional_embedding", sinusoids(n_ctx, text_dim)) |
|
|
|
def forward(self, x: Tensor): |
|
""" |
|
x : torch.Tensor, shape = (batch_size, waveform) |
|
single channel wavform |
|
""" |
|
x = self.mel_encoder(x) |
|
x = F.gelu(self.conv1(x)) |
|
for conv in self.conv_stack: |
|
x = F.gelu(conv(x)) |
|
x = x.permute(0, 2, 1) |
|
x = (x + self.positional_embedding).to(x.dtype) |
|
return x |