|
|
|
|
|
|
|
|
|
|
|
|
|
from itertools import product
|
|
|
|
import pytest
|
|
import torch
|
|
|
|
from audiocraft.modules.seanet import SEANetEncoder, SEANetDecoder, SEANetResnetBlock
|
|
from audiocraft.modules import StreamableConv1d, StreamableConvTranspose1d
|
|
|
|
|
|
class TestSEANetModel:
|
|
|
|
def test_base(self):
|
|
encoder = SEANetEncoder()
|
|
decoder = SEANetDecoder()
|
|
|
|
x = torch.randn(1, 1, 24000)
|
|
z = encoder(x)
|
|
assert list(z.shape) == [1, 128, 75], z.shape
|
|
y = decoder(z)
|
|
assert y.shape == x.shape, (x.shape, y.shape)
|
|
|
|
def test_causal(self):
|
|
encoder = SEANetEncoder(causal=True)
|
|
decoder = SEANetDecoder(causal=True)
|
|
x = torch.randn(1, 1, 24000)
|
|
|
|
z = encoder(x)
|
|
assert list(z.shape) == [1, 128, 75], z.shape
|
|
y = decoder(z)
|
|
assert y.shape == x.shape, (x.shape, y.shape)
|
|
|
|
def test_conv_skip_connection(self):
|
|
encoder = SEANetEncoder(true_skip=False)
|
|
decoder = SEANetDecoder(true_skip=False)
|
|
|
|
x = torch.randn(1, 1, 24000)
|
|
z = encoder(x)
|
|
assert list(z.shape) == [1, 128, 75], z.shape
|
|
y = decoder(z)
|
|
assert y.shape == x.shape, (x.shape, y.shape)
|
|
|
|
def test_seanet_encoder_decoder_final_act(self):
|
|
encoder = SEANetEncoder(true_skip=False)
|
|
decoder = SEANetDecoder(true_skip=False, final_activation='Tanh')
|
|
|
|
x = torch.randn(1, 1, 24000)
|
|
z = encoder(x)
|
|
assert list(z.shape) == [1, 128, 75], z.shape
|
|
y = decoder(z)
|
|
assert y.shape == x.shape, (x.shape, y.shape)
|
|
|
|
def _check_encoder_blocks_norm(self, encoder: SEANetEncoder, n_disable_blocks: int, norm: str):
|
|
n_blocks = 0
|
|
for layer in encoder.model:
|
|
if isinstance(layer, StreamableConv1d):
|
|
n_blocks += 1
|
|
assert layer.conv.norm_type == 'none' if n_blocks <= n_disable_blocks else norm
|
|
elif isinstance(layer, SEANetResnetBlock):
|
|
for resnet_layer in layer.block:
|
|
if isinstance(resnet_layer, StreamableConv1d):
|
|
|
|
assert resnet_layer.conv.norm_type == 'none' if (n_blocks + 1) <= n_disable_blocks else norm
|
|
|
|
def test_encoder_disable_norm(self):
|
|
n_residuals = [0, 1, 3]
|
|
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
|
|
norms = ['weight_norm', 'none']
|
|
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
|
|
encoder = SEANetEncoder(n_residual_layers=n_res, norm=norm,
|
|
disable_norm_outer_blocks=disable_blocks)
|
|
self._check_encoder_blocks_norm(encoder, disable_blocks, norm)
|
|
|
|
def _check_decoder_blocks_norm(self, decoder: SEANetDecoder, n_disable_blocks: int, norm: str):
|
|
n_blocks = 0
|
|
for layer in decoder.model:
|
|
if isinstance(layer, StreamableConv1d):
|
|
n_blocks += 1
|
|
assert layer.conv.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
|
|
elif isinstance(layer, StreamableConvTranspose1d):
|
|
n_blocks += 1
|
|
assert layer.convtr.norm_type == 'none' if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
|
|
elif isinstance(layer, SEANetResnetBlock):
|
|
for resnet_layer in layer.block:
|
|
if isinstance(resnet_layer, StreamableConv1d):
|
|
assert resnet_layer.conv.norm_type == 'none' \
|
|
if (decoder.n_blocks - n_blocks) < n_disable_blocks else norm
|
|
|
|
def test_decoder_disable_norm(self):
|
|
n_residuals = [0, 1, 3]
|
|
disable_blocks = [0, 1, 2, 3, 4, 5, 6]
|
|
norms = ['weight_norm', 'none']
|
|
for n_res, disable_blocks, norm in product(n_residuals, disable_blocks, norms):
|
|
decoder = SEANetDecoder(n_residual_layers=n_res, norm=norm,
|
|
disable_norm_outer_blocks=disable_blocks)
|
|
self._check_decoder_blocks_norm(decoder, disable_blocks, norm)
|
|
|
|
def test_disable_norm_raises_exception(self):
|
|
|
|
with pytest.raises(AssertionError):
|
|
SEANetEncoder(disable_norm_outer_blocks=-1)
|
|
|
|
with pytest.raises(AssertionError):
|
|
SEANetEncoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
|
|
|
|
with pytest.raises(AssertionError):
|
|
SEANetDecoder(disable_norm_outer_blocks=-1)
|
|
|
|
with pytest.raises(AssertionError):
|
|
SEANetDecoder(ratios=[1, 1, 2, 2], disable_norm_outer_blocks=7)
|
|
|