File size: 3,396 Bytes
9d3cb0a
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
import typing as tp

from einops import rearrange
from librosa import filters
import torch
from torch import nn
import torch.nn.functional as F
import torchaudio


class ChromaExtractor(nn.Module):
    """Chroma extraction and quantization.

    Args:
        sample_rate (int): Sample rate for the chroma extraction.
        n_chroma (int): Number of chroma bins for the chroma extraction.
        radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12).
        nfft (int, optional): Number of FFT.
        winlen (int, optional): Window length.
        winhop (int, optional): Window hop size.
        argmax (bool, optional): Whether to use argmax. Defaults to False.
        norm (float, optional): Norm for chroma normalization. Defaults to inf.
    """
    
    def __init__(self, 
                 sample_rate: int, 
                 n_chroma: int = 12, radix2_exp: int = 12,
                 nfft: tp.Optional[int] = None,
                 winlen: tp.Optional[int] = None,
                 winhop: tp.Optional[int] = None, argmax: bool = True,
                 norm: float = torch.inf):
        super().__init__()
        self.winlen = winlen or 2 ** radix2_exp
        self.nfft = nfft or self.winlen
        self.winhop = winhop or (self.winlen // 4)
        self.sample_rate = sample_rate
        self.n_chroma = n_chroma
        self.norm = norm
        self.argmax = argmax
        self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0,
                                                                       n_chroma=self.n_chroma)), persistent=False)
        self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen,
                                                      hop_length=self.winhop, power=2, center=False,
                                                      pad=0, normalized=True)

    def forward(self, wav: torch.Tensor) -> torch.Tensor:
        T = wav.shape[-1]
        # in case we are getting a wav that was dropped out (nullified)
        # from the conditioner, make sure wav length is no less that nfft
        if T < self.nfft:
            pad = self.nfft - T
            r = 0 if pad % 2 == 0 else 1
            wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0)
            assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}"

        wav = F.pad(wav, (int(self.nfft // 2 - self.winhop // 2 ),
                          int(self.nfft // 2 - self.winhop // 2 )), mode="reflect")

        spec = self.spec(wav).squeeze(1)
        raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec)
        norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6)
        norm_chroma = rearrange(norm_chroma, 'b d t -> b t d')

        if self.argmax:
            idx = norm_chroma.argmax(-1, keepdim=True)
            norm_chroma[:] = 0
            norm_chroma.scatter_(dim=-1, index=idx, value=1)

        return norm_chroma


if __name__ == "__main__":
    chroma = ChromaExtractor(sample_rate=16000,
                             n_chroma=4,
                             radix2_exp=None,
                             winlen=16000,
                             nfft=16000,
                             winhop=4000)
    audio = torch.rand(1, 16000)
    c = chroma(audio)