File size: 2,362 Bytes
2631d60 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 |
# Copyright (c) Meta Platforms, Inc. and affiliates.
# All rights reserved.
#
# This source code is licensed under the license found in the
# LICENSE file in the root directory of this source tree.
import random
import torch
from audiocraft.adversarial.discriminators import (
MultiPeriodDiscriminator,
MultiScaleDiscriminator,
MultiScaleSTFTDiscriminator
)
class TestMultiPeriodDiscriminator:
def test_mpd_discriminator(self):
N, C, T = 2, 2, random.randrange(1, 100_000)
t0 = torch.randn(N, C, T)
periods = [1, 2, 3]
mpd = MultiPeriodDiscriminator(periods=periods, in_channels=C)
logits, fmaps = mpd(t0)
assert len(logits) == len(periods)
assert len(fmaps) == len(periods)
assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits])
assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
class TestMultiScaleDiscriminator:
def test_msd_discriminator(self):
N, C, T = 2, 2, random.randrange(1, 100_000)
t0 = torch.randn(N, C, T)
scale_norms = ['weight_norm', 'weight_norm']
msd = MultiScaleDiscriminator(scale_norms=scale_norms, in_channels=C)
logits, fmaps = msd(t0)
assert len(logits) == len(scale_norms)
assert len(fmaps) == len(scale_norms)
assert all([logit.shape[0] == N and len(logit.shape) == 3 for logit in logits])
assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
class TestMultiScaleStftDiscriminator:
def test_msstftd_discriminator(self):
N, C, T = 2, 2, random.randrange(1, 100_000)
t0 = torch.randn(N, C, T)
n_filters = 4
n_ffts = [128, 256, 64]
hop_lengths = [32, 64, 16]
win_lengths = [128, 256, 64]
msstftd = MultiScaleSTFTDiscriminator(filters=n_filters, n_ffts=n_ffts, hop_lengths=hop_lengths,
win_lengths=win_lengths, in_channels=C)
logits, fmaps = msstftd(t0)
assert len(logits) == len(n_ffts)
assert len(fmaps) == len(n_ffts)
assert all([logit.shape[0] == N and len(logit.shape) == 4 for logit in logits])
assert all([feature.shape[0] == N for fmap in fmaps for feature in fmap])
|