Spaces:
Running
on
A10G
Running
on
A10G
# Copyright (c) 2023 Amphion. | |
# | |
# This source code is licensed under the MIT license found in the | |
# LICENSE file in the root directory of this source tree. | |
import torch | |
from torch import nn | |
from torch.nn import functional as F | |
class LayerNorm(nn.Module): | |
def __init__(self, channels, eps=1e-5): | |
super().__init__() | |
self.channels = channels | |
self.eps = eps | |
self.gamma = nn.Parameter(torch.ones(channels)) | |
self.beta = nn.Parameter(torch.zeros(channels)) | |
def forward(self, x): | |
x = x.transpose(1, -1) | |
x = F.layer_norm(x, (self.channels,), self.gamma, self.beta, self.eps) | |
return x.transpose(1, -1) | |
class ConvReluNorm(nn.Module): | |
def __init__( | |
self, | |
in_channels, | |
hidden_channels, | |
out_channels, | |
kernel_size, | |
n_layers, | |
p_dropout, | |
): | |
super().__init__() | |
self.in_channels = in_channels | |
self.hidden_channels = hidden_channels | |
self.out_channels = out_channels | |
self.kernel_size = kernel_size | |
self.n_layers = n_layers | |
self.p_dropout = p_dropout | |
assert n_layers > 1, "Number of layers should be larger than 0." | |
self.conv_layers = nn.ModuleList() | |
self.norm_layers = nn.ModuleList() | |
self.conv_layers.append( | |
nn.Conv1d( | |
in_channels, hidden_channels, kernel_size, padding=kernel_size // 2 | |
) | |
) | |
self.norm_layers.append(LayerNorm(hidden_channels)) | |
self.relu_drop = nn.Sequential(nn.ReLU(), nn.Dropout(p_dropout)) | |
for _ in range(n_layers - 1): | |
self.conv_layers.append( | |
nn.Conv1d( | |
hidden_channels, | |
hidden_channels, | |
kernel_size, | |
padding=kernel_size // 2, | |
) | |
) | |
self.norm_layers.append(LayerNorm(hidden_channels)) | |
self.proj = nn.Conv1d(hidden_channels, out_channels, 1) | |
self.proj.weight.data.zero_() | |
self.proj.bias.data.zero_() | |
def forward(self, x, x_mask): | |
x_org = x | |
for i in range(self.n_layers): | |
x = self.conv_layers[i](x * x_mask) | |
x = self.norm_layers[i](x) | |
x = self.relu_drop(x) | |
x = x_org + self.proj(x) | |
return x * x_mask | |