diff --git a/audiotools/__init__.py b/audiotools/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..573ffd06100ad72614df9363b12cda6672f1b70e --- /dev/null +++ b/audiotools/__init__.py @@ -0,0 +1,10 @@ +__version__ = "0.7.3" +from .core import AudioSignal +from .core import STFTParams +from .core import Meter +from .core import util +from . import metrics +from . import data +from . import ml +from .data import datasets +from .data import transforms diff --git a/audiotools/core/__init__.py b/audiotools/core/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..8660c4e67f43d0ded584a38939425e2c28d95cd3 --- /dev/null +++ b/audiotools/core/__init__.py @@ -0,0 +1,4 @@ +from . import util +from .audio_signal import AudioSignal +from .audio_signal import STFTParams +from .loudness import Meter diff --git a/audiotools/core/audio_signal.py b/audiotools/core/audio_signal.py new file mode 100644 index 0000000000000000000000000000000000000000..fb6d751cb968a003656e3e7874c487b83d94c82e --- /dev/null +++ b/audiotools/core/audio_signal.py @@ -0,0 +1,1682 @@ +import copy +import functools +import hashlib +import math +import pathlib +import tempfile +import typing +import warnings +from collections import namedtuple +from pathlib import Path + +import julius +import numpy as np +import soundfile +import torch + +from . import util +from .display import DisplayMixin +from .dsp import DSPMixin +from .effects import EffectMixin +from .effects import ImpulseResponseMixin +from .ffmpeg import FFMPEGMixin +from .loudness import LoudnessMixin +from .playback import PlayMixin +from .whisper import WhisperMixin + + +STFTParams = namedtuple( + "STFTParams", + ["window_length", "hop_length", "window_type", "match_stride", "padding_type"], +) +""" +STFTParams object is a container that holds STFT parameters - window_length, +hop_length, and window_type. Not all parameters need to be specified. Ones that +are not specified will be inferred by the AudioSignal parameters. + +Parameters +---------- +window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. +hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. +window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. +match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False +padding_type : str, optional + Type of padding to use, by default 'reflect' +""" +STFTParams.__new__.__defaults__ = (None, None, None, None, None) + + +class AudioSignal( + EffectMixin, + LoudnessMixin, + PlayMixin, + ImpulseResponseMixin, + DSPMixin, + DisplayMixin, + FFMPEGMixin, + WhisperMixin, +): + """This is the core object of this library. Audio is always + loaded into an AudioSignal, which then enables all the features + of this library, including audio augmentations, I/O, playback, + and more. + + The structure of this object is that the base functionality + is defined in ``core/audio_signal.py``, while extensions to + that functionality are defined in the other ``core/*.py`` + files. For example, all the display-based functionality + (e.g. plot spectrograms, waveforms, write to tensorboard) + are in ``core/display.py``. + + Parameters + ---------- + audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray] + Object to create AudioSignal from. Can be a tensor, numpy array, + or a path to a file. The file is always reshaped to + sample_rate : int, optional + Sample rate of the audio. If different from underlying file, resampling is + performed. If passing in an array or tensor, this must be defined, + by default None + stft_params : STFTParams, optional + Parameters of STFT to use. , by default None + offset : float, optional + Offset in seconds to read from file, by default 0 + duration : float, optional + Duration in seconds to read from file, by default None + device : str, optional + Device to load audio onto, by default None + + Examples + -------- + Loading an AudioSignal from an array, at a sample rate of + 44100. + + >>> signal = AudioSignal(torch.randn(5*44100), 44100) + + Note, the signal is reshaped to have a batch size, and one + audio channel: + + >>> print(signal.shape) + (1, 1, 44100) + + You can treat AudioSignals like tensors, and many of the same + functions you might use on tensors are defined for AudioSignals + as well: + + >>> signal.to("cuda") + >>> signal.cuda() + >>> signal.clone() + >>> signal.detach() + + Indexing AudioSignals returns an AudioSignal: + + >>> signal[..., 3*44100:4*44100] + + The above signal is 1 second long, and is also an AudioSignal. + """ + + def __init__( + self, + audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray], + sample_rate: int = None, + stft_params: STFTParams = None, + offset: float = 0, + duration: float = None, + device: str = None, + ): + audio_path = None + audio_array = None + + if isinstance(audio_path_or_array, str): + audio_path = audio_path_or_array + elif isinstance(audio_path_or_array, pathlib.Path): + audio_path = audio_path_or_array + elif isinstance(audio_path_or_array, np.ndarray): + audio_array = audio_path_or_array + elif torch.is_tensor(audio_path_or_array): + audio_array = audio_path_or_array + else: + raise ValueError( + "audio_path_or_array must be either a Path, " + "string, numpy array, or torch Tensor!" + ) + + self.path_to_file = None + + self.audio_data = None + self.sources = None # List of AudioSignal objects. + self.stft_data = None + if audio_path is not None: + self.load_from_file( + audio_path, offset=offset, duration=duration, device=device + ) + elif audio_array is not None: + assert sample_rate is not None, "Must set sample rate!" + self.load_from_array(audio_array, sample_rate, device=device) + + self.window = None + self.stft_params = stft_params + + self.metadata = { + "offset": offset, + "duration": duration, + } + + @property + def path_to_input_file( + self, + ): + """ + Path to input file, if it exists. + Alias to ``path_to_file`` for backwards compatibility + """ + return self.path_to_file + + @classmethod + def excerpt( + cls, + audio_path: typing.Union[str, Path], + offset: float = None, + duration: float = None, + state: typing.Union[np.random.RandomState, int] = None, + **kwargs, + ): + """Randomly draw an excerpt of ``duration`` seconds from an + audio file specified at ``audio_path``, between ``offset`` seconds + and end of file. ``state`` can be used to seed the random draw. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to audio file to grab excerpt from. + offset : float, optional + Lower bound for the start time, in seconds drawn from + the file, by default None. + duration : float, optional + Duration of excerpt, in seconds, by default None + state : typing.Union[np.random.RandomState, int], optional + RandomState or seed of random state, by default None + + Returns + ------- + AudioSignal + AudioSignal containing excerpt. + + Examples + -------- + >>> signal = AudioSignal.excerpt("path/to/audio", duration=5) + """ + info = util.info(audio_path) + total_duration = info.duration + + state = util.random_state(state) + lower_bound = 0 if offset is None else offset + upper_bound = max(total_duration - duration, 0) + offset = state.uniform(lower_bound, upper_bound) + + signal = cls(audio_path, offset=offset, duration=duration, **kwargs) + signal.metadata["offset"] = offset + signal.metadata["duration"] = duration + + return signal + + @classmethod + def salient_excerpt( + cls, + audio_path: typing.Union[str, Path], + loudness_cutoff: float = None, + num_tries: int = 8, + state: typing.Union[np.random.RandomState, int] = None, + **kwargs, + ): + """Similar to AudioSignal.excerpt, except it extracts excerpts only + if they are above a specified loudness threshold, which is computed via + a fast LUFS routine. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to audio file to grab excerpt from. + loudness_cutoff : float, optional + Loudness threshold in dB. Typical values are ``-40, -60``, + etc, by default None + num_tries : int, optional + Number of tries to grab an excerpt above the threshold + before giving up, by default 8. + state : typing.Union[np.random.RandomState, int], optional + RandomState or seed of random state, by default None + kwargs : dict + Keyword arguments to AudioSignal.excerpt + + Returns + ------- + AudioSignal + AudioSignal containing excerpt. + + + .. warning:: + if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can + result in an infinite loop if ``audio_path`` does not have + any loud enough excerpts. + + Examples + -------- + >>> signal = AudioSignal.salient_excerpt( + "path/to/audio", + loudness_cutoff=-40, + duration=5 + ) + """ + state = util.random_state(state) + if loudness_cutoff is None: + excerpt = cls.excerpt(audio_path, state=state, **kwargs) + else: + loudness = -np.inf + num_try = 0 + while loudness <= loudness_cutoff: + excerpt = cls.excerpt(audio_path, state=state, **kwargs) + loudness = excerpt.loudness() + num_try += 1 + if num_tries is not None and num_try >= num_tries: + break + return excerpt + + @classmethod + def zeros( + cls, + duration: float, + sample_rate: int, + num_channels: int = 1, + batch_size: int = 1, + **kwargs, + ): + """Helper function create an AudioSignal of all zeros. + + Parameters + ---------- + duration : float + Duration of AudioSignal + sample_rate : int + Sample rate of AudioSignal + num_channels : int, optional + Number of channels, by default 1 + batch_size : int, optional + Batch size, by default 1 + + Returns + ------- + AudioSignal + AudioSignal containing all zeros. + + Examples + -------- + Generate 5 seconds of all zeros at a sample rate of 44100. + + >>> signal = AudioSignal.zeros(5.0, 44100) + """ + n_samples = int(duration * sample_rate) + return cls( + torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs + ) + + @classmethod + def wave( + cls, + frequency: float, + duration: float, + sample_rate: int, + num_channels: int = 1, + shape: str = "sine", + **kwargs, + ): + """ + Generate a waveform of a given frequency and shape. + + Parameters + ---------- + frequency : float + Frequency of the waveform + duration : float + Duration of the waveform + sample_rate : int + Sample rate of the waveform + num_channels : int, optional + Number of channels, by default 1 + shape : str, optional + Shape of the waveform, by default "saw" + One of "sawtooth", "square", "sine", "triangle" + kwargs : dict + Keyword arguments to AudioSignal + """ + n_samples = int(duration * sample_rate) + t = torch.linspace(0, duration, n_samples) + if shape == "sawtooth": + from scipy.signal import sawtooth + + wave_data = sawtooth(2 * np.pi * frequency * t, 0.5) + elif shape == "square": + from scipy.signal import square + + wave_data = square(2 * np.pi * frequency * t) + elif shape == "sine": + wave_data = np.sin(2 * np.pi * frequency * t) + elif shape == "triangle": + from scipy.signal import sawtooth + + # frequency is doubled by the abs call, so omit the 2 in 2pi + wave_data = sawtooth(np.pi * frequency * t, 0.5) + wave_data = -np.abs(wave_data) * 2 + 1 + else: + raise ValueError(f"Invalid shape {shape}") + + wave_data = torch.tensor(wave_data, dtype=torch.float32) + wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1) + return cls(wave_data, sample_rate, **kwargs) + + @classmethod + def batch( + cls, + audio_signals: list, + pad_signals: bool = False, + truncate_signals: bool = False, + resample: bool = False, + dim: int = 0, + ): + """Creates a batched AudioSignal from a list of AudioSignals. + + Parameters + ---------- + audio_signals : list[AudioSignal] + List of AudioSignal objects + pad_signals : bool, optional + Whether to pad signals to length of the maximum length + AudioSignal in the list, by default False + truncate_signals : bool, optional + Whether to truncate signals to length of shortest length + AudioSignal in the list, by default False + resample : bool, optional + Whether to resample AudioSignal to the sample rate of + the first AudioSignal in the list, by default False + dim : int, optional + Dimension along which to batch the signals. + + Returns + ------- + AudioSignal + Batched AudioSignal. + + Raises + ------ + RuntimeError + If not all AudioSignals are the same sample rate, and + ``resample=False``, an error is raised. + RuntimeError + If not all AudioSignals are the same the length, and + both ``pad_signals=False`` and ``truncate_signals=False``, + an error is raised. + + Examples + -------- + Batching a bunch of random signals: + + >>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)] + >>> signal = AudioSignal.batch(signal_list) + >>> print(signal.shape) + (10, 1, 44100) + + """ + signal_lengths = [x.signal_length for x in audio_signals] + sample_rates = [x.sample_rate for x in audio_signals] + + if len(set(sample_rates)) != 1: + if resample: + for x in audio_signals: + x.resample(sample_rates[0]) + else: + raise RuntimeError( + f"Not all signals had the same sample rate! Got {sample_rates}. " + f"All signals must have the same sample rate, or resample must be True. " + ) + + if len(set(signal_lengths)) != 1: + if pad_signals: + max_length = max(signal_lengths) + for x in audio_signals: + pad_len = max_length - x.signal_length + x.zero_pad(0, pad_len) + elif truncate_signals: + min_length = min(signal_lengths) + for x in audio_signals: + x.truncate_samples(min_length) + else: + raise RuntimeError( + f"Not all signals had the same length! Got {signal_lengths}. " + f"All signals must be the same length, or pad_signals/truncate_signals " + f"must be True. " + ) + # Concatenate along the specified dimension (default 0) + audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim) + audio_paths = [x.path_to_file for x in audio_signals] + + batched_signal = cls( + audio_data, + sample_rate=audio_signals[0].sample_rate, + ) + batched_signal.path_to_file = audio_paths + return batched_signal + + # I/O + def load_from_file( + self, + audio_path: typing.Union[str, Path], + offset: float, + duration: float, + device: str = "cpu", + ): + """Loads data from file. Used internally when AudioSignal + is instantiated with a path to a file. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to file + offset : float + Offset in seconds + duration : float + Duration in seconds + device : str, optional + Device to put AudioSignal on, by default "cpu" + + Returns + ------- + AudioSignal + AudioSignal loaded from file + """ + import librosa + + data, sample_rate = librosa.load( + audio_path, + offset=offset, + duration=duration, + sr=None, + mono=False, + ) + data = util.ensure_tensor(data) + if data.shape[-1] == 0: + raise RuntimeError( + f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!" + ) + + if data.ndim < 2: + data = data.unsqueeze(0) + if data.ndim < 3: + data = data.unsqueeze(0) + self.audio_data = data + + self.original_signal_length = self.signal_length + + self.sample_rate = sample_rate + self.path_to_file = audio_path + return self.to(device) + + def load_from_array( + self, + audio_array: typing.Union[torch.Tensor, np.ndarray], + sample_rate: int, + device: str = "cpu", + ): + """Loads data from array, reshaping it to be exactly 3 + dimensions. Used internally when AudioSignal is called + with a tensor or an array. + + Parameters + ---------- + audio_array : typing.Union[torch.Tensor, np.ndarray] + Array/tensor of audio of samples. + sample_rate : int + Sample rate of audio + device : str, optional + Device to move audio onto, by default "cpu" + + Returns + ------- + AudioSignal + AudioSignal loaded from array + """ + audio_data = util.ensure_tensor(audio_array) + + if audio_data.dtype == torch.double: + audio_data = audio_data.float() + + if audio_data.ndim < 2: + audio_data = audio_data.unsqueeze(0) + if audio_data.ndim < 3: + audio_data = audio_data.unsqueeze(0) + self.audio_data = audio_data + + self.original_signal_length = self.signal_length + + self.sample_rate = sample_rate + return self.to(device) + + def write(self, audio_path: typing.Union[str, Path]): + """Writes audio to a file. Only writes the audio + that is in the very first item of the batch. To write other items + in the batch, index the signal along the batch dimension + before writing. After writing, the signal's ``path_to_file`` + attribute is updated to the new path. + + Parameters + ---------- + audio_path : typing.Union[str, Path] + Path to write audio to. + + Returns + ------- + AudioSignal + Returns original AudioSignal, so you can use this in a fluent + interface. + + Examples + -------- + Creating and writing a signal to disk: + + >>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100) + >>> signal.write("/tmp/out.wav") + + Writing a different element of the batch: + + >>> signal[5].write("/tmp/out.wav") + + Using this in a fluent interface: + + >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav") + + """ + if self.audio_data[0].abs().max() > 1: + warnings.warn("Audio amplitude > 1 clipped when saving") + soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate) + + self.path_to_file = audio_path + return self + + def deepcopy(self): + """Copies the signal and all of its attributes. + + Returns + ------- + AudioSignal + Deep copy of the audio signal. + """ + return copy.deepcopy(self) + + def copy(self): + """Shallow copy of signal. + + Returns + ------- + AudioSignal + Shallow copy of the audio signal. + """ + return copy.copy(self) + + def clone(self): + """Clones all tensors contained in the AudioSignal, + and returns a copy of the signal with everything + cloned. Useful when using AudioSignal within autograd + computation graphs. + + Relevant attributes are the stft data, the audio data, + and the loudness of the file. + + Returns + ------- + AudioSignal + Clone of AudioSignal. + """ + clone = type(self)( + self.audio_data.clone(), + self.sample_rate, + stft_params=self.stft_params, + ) + if self.stft_data is not None: + clone.stft_data = self.stft_data.clone() + if self._loudness is not None: + clone._loudness = self._loudness.clone() + clone.path_to_file = copy.deepcopy(self.path_to_file) + clone.metadata = copy.deepcopy(self.metadata) + return clone + + def detach(self): + """Detaches tensors contained in AudioSignal. + + Relevant attributes are the stft data, the audio data, + and the loudness of the file. + + Returns + ------- + AudioSignal + Same signal, but with all tensors detached. + """ + if self._loudness is not None: + self._loudness = self._loudness.detach() + if self.stft_data is not None: + self.stft_data = self.stft_data.detach() + + self.audio_data = self.audio_data.detach() + return self + + def hash(self): + """Writes the audio data to a temporary file, and then + hashes it using hashlib. Useful for creating a file + name based on the audio content. + + Returns + ------- + str + Hash of audio data. + + Examples + -------- + Creating a signal, and writing it to a unique file name: + + >>> signal = AudioSignal(torch.randn(44100), 44100) + >>> hash = signal.hash() + >>> signal.write(f"{hash}.wav") + + """ + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + self.write(f.name) + h = hashlib.sha256() + b = bytearray(128 * 1024) + mv = memoryview(b) + with open(f.name, "rb", buffering=0) as f: + for n in iter(lambda: f.readinto(mv), 0): + h.update(mv[:n]) + file_hash = h.hexdigest() + return file_hash + + # Signal operations + def to_mono(self): + """Converts audio data to mono audio, by taking the mean + along the channels dimension. + + Returns + ------- + AudioSignal + AudioSignal with mean of channels. + """ + self.audio_data = self.audio_data.mean(1, keepdim=True) + return self + + def resample(self, sample_rate: int): + """Resamples the audio, using sinc interpolation. This works on both + cpu and gpu, and is much faster on gpu. + + Parameters + ---------- + sample_rate : int + Sample rate to resample to. + + Returns + ------- + AudioSignal + Resampled AudioSignal + """ + if sample_rate == self.sample_rate: + return self + self.audio_data = julius.resample_frac( + self.audio_data, self.sample_rate, sample_rate + ) + self.sample_rate = sample_rate + return self + + # Tensor operations + def to(self, device: str): + """Moves all tensors contained in signal to the specified device. + + Parameters + ---------- + device : str + Device to move AudioSignal onto. Typical values are + "cuda", "cpu", or "cuda:n" to specify the nth gpu. + + Returns + ------- + AudioSignal + AudioSignal with all tensors moved to specified device. + """ + if self._loudness is not None: + self._loudness = self._loudness.to(device) + if self.stft_data is not None: + self.stft_data = self.stft_data.to(device) + if self.audio_data is not None: + self.audio_data = self.audio_data.to(device) + return self + + def float(self): + """Calls ``.float()`` on ``self.audio_data``. + + Returns + ------- + AudioSignal + """ + self.audio_data = self.audio_data.float() + return self + + def cpu(self): + """Moves AudioSignal to cpu. + + Returns + ------- + AudioSignal + """ + return self.to("cpu") + + def cuda(self): # pragma: no cover + """Moves AudioSignal to cuda. + + Returns + ------- + AudioSignal + """ + return self.to("cuda") + + def numpy(self): + """Detaches ``self.audio_data``, moves to cpu, and converts to numpy. + + Returns + ------- + np.ndarray + Audio data as a numpy array. + """ + return self.audio_data.detach().cpu().numpy() + + def zero_pad(self, before: int, after: int): + """Zero pads the audio_data tensor before and after. + + Parameters + ---------- + before : int + How many zeros to prepend to audio. + after : int + How many zeros to append to audio. + + Returns + ------- + AudioSignal + AudioSignal with padding applied. + """ + self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after)) + return self + + def zero_pad_to(self, length: int, mode: str = "after"): + """Pad with zeros to a specified length, either before or after + the audio data. + + Parameters + ---------- + length : int + Length to pad to + mode : str, optional + Whether to prepend or append zeros to signal, by default "after" + + Returns + ------- + AudioSignal + AudioSignal with padding applied. + """ + if mode == "before": + self.zero_pad(max(length - self.signal_length, 0), 0) + elif mode == "after": + self.zero_pad(0, max(length - self.signal_length, 0)) + return self + + def trim(self, before: int, after: int): + """Trims the audio_data tensor before and after. + + Parameters + ---------- + before : int + How many samples to trim from beginning. + after : int + How many samples to trim from end. + + Returns + ------- + AudioSignal + AudioSignal with trimming applied. + """ + if after == 0: + self.audio_data = self.audio_data[..., before:] + else: + self.audio_data = self.audio_data[..., before:-after] + return self + + def truncate_samples(self, length_in_samples: int): + """Truncate signal to specified length. + + Parameters + ---------- + length_in_samples : int + Truncate to this many samples. + + Returns + ------- + AudioSignal + AudioSignal with truncation applied. + """ + self.audio_data = self.audio_data[..., :length_in_samples] + return self + + @property + def device(self): + """Get device that AudioSignal is on. + + Returns + ------- + torch.device + Device that AudioSignal is on. + """ + if self.audio_data is not None: + device = self.audio_data.device + elif self.stft_data is not None: + device = self.stft_data.device + return device + + # Properties + @property + def audio_data(self): + """Returns the audio data tensor in the object. + + Audio data is always of the shape + (batch_size, num_channels, num_samples). If value has less + than 3 dims (e.g. is (num_channels, num_samples)), then it will + be reshaped to (1, num_channels, num_samples) - a batch size of 1. + + Parameters + ---------- + data : typing.Union[torch.Tensor, np.ndarray] + Audio data to set. + + Returns + ------- + torch.Tensor + Audio samples. + """ + return self._audio_data + + @audio_data.setter + def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]): + if data is not None: + assert torch.is_tensor(data), "audio_data should be torch.Tensor" + assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)" + self._audio_data = data + # Old loudness value not guaranteed to be right, reset it. + self._loudness = None + return + + # alias for audio_data + samples = audio_data + + @property + def stft_data(self): + """Returns the STFT data inside the signal. Shape is + (batch, channels, frequencies, time). + + Returns + ------- + torch.Tensor + Complex spectrogram data. + """ + return self._stft_data + + @stft_data.setter + def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]): + if data is not None: + assert torch.is_tensor(data) and torch.is_complex(data) + if self.stft_data is not None and self.stft_data.shape != data.shape: + warnings.warn("stft_data changed shape") + self._stft_data = data + return + + @property + def batch_size(self): + """Batch size of audio signal. + + Returns + ------- + int + Batch size of signal. + """ + return self.audio_data.shape[0] + + @property + def signal_length(self): + """Length of audio signal. + + Returns + ------- + int + Length of signal in samples. + """ + return self.audio_data.shape[-1] + + # alias for signal_length + length = signal_length + + @property + def shape(self): + """Shape of audio data. + + Returns + ------- + tuple + Shape of audio data. + """ + return self.audio_data.shape + + @property + def signal_duration(self): + """Length of audio signal in seconds. + + Returns + ------- + float + Length of signal in seconds. + """ + return self.signal_length / self.sample_rate + + # alias for signal_duration + duration = signal_duration + + @property + def num_channels(self): + """Number of audio channels. + + Returns + ------- + int + Number of audio channels. + """ + return self.audio_data.shape[1] + + # STFT + @staticmethod + @functools.lru_cache(None) + def get_window(window_type: str, window_length: int, device: str): + """Wrapper around scipy.signal.get_window so one can also get the + popular sqrt-hann window. This function caches for efficiency + using functools.lru\_cache. + + Parameters + ---------- + window_type : str + Type of window to get + window_length : int + Length of the window + device : str + Device to put window onto. + + Returns + ------- + torch.Tensor + Window returned by scipy.signal.get_window, as a tensor. + """ + from scipy import signal + + if window_type == "average": + window = np.ones(window_length) / window_length + elif window_type == "sqrt_hann": + window = np.sqrt(signal.get_window("hann", window_length)) + else: + window = signal.get_window(window_type, window_length) + window = torch.from_numpy(window).to(device).float() + return window + + @property + def stft_params(self): + """Returns STFTParams object, which can be re-used to other + AudioSignals. + + This property can be set as well. If values are not defined in STFTParams, + they are inferred automatically from the signal properties. The default is to use + 32ms windows, with 8ms hop length, and the square root of the hann window. + + Returns + ------- + STFTParams + STFT parameters for the AudioSignal. + + Examples + -------- + >>> stft_params = STFTParams(128, 32) + >>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params) + >>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params) + >>> signal1.stft_params = STFTParams() # Defaults + """ + return self._stft_params + + @stft_params.setter + def stft_params(self, value: STFTParams): + default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate)))) + default_hop_len = default_win_len // 4 + default_win_type = "hann" + default_match_stride = False + default_padding_type = "reflect" + + default_stft_params = STFTParams( + window_length=default_win_len, + hop_length=default_hop_len, + window_type=default_win_type, + match_stride=default_match_stride, + padding_type=default_padding_type, + )._asdict() + + value = value._asdict() if value else default_stft_params + + for key in default_stft_params: + if value[key] is None: + value[key] = default_stft_params[key] + + self._stft_params = STFTParams(**value) + self.stft_data = None + + def compute_stft_padding( + self, window_length: int, hop_length: int, match_stride: bool + ): + """Compute how the STFT should be padded, based on match\_stride. + + Parameters + ---------- + window_length : int + Window length of STFT. + hop_length : int + Hop length of STFT. + match_stride : bool + Whether or not to match stride, making the STFT have the same alignment as + convolutional layers. + + Returns + ------- + tuple + Amount to pad on either side of audio. + """ + length = self.signal_length + + if match_stride: + assert ( + hop_length == window_length // 4 + ), "For match_stride, hop must equal n_fft // 4" + right_pad = math.ceil(length / hop_length) * hop_length - length + pad = (window_length - hop_length) // 2 + else: + right_pad = 0 + pad = 0 + + return right_pad, pad + + def stft( + self, + window_length: int = None, + hop_length: int = None, + window_type: str = None, + match_stride: bool = None, + padding_type: str = None, + ): + """Computes the short-time Fourier transform of the audio data, + with specified STFT parameters. + + Parameters + ---------- + window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. + hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. + window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + padding_type : str, optional + Type of padding to use, by default 'reflect' + + Returns + ------- + torch.Tensor + STFT of audio data. + + Examples + -------- + Compute the STFT of an AudioSignal: + + >>> signal = AudioSignal(torch.randn(44100), 44100) + >>> signal.stft() + + Vary the window and hop length: + + >>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)] + >>> for stft_param in stft_params: + >>> signal.stft_params = stft_params + >>> signal.stft() + + """ + window_length = ( + self.stft_params.window_length + if window_length is None + else int(window_length) + ) + hop_length = ( + self.stft_params.hop_length if hop_length is None else int(hop_length) + ) + window_type = ( + self.stft_params.window_type if window_type is None else window_type + ) + match_stride = ( + self.stft_params.match_stride if match_stride is None else match_stride + ) + padding_type = ( + self.stft_params.padding_type if padding_type is None else padding_type + ) + + window = self.get_window(window_type, window_length, self.audio_data.device) + window = window.to(self.audio_data.device) + + audio_data = self.audio_data + right_pad, pad = self.compute_stft_padding( + window_length, hop_length, match_stride + ) + audio_data = torch.nn.functional.pad( + audio_data, (pad, pad + right_pad), padding_type + ) + stft_data = torch.stft( + audio_data.reshape(-1, audio_data.shape[-1]), + n_fft=window_length, + hop_length=hop_length, + window=window, + return_complex=True, + center=True, + ) + _, nf, nt = stft_data.shape + stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt) + + if match_stride: + # Drop first two and last two frames, which are added + # because of padding. Now num_frames * hop_length = num_samples. + stft_data = stft_data[..., 2:-2] + self.stft_data = stft_data + + return stft_data + + def istft( + self, + window_length: int = None, + hop_length: int = None, + window_type: str = None, + match_stride: bool = None, + length: int = None, + ): + """Computes inverse STFT and sets it to audio\_data. + + Parameters + ---------- + window_length : int, optional + Window length of STFT, by default ``0.032 * self.sample_rate``. + hop_length : int, optional + Hop length of STFT, by default ``window_length // 4``. + window_type : str, optional + Type of window to use, by default ``sqrt\_hann``. + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + length : int, optional + Original length of signal, by default None + + Returns + ------- + AudioSignal + AudioSignal with istft applied. + + Raises + ------ + RuntimeError + Raises an error if stft was not called prior to istft on the signal, + or if stft_data is not set. + """ + if self.stft_data is None: + raise RuntimeError("Cannot do inverse STFT without self.stft_data!") + + window_length = ( + self.stft_params.window_length + if window_length is None + else int(window_length) + ) + hop_length = ( + self.stft_params.hop_length if hop_length is None else int(hop_length) + ) + window_type = ( + self.stft_params.window_type if window_type is None else window_type + ) + match_stride = ( + self.stft_params.match_stride if match_stride is None else match_stride + ) + + window = self.get_window(window_type, window_length, self.stft_data.device) + + nb, nch, nf, nt = self.stft_data.shape + stft_data = self.stft_data.reshape(nb * nch, nf, nt) + right_pad, pad = self.compute_stft_padding( + window_length, hop_length, match_stride + ) + + if length is None: + length = self.original_signal_length + length = length + 2 * pad + right_pad + + if match_stride: + # Zero-pad the STFT on either side, putting back the frames that were + # dropped in stft(). + stft_data = torch.nn.functional.pad(stft_data, (2, 2)) + + audio_data = torch.istft( + stft_data, + n_fft=window_length, + hop_length=hop_length, + window=window, + length=length, + center=True, + ) + audio_data = audio_data.reshape(nb, nch, -1) + if match_stride: + audio_data = audio_data[..., pad : -(pad + right_pad)] + self.audio_data = audio_data + + return self + + @staticmethod + @functools.lru_cache(None) + def get_mel_filters( + sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None + ): + """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins. + + Parameters + ---------- + sr : int + Sample rate of audio + n_fft : int + Number of FFT bins + n_mels : int + Number of mels + fmin : float, optional + Lowest frequency, in Hz, by default 0.0 + fmax : float, optional + Highest frequency, by default None + + Returns + ------- + np.ndarray [shape=(n_mels, 1 + n_fft/2)] + Mel transform matrix + """ + from librosa.filters import mel as librosa_mel_fn + + return librosa_mel_fn( + sr=sr, + n_fft=n_fft, + n_mels=n_mels, + fmin=fmin, + fmax=fmax, + ) + + def mel_spectrogram( + self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs + ): + """Computes a Mel spectrogram. + + Parameters + ---------- + n_mels : int, optional + Number of mels, by default 80 + mel_fmin : float, optional + Lowest frequency, in Hz, by default 0.0 + mel_fmax : float, optional + Highest frequency, by default None + kwargs : dict, optional + Keyword arguments to self.stft(). + + Returns + ------- + torch.Tensor [shape=(batch, channels, mels, time)] + Mel spectrogram. + """ + stft = self.stft(**kwargs) + magnitude = torch.abs(stft) + + nf = magnitude.shape[2] + mel_basis = self.get_mel_filters( + sr=self.sample_rate, + n_fft=2 * (nf - 1), + n_mels=n_mels, + fmin=mel_fmin, + fmax=mel_fmax, + ) + mel_basis = torch.from_numpy(mel_basis).to(self.device) + + mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T + mel_spectrogram = mel_spectrogram.transpose(-1, 2) + return mel_spectrogram + + @staticmethod + @functools.lru_cache(None) + def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None): + """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``), + it can be normalized depending on norm. For more information about dct: + http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II + + Parameters + ---------- + n_mfcc : int + Number of mfccs + n_mels : int + Number of mels + norm : str + Use "ortho" to get a orthogonal matrix or None, by default "ortho" + device : str, optional + Device to load the transformation matrix on, by default None + + Returns + ------- + torch.Tensor [shape=(n_mels, n_mfcc)] T + The dct transformation matrix. + """ + from torchaudio.functional import create_dct + + return create_dct(n_mfcc, n_mels, norm).to(device) + + def mfcc( + self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs + ): + """Computes mel-frequency cepstral coefficients (MFCCs). + + Parameters + ---------- + n_mfcc : int, optional + Number of mels, by default 40 + n_mels : int, optional + Number of mels, by default 80 + log_offset: float, optional + Small value to prevent numerical issues when trying to compute log(0), by default 1e-6 + kwargs : dict, optional + Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft() + + Returns + ------- + torch.Tensor [shape=(batch, channels, mfccs, time)] + MFCCs. + """ + + mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs) + mel_spectrogram = torch.log(mel_spectrogram + log_offset) + dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device) + + mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat + mfcc = mfcc.transpose(-1, -2) + return mfcc + + @property + def magnitude(self): + """Computes and returns the absolute value of the STFT, which + is the magnitude. This value can also be set to some tensor. + When set, ``self.stft_data`` is manipulated so that its magnitude + matches what this is set to, and modulated by the phase. + + Returns + ------- + torch.Tensor + Magnitude of STFT. + + Examples + -------- + >>> signal = AudioSignal(torch.randn(44100), 44100) + >>> magnitude = signal.magnitude # Computes stft if not computed + >>> magnitude[magnitude < magnitude.mean()] = 0 + >>> signal.magnitude = magnitude + >>> signal.istft() + """ + if self.stft_data is None: + self.stft() + return torch.abs(self.stft_data) + + @magnitude.setter + def magnitude(self, value): + self.stft_data = value * torch.exp(1j * self.phase) + return + + def log_magnitude( + self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0 + ): + """Computes the log-magnitude of the spectrogram. + + Parameters + ---------- + ref_value : float, optional + The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``. + Zeros in the output correspond to positions where ``S == ref``, + by default 1.0 + amin : float, optional + Minimum threshold for ``S`` and ``ref``, by default 1e-5 + top_db : float, optional + Threshold the output at ``top_db`` below the peak: + ``max(10 * log10(S/ref)) - top_db``, by default -80.0 + + Returns + ------- + torch.Tensor + Log-magnitude spectrogram + """ + magnitude = self.magnitude + + amin = amin**2 + log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin)) + log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value)) + + if top_db is not None: + log_spec = torch.maximum(log_spec, log_spec.max() - top_db) + return log_spec + + @property + def phase(self): + """Computes and returns the phase of the STFT. + This value can also be set to some tensor. + When set, ``self.stft_data`` is manipulated so that its phase + matches what this is set to, we original magnitudeith th. + + Returns + ------- + torch.Tensor + Phase of STFT. + + Examples + -------- + >>> signal = AudioSignal(torch.randn(44100), 44100) + >>> phase = signal.phase # Computes stft if not computed + >>> phase[phase < phase.mean()] = 0 + >>> signal.phase = phase + >>> signal.istft() + """ + if self.stft_data is None: + self.stft() + return torch.angle(self.stft_data) + + @phase.setter + def phase(self, value): + self.stft_data = self.magnitude * torch.exp(1j * value) + return + + # Operator overloading + def __add__(self, other): + new_signal = self.clone() + new_signal.audio_data += util._get_value(other) + return new_signal + + def __iadd__(self, other): + self.audio_data += util._get_value(other) + return self + + def __radd__(self, other): + return self + other + + def __sub__(self, other): + new_signal = self.clone() + new_signal.audio_data -= util._get_value(other) + return new_signal + + def __isub__(self, other): + self.audio_data -= util._get_value(other) + return self + + def __mul__(self, other): + new_signal = self.clone() + new_signal.audio_data *= util._get_value(other) + return new_signal + + def __imul__(self, other): + self.audio_data *= util._get_value(other) + return self + + def __rmul__(self, other): + return self * other + + # Representation + def _info(self): + dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]" + info = { + "duration": f"{dur} seconds", + "batch_size": self.batch_size, + "path": self.path_to_file if self.path_to_file else "path unknown", + "sample_rate": self.sample_rate, + "num_channels": self.num_channels if self.num_channels else "[unknown]", + "audio_data.shape": self.audio_data.shape, + "stft_params": self.stft_params, + "device": self.device, + } + + return info + + def markdown(self): + """Produces a markdown representation of AudioSignal, in a markdown table. + + Returns + ------- + str + Markdown representation of AudioSignal. + + Examples + -------- + >>> signal = AudioSignal(torch.randn(44100), 44100) + >>> print(signal.markdown()) + | Key | Value + |---|--- + | duration | 1.000 seconds | + | batch_size | 1 | + | path | path unknown | + | sample_rate | 44100 | + | num_channels | 1 | + | audio_data.shape | torch.Size([1, 1, 44100]) | + | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) | + | device | cpu | + """ + info = self._info() + + FORMAT = "| Key | Value \n" "|---|--- \n" + for k, v in info.items(): + row = f"| {k} | {v} |\n" + FORMAT += row + return FORMAT + + def __str__(self): + info = self._info() + + desc = "" + for k, v in info.items(): + desc += f"{k}: {v}\n" + return desc + + def __rich__(self): + from rich.table import Table + + info = self._info() + + table = Table(title=f"{self.__class__.__name__}") + table.add_column("Key", style="green") + table.add_column("Value", style="cyan") + + for k, v in info.items(): + table.add_row(k, str(v)) + return table + + # Comparison + def __eq__(self, other): + for k, v in list(self.__dict__.items()): + if torch.is_tensor(v): + if not torch.allclose(v, other.__dict__[k], atol=1e-6): + max_error = (v - other.__dict__[k]).abs().max() + print(f"Max abs error for {k}: {max_error}") + return False + return True + + # Indexing + def __getitem__(self, key): + if torch.is_tensor(key) and key.ndim == 0 and key.item() is True: + assert self.batch_size == 1 + audio_data = self.audio_data + _loudness = self._loudness + stft_data = self.stft_data + + elif isinstance(key, (bool, int, list, slice, tuple)) or ( + torch.is_tensor(key) and key.ndim <= 1 + ): + # Indexing only on the batch dimension. + # Then let's copy over relevant stuff. + # Future work: make this work for time-indexing + # as well, using the hop length. + audio_data = self.audio_data[key] + _loudness = self._loudness[key] if self._loudness is not None else None + stft_data = self.stft_data[key] if self.stft_data is not None else None + + sources = None + + copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params) + copy._loudness = _loudness + copy._stft_data = stft_data + copy.sources = sources + + return copy + + def __setitem__(self, key, value): + if not isinstance(value, type(self)): + self.audio_data[key] = value + return + + if torch.is_tensor(key) and key.ndim == 0 and key.item() is True: + assert self.batch_size == 1 + self.audio_data = value.audio_data + self._loudness = value._loudness + self.stft_data = value.stft_data + return + + elif isinstance(key, (bool, int, list, slice, tuple)) or ( + torch.is_tensor(key) and key.ndim <= 1 + ): + if self.audio_data is not None and value.audio_data is not None: + self.audio_data[key] = value.audio_data + if self._loudness is not None and value._loudness is not None: + self._loudness[key] = value._loudness + if self.stft_data is not None and value.stft_data is not None: + self.stft_data[key] = value.stft_data + return + + def __ne__(self, other): + return not self == other diff --git a/audiotools/core/display.py b/audiotools/core/display.py new file mode 100644 index 0000000000000000000000000000000000000000..66cbcf34cb2cf9fdf8d67ec4418a887eba73f184 --- /dev/null +++ b/audiotools/core/display.py @@ -0,0 +1,194 @@ +import inspect +import typing +from functools import wraps + +from . import util + + +def format_figure(func): + """Decorator for formatting figures produced by the code below. + See :py:func:`audiotools.core.util.format_figure` for more. + + Parameters + ---------- + func : Callable + Plotting function that is decorated by this function. + + """ + + @wraps(func) + def wrapper(*args, **kwargs): + f_keys = inspect.signature(util.format_figure).parameters.keys() + f_kwargs = {} + for k, v in list(kwargs.items()): + if k in f_keys: + kwargs.pop(k) + f_kwargs[k] = v + func(*args, **kwargs) + util.format_figure(**f_kwargs) + + return wrapper + + +class DisplayMixin: + @format_figure + def specshow( + self, + preemphasis: bool = False, + x_axis: str = "time", + y_axis: str = "linear", + n_mels: int = 128, + **kwargs, + ): + """Displays a spectrogram, using ``librosa.display.specshow``. + + Parameters + ---------- + preemphasis : bool, optional + Whether or not to apply preemphasis, which makes high + frequency detail easier to see, by default False + x_axis : str, optional + How to label the x axis, by default "time" + y_axis : str, optional + How to label the y axis, by default "linear" + n_mels : int, optional + If displaying a mel spectrogram with ``y_axis = "mel"``, + this controls the number of mels, by default 128. + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.util.format_figure`. + """ + import librosa + import librosa.display + + # Always re-compute the STFT data before showing it, in case + # it changed. + signal = self.clone() + signal.stft_data = None + + if preemphasis: + signal.preemphasis() + + ref = signal.magnitude.max() + log_mag = signal.log_magnitude(ref_value=ref) + + if y_axis == "mel": + log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10() + log_mag -= log_mag.max() + + librosa.display.specshow( + log_mag.numpy()[0].mean(axis=0), + x_axis=x_axis, + y_axis=y_axis, + sr=signal.sample_rate, + **kwargs, + ) + + @format_figure + def waveplot(self, x_axis: str = "time", **kwargs): + """Displays a waveform plot, using ``librosa.display.waveshow``. + + Parameters + ---------- + x_axis : str, optional + How to label the x axis, by default "time" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.util.format_figure`. + """ + import librosa + import librosa.display + + audio_data = self.audio_data[0].mean(dim=0) + audio_data = audio_data.cpu().numpy() + + plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot" + wave_plot_fn = getattr(librosa.display, plot_fn) + wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs) + + @format_figure + def wavespec(self, x_axis: str = "time", **kwargs): + """Displays a waveform plot, using ``librosa.display.waveshow``. + + Parameters + ---------- + x_axis : str, optional + How to label the x axis, by default "time" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`. + """ + import matplotlib.pyplot as plt + from matplotlib.gridspec import GridSpec + + gs = GridSpec(6, 1) + plt.subplot(gs[0, :]) + self.waveplot(x_axis=x_axis) + plt.subplot(gs[1:, :]) + self.specshow(x_axis=x_axis, **kwargs) + + def write_audio_to_tb( + self, + tag: str, + writer, + step: int = None, + plot_fn: typing.Union[typing.Callable, str] = "specshow", + **kwargs, + ): + """Writes a signal and its spectrogram to Tensorboard. Will show up + under the Audio and Images tab in Tensorboard. + + Parameters + ---------- + tag : str + Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be + written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``). + writer : SummaryWriter + A SummaryWriter object from PyTorch library. + step : int, optional + The step to write the signal to, by default None + plot_fn : typing.Union[typing.Callable, str], optional + How to create the image. Set to ``None`` to avoid plotting, by default "specshow" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or + whatever ``plot_fn`` is set to. + """ + import matplotlib.pyplot as plt + + audio_data = self.audio_data[0, 0].detach().cpu() + sample_rate = self.sample_rate + writer.add_audio(tag, audio_data, step, sample_rate) + + if plot_fn is not None: + if isinstance(plot_fn, str): + plot_fn = getattr(self, plot_fn) + fig = plt.figure() + plt.clf() + plot_fn(**kwargs) + writer.add_figure(tag.replace("wav", "png"), fig, step) + + def save_image( + self, + image_path: str, + plot_fn: typing.Union[typing.Callable, str] = "specshow", + **kwargs, + ): + """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to + a specified file. + + Parameters + ---------- + image_path : str + Where to save the file to. + plot_fn : typing.Union[typing.Callable, str], optional + How to create the image. Set to ``None`` to avoid plotting, by default "specshow" + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or + whatever ``plot_fn`` is set to. + """ + import matplotlib.pyplot as plt + + if isinstance(plot_fn, str): + plot_fn = getattr(self, plot_fn) + + plt.clf() + plot_fn(**kwargs) + plt.savefig(image_path, bbox_inches="tight", pad_inches=0) + plt.close() diff --git a/audiotools/core/dsp.py b/audiotools/core/dsp.py new file mode 100644 index 0000000000000000000000000000000000000000..f9be51a119537b77e497ddc2dac126d569533d7c --- /dev/null +++ b/audiotools/core/dsp.py @@ -0,0 +1,390 @@ +import typing + +import julius +import numpy as np +import torch + +from . import util + + +class DSPMixin: + _original_batch_size = None + _original_num_channels = None + _padded_signal_length = None + + def _preprocess_signal_for_windowing(self, window_duration, hop_duration): + self._original_batch_size = self.batch_size + self._original_num_channels = self.num_channels + + window_length = int(window_duration * self.sample_rate) + hop_length = int(hop_duration * self.sample_rate) + + if window_length % hop_length != 0: + factor = window_length // hop_length + window_length = factor * hop_length + + self.zero_pad(hop_length, hop_length) + self._padded_signal_length = self.signal_length + + return window_length, hop_length + + def windows( + self, window_duration: float, hop_duration: float, preprocess: bool = True + ): + """Generator which yields windows of specified duration from signal with a specified + hop length. + + Parameters + ---------- + window_duration : float + Duration of every window in seconds. + hop_duration : float + Hop between windows in seconds. + preprocess : bool, optional + Whether to preprocess the signal, so that the first sample is in + the middle of the first window, by default True + + Yields + ------ + AudioSignal + Each window is returned as an AudioSignal. + """ + if preprocess: + window_length, hop_length = self._preprocess_signal_for_windowing( + window_duration, hop_duration + ) + + self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length) + + for b in range(self.batch_size): + i = 0 + start_idx = i * hop_length + while True: + start_idx = i * hop_length + i += 1 + end_idx = start_idx + window_length + if end_idx > self.signal_length: + break + yield self[b, ..., start_idx:end_idx] + + def collect_windows( + self, window_duration: float, hop_duration: float, preprocess: bool = True + ): + """Reshapes signal into windows of specified duration from signal with a specified + hop length. Window are placed along the batch dimension. Use with + :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the + original signal. + + Parameters + ---------- + window_duration : float + Duration of every window in seconds. + hop_duration : float + Hop between windows in seconds. + preprocess : bool, optional + Whether to preprocess the signal, so that the first sample is in + the middle of the first window, by default True + + Returns + ------- + AudioSignal + AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)`` + """ + if preprocess: + window_length, hop_length = self._preprocess_signal_for_windowing( + window_duration, hop_duration + ) + + # self.audio_data: (nb, nch, nt). + unfolded = torch.nn.functional.unfold( + self.audio_data.reshape(-1, 1, 1, self.signal_length), + kernel_size=(1, window_length), + stride=(1, hop_length), + ) + # unfolded: (nb * nch, window_length, num_windows). + # -> (nb * nch * num_windows, 1, window_length) + unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length) + self.audio_data = unfolded + return self + + def overlap_and_add(self, hop_duration: float): + """Function which takes a list of windows and overlap adds them into a + signal the same length as ``audio_signal``. + + Parameters + ---------- + hop_duration : float + How much to shift for each window + (overlap is window_duration - hop_duration) in seconds. + + Returns + ------- + AudioSignal + overlap-and-added signal. + """ + hop_length = int(hop_duration * self.sample_rate) + window_length = self.signal_length + + nb, nch = self._original_batch_size, self._original_num_channels + + unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1) + folded = torch.nn.functional.fold( + unfolded, + output_size=(1, self._padded_signal_length), + kernel_size=(1, window_length), + stride=(1, hop_length), + ) + + norm = torch.ones_like(unfolded, device=unfolded.device) + norm = torch.nn.functional.fold( + norm, + output_size=(1, self._padded_signal_length), + kernel_size=(1, window_length), + stride=(1, hop_length), + ) + + folded = folded / norm + + folded = folded.reshape(nb, nch, -1) + self.audio_data = folded + self.trim(hop_length, hop_length) + return self + + def low_pass( + self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51 + ): + """Low-passes the signal in-place. Each item in the batch + can have a different low-pass cutoff, if the input + to this signal is an array or tensor. If a float, all + items are given the same low-pass filter. + + Parameters + ---------- + cutoffs : typing.Union[torch.Tensor, np.ndarray, float] + Cutoff in Hz of low-pass filter. + zeros : int, optional + Number of taps to use in low-pass filter, by default 51 + + Returns + ------- + AudioSignal + Low-passed AudioSignal. + """ + cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size) + cutoffs = cutoffs / self.sample_rate + filtered = torch.empty_like(self.audio_data) + + for i, cutoff in enumerate(cutoffs): + lp_filter = julius.LowPassFilter(cutoff.cpu(), zeros=zeros).to(self.device) + filtered[i] = lp_filter(self.audio_data[i]) + + self.audio_data = filtered + self.stft_data = None + return self + + def high_pass( + self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51 + ): + """High-passes the signal in-place. Each item in the batch + can have a different high-pass cutoff, if the input + to this signal is an array or tensor. If a float, all + items are given the same high-pass filter. + + Parameters + ---------- + cutoffs : typing.Union[torch.Tensor, np.ndarray, float] + Cutoff in Hz of high-pass filter. + zeros : int, optional + Number of taps to use in high-pass filter, by default 51 + + Returns + ------- + AudioSignal + High-passed AudioSignal. + """ + cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size) + cutoffs = cutoffs / self.sample_rate + filtered = torch.empty_like(self.audio_data) + + for i, cutoff in enumerate(cutoffs): + hp_filter = julius.HighPassFilter(cutoff.cpu(), zeros=zeros).to(self.device) + filtered[i] = hp_filter(self.audio_data[i]) + + self.audio_data = filtered + self.stft_data = None + return self + + def mask_frequencies( + self, + fmin_hz: typing.Union[torch.Tensor, np.ndarray, float], + fmax_hz: typing.Union[torch.Tensor, np.ndarray, float], + val: float = 0.0, + ): + """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them + with the value specified by ``val``. Useful for implementing SpecAug. + The min and max can be different for every item in the batch. + + Parameters + ---------- + fmin_hz : typing.Union[torch.Tensor, np.ndarray, float] + Lower end of band to mask out. + fmax_hz : typing.Union[torch.Tensor, np.ndarray, float] + Upper end of band to mask out. + val : float, optional + Value to fill in, by default 0.0 + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + # SpecAug + mag, phase = self.magnitude, self.phase + fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim) + fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim) + assert torch.all(fmin_hz < fmax_hz) + + # build mask + nbins = mag.shape[-2] + bins_hz = torch.linspace(0, self.sample_rate / 2, nbins, device=self.device) + bins_hz = bins_hz[None, None, :, None].repeat( + self.batch_size, 1, 1, mag.shape[-1] + ) + mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz) + mask = mask.to(self.device) + + mag = mag.masked_fill(mask, val) + phase = phase.masked_fill(mask, val) + self.stft_data = mag * torch.exp(1j * phase) + return self + + def mask_timesteps( + self, + tmin_s: typing.Union[torch.Tensor, np.ndarray, float], + tmax_s: typing.Union[torch.Tensor, np.ndarray, float], + val: float = 0.0, + ): + """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them + with the value specified by ``val``. Useful for implementing SpecAug. + The min and max can be different for every item in the batch. + + Parameters + ---------- + tmin_s : typing.Union[torch.Tensor, np.ndarray, float] + Lower end of timesteps to mask out. + tmax_s : typing.Union[torch.Tensor, np.ndarray, float] + Upper end of timesteps to mask out. + val : float, optional + Value to fill in, by default 0.0 + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + # SpecAug + mag, phase = self.magnitude, self.phase + tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim) + tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim) + + assert torch.all(tmin_s < tmax_s) + + # build mask + nt = mag.shape[-1] + bins_t = torch.linspace(0, self.signal_duration, nt, device=self.device) + bins_t = bins_t[None, None, None, :].repeat( + self.batch_size, 1, mag.shape[-2], 1 + ) + mask = (tmin_s <= bins_t) & (bins_t < tmax_s) + + mag = mag.masked_fill(mask, val) + phase = phase.masked_fill(mask, val) + self.stft_data = mag * torch.exp(1j * phase) + return self + + def mask_low_magnitudes( + self, db_cutoff: typing.Union[torch.Tensor, np.ndarray, float], val: float = 0.0 + ): + """Mask away magnitudes below a specified threshold, which + can be different for every item in the batch. + + Parameters + ---------- + db_cutoff : typing.Union[torch.Tensor, np.ndarray, float] + Decibel value for which things below it will be masked away. + val : float, optional + Value to fill in for masked portions, by default 0.0 + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + mag = self.magnitude + log_mag = self.log_magnitude() + + db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim) + mask = log_mag < db_cutoff + mag = mag.masked_fill(mask, val) + + self.magnitude = mag + return self + + def shift_phase(self, shift: typing.Union[torch.Tensor, np.ndarray, float]): + """Shifts the phase by a constant value. + + Parameters + ---------- + shift : typing.Union[torch.Tensor, np.ndarray, float] + What to shift the phase by. + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + shift = util.ensure_tensor(shift, ndim=self.phase.ndim) + self.phase = self.phase + shift + return self + + def corrupt_phase(self, scale: typing.Union[torch.Tensor, np.ndarray, float]): + """Corrupts the phase randomly by some scaled value. + + Parameters + ---------- + scale : typing.Union[torch.Tensor, np.ndarray, float] + Standard deviation of noise to add to the phase. + + Returns + ------- + AudioSignal + Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the + masked audio data. + """ + scale = util.ensure_tensor(scale, ndim=self.phase.ndim) + self.phase = self.phase + scale * torch.randn_like(self.phase) + return self + + def preemphasis(self, coef: float = 0.85): + """Applies pre-emphasis to audio signal. + + Parameters + ---------- + coef : float, optional + How much pre-emphasis to apply, lower values do less. 0 does nothing. + by default 0.85 + + Returns + ------- + AudioSignal + Pre-emphasized signal. + """ + kernel = torch.tensor([1, -coef, 0]).view(1, 1, -1).to(self.device) + x = self.audio_data.reshape(-1, 1, self.signal_length) + x = torch.nn.functional.conv1d(x, kernel, padding=1) + self.audio_data = x.reshape(*self.audio_data.shape) + return self diff --git a/audiotools/core/effects.py b/audiotools/core/effects.py new file mode 100644 index 0000000000000000000000000000000000000000..fb534cbcb2d457575de685fc9248d1716879145b --- /dev/null +++ b/audiotools/core/effects.py @@ -0,0 +1,647 @@ +import typing + +import julius +import numpy as np +import torch +import torchaudio + +from . import util + + +class EffectMixin: + GAIN_FACTOR = np.log(10) / 20 + """Gain factor for converting between amplitude and decibels.""" + CODEC_PRESETS = { + "8-bit": {"format": "wav", "encoding": "ULAW", "bits_per_sample": 8}, + "GSM-FR": {"format": "gsm"}, + "MP3": {"format": "mp3", "compression": -9}, + "Vorbis": {"format": "vorbis", "compression": -1}, + "Ogg": { + "format": "ogg", + "compression": -1, + }, + "Amr-nb": {"format": "amr-nb"}, + } + """Presets for applying codecs via torchaudio.""" + + def mix( + self, + other, + snr: typing.Union[torch.Tensor, np.ndarray, float] = 10, + other_eq: typing.Union[torch.Tensor, np.ndarray] = None, + ): + """Mixes noise with signal at specified + signal-to-noise ratio. Optionally, the + other signal can be equalized in-place. + + + Parameters + ---------- + other : AudioSignal + AudioSignal object to mix with. + snr : typing.Union[torch.Tensor, np.ndarray, float], optional + Signal to noise ratio, by default 10 + other_eq : typing.Union[torch.Tensor, np.ndarray], optional + EQ curve to apply to other signal, if any, by default None + + Returns + ------- + AudioSignal + In-place modification of AudioSignal. + """ + snr = util.ensure_tensor(snr).to(self.device) + + pad_len = max(0, self.signal_length - other.signal_length) + other.zero_pad(0, pad_len) + other.truncate_samples(self.signal_length) + if other_eq is not None: + other = other.equalizer(other_eq) + + tgt_loudness = self.loudness() - snr + other = other.normalize(tgt_loudness) + + self.audio_data = self.audio_data + other.audio_data + return self + + def convolve(self, other, start_at_max: bool = True): + """Convolves self with other. + This function uses FFTs to do the convolution. + + Parameters + ---------- + other : AudioSignal + Signal to convolve with. + start_at_max : bool, optional + Whether to start at the max value of other signal, to + avoid inducing delays, by default True + + Returns + ------- + AudioSignal + Convolved signal, in-place. + """ + from . import AudioSignal + + pad_len = self.signal_length - other.signal_length + + if pad_len > 0: + other.zero_pad(0, pad_len) + else: + other.truncate_samples(self.signal_length) + + if start_at_max: + # Use roll to rotate over the max for every item + # so that the impulse responses don't induce any + # delay. + idx = other.audio_data.abs().argmax(axis=-1) + irs = torch.zeros_like(other.audio_data) + for i in range(other.batch_size): + irs[i] = torch.roll(other.audio_data[i], -idx[i].item(), -1) + other = AudioSignal(irs, other.sample_rate) + + delta = torch.zeros_like(other.audio_data) + delta[..., 0] = 1 + + length = self.signal_length + delta_fft = torch.fft.rfft(delta, length) + other_fft = torch.fft.rfft(other.audio_data, length) + self_fft = torch.fft.rfft(self.audio_data, length) + + convolved_fft = other_fft * self_fft + convolved_audio = torch.fft.irfft(convolved_fft, length) + + delta_convolved_fft = other_fft * delta_fft + delta_audio = torch.fft.irfft(delta_convolved_fft, length) + + # Use the delta to rescale the audio exactly as needed. + delta_max = delta_audio.abs().max(dim=-1, keepdims=True)[0] + scale = 1 / delta_max.clamp(1e-5) + convolved_audio = convolved_audio * scale + + self.audio_data = convolved_audio + + return self + + def apply_ir( + self, + ir, + drr: typing.Union[torch.Tensor, np.ndarray, float] = None, + ir_eq: typing.Union[torch.Tensor, np.ndarray] = None, + use_original_phase: bool = False, + ): + """Applies an impulse response to the signal. If ` is`ir_eq`` + is specified, the impulse response is equalized before + it is applied, using the given curve. + + Parameters + ---------- + ir : AudioSignal + Impulse response to convolve with. + drr : typing.Union[torch.Tensor, np.ndarray, float], optional + Direct-to-reverberant ratio that impulse response will be + altered to, if specified, by default None + ir_eq : typing.Union[torch.Tensor, np.ndarray], optional + Equalization that will be applied to impulse response + if specified, by default None + use_original_phase : bool, optional + Whether to use the original phase, instead of the convolved + phase, by default False + + Returns + ------- + AudioSignal + Signal with impulse response applied to it + """ + if ir_eq is not None: + ir = ir.equalizer(ir_eq) + if drr is not None: + ir = ir.alter_drr(drr) + + # Save the peak before + max_spk = self.audio_data.abs().max(dim=-1, keepdims=True).values + + # Augment the impulse response to simulate microphone effects + # and with varying direct-to-reverberant ratio. + phase = self.phase + self.convolve(ir) + + # Use the input phase + if use_original_phase: + self.stft() + self.stft_data = self.magnitude * torch.exp(1j * phase) + self.istft() + + # Rescale to the input's amplitude + max_transformed = self.audio_data.abs().max(dim=-1, keepdims=True).values + scale_factor = max_spk.clamp(1e-8) / max_transformed.clamp(1e-8) + self = self * scale_factor + + return self + + def ensure_max_of_audio(self, max: float = 1.0): + """Ensures that ``abs(audio_data) <= max``. + + Parameters + ---------- + max : float, optional + Max absolute value of signal, by default 1.0 + + Returns + ------- + AudioSignal + Signal with values scaled between -max and max. + """ + peak = self.audio_data.abs().max(dim=-1, keepdims=True)[0] + peak_gain = torch.ones_like(peak) + peak_gain[peak > max] = max / peak[peak > max] + self.audio_data = self.audio_data * peak_gain + return self + + def normalize(self, db: typing.Union[torch.Tensor, np.ndarray, float] = -24.0): + """Normalizes the signal's volume to the specified db, in LUFS. + This is GPU-compatible, making for very fast loudness normalization. + + Parameters + ---------- + db : typing.Union[torch.Tensor, np.ndarray, float], optional + Loudness to normalize to, by default -24.0 + + Returns + ------- + AudioSignal + Normalized audio signal. + """ + db = util.ensure_tensor(db).to(self.device) + ref_db = self.loudness() + gain = db - ref_db + gain = torch.exp(gain * self.GAIN_FACTOR) + + self.audio_data = self.audio_data * gain[:, None, None] + return self + + def volume_change(self, db: typing.Union[torch.Tensor, np.ndarray, float]): + """Change volume of signal by some amount, in dB. + + Parameters + ---------- + db : typing.Union[torch.Tensor, np.ndarray, float] + Amount to change volume by. + + Returns + ------- + AudioSignal + Signal at new volume. + """ + db = util.ensure_tensor(db, ndim=1).to(self.device) + gain = torch.exp(db * self.GAIN_FACTOR) + self.audio_data = self.audio_data * gain[:, None, None] + return self + + def _to_2d(self): + waveform = self.audio_data.reshape(-1, self.signal_length) + return waveform + + def _to_3d(self, waveform): + return waveform.reshape(self.batch_size, self.num_channels, -1) + + def pitch_shift(self, n_semitones: int, quick: bool = True): + """Pitch shift the signal. All items in the batch + get the same pitch shift. + + Parameters + ---------- + n_semitones : int + How many semitones to shift the signal by. + quick : bool, optional + Using quick pitch shifting, by default True + + Returns + ------- + AudioSignal + Pitch shifted audio signal. + """ + device = self.device + effects = [ + ["pitch", str(n_semitones * 100)], + ["rate", str(self.sample_rate)], + ] + if quick: + effects[0].insert(1, "-q") + + waveform = self._to_2d().cpu() + waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( + waveform, self.sample_rate, effects, channels_first=True + ) + self.sample_rate = sample_rate + self.audio_data = self._to_3d(waveform) + return self.to(device) + + def time_stretch(self, factor: float, quick: bool = True): + """Time stretch the audio signal. + + Parameters + ---------- + factor : float + Factor by which to stretch the AudioSignal. Typically + between 0.8 and 1.2. + quick : bool, optional + Whether to use quick time stretching, by default True + + Returns + ------- + AudioSignal + Time-stretched AudioSignal. + """ + device = self.device + effects = [ + ["tempo", str(factor)], + ["rate", str(self.sample_rate)], + ] + if quick: + effects[0].insert(1, "-q") + + waveform = self._to_2d().cpu() + waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor( + waveform, self.sample_rate, effects, channels_first=True + ) + self.sample_rate = sample_rate + self.audio_data = self._to_3d(waveform) + return self.to(device) + + def apply_codec( + self, + preset: str = None, + format: str = "wav", + encoding: str = None, + bits_per_sample: int = None, + compression: int = None, + ): # pragma: no cover + """Applies an audio codec to the signal. + + Parameters + ---------- + preset : str, optional + One of the keys in ``self.CODEC_PRESETS``, by default None + format : str, optional + Format for audio codec, by default "wav" + encoding : str, optional + Encoding to use, by default None + bits_per_sample : int, optional + How many bits per sample, by default None + compression : int, optional + Compression amount of codec, by default None + + Returns + ------- + AudioSignal + AudioSignal with codec applied. + + Raises + ------ + ValueError + If preset is not in ``self.CODEC_PRESETS``, an error + is thrown. + """ + torchaudio_version_070 = "0.7" in torchaudio.__version__ + if torchaudio_version_070: + return self + + kwargs = { + "format": format, + "encoding": encoding, + "bits_per_sample": bits_per_sample, + "compression": compression, + } + + if preset is not None: + if preset in self.CODEC_PRESETS: + kwargs = self.CODEC_PRESETS[preset] + else: + raise ValueError( + f"Unknown preset: {preset}. " + f"Known presets: {list(self.CODEC_PRESETS.keys())}" + ) + + waveform = self._to_2d() + if kwargs["format"] in ["vorbis", "mp3", "ogg", "amr-nb"]: + # Apply it in a for loop + augmented = torch.cat( + [ + torchaudio.functional.apply_codec( + waveform[i][None, :], self.sample_rate, **kwargs + ) + for i in range(waveform.shape[0]) + ], + dim=0, + ) + else: + augmented = torchaudio.functional.apply_codec( + waveform, self.sample_rate, **kwargs + ) + augmented = self._to_3d(augmented) + + self.audio_data = augmented + return self + + def mel_filterbank(self, n_bands: int): + """Breaks signal into mel bands. + + Parameters + ---------- + n_bands : int + Number of mel bands to use. + + Returns + ------- + torch.Tensor + Mel-filtered bands, with last axis being the band index. + """ + filterbank = ( + julius.SplitBands(self.sample_rate, n_bands).float().to(self.device) + ) + filtered = filterbank(self.audio_data) + return filtered.permute(1, 2, 3, 0) + + def equalizer(self, db: typing.Union[torch.Tensor, np.ndarray]): + """Applies a mel-spaced equalizer to the audio signal. + + Parameters + ---------- + db : typing.Union[torch.Tensor, np.ndarray] + EQ curve to apply. + + Returns + ------- + AudioSignal + AudioSignal with equalization applied. + """ + db = util.ensure_tensor(db) + n_bands = db.shape[-1] + fbank = self.mel_filterbank(n_bands) + + # If there's a batch dimension, make sure it's the same. + if db.ndim == 2: + if db.shape[0] != 1: + assert db.shape[0] == fbank.shape[0] + else: + db = db.unsqueeze(0) + + weights = (10**db).to(self.device).float() + fbank = fbank * weights[:, None, None, :] + eq_audio_data = fbank.sum(-1) + self.audio_data = eq_audio_data + return self + + def clip_distortion( + self, clip_percentile: typing.Union[torch.Tensor, np.ndarray, float] + ): + """Clips the signal at a given percentile. The higher it is, + the lower the threshold for clipping. + + Parameters + ---------- + clip_percentile : typing.Union[torch.Tensor, np.ndarray, float] + Values are between 0.0 to 1.0. Typical values are 0.1 or below. + + Returns + ------- + AudioSignal + Audio signal with clipped audio data. + """ + clip_percentile = util.ensure_tensor(clip_percentile, ndim=1) + min_thresh = torch.quantile(self.audio_data, clip_percentile / 2, dim=-1) + max_thresh = torch.quantile(self.audio_data, 1 - (clip_percentile / 2), dim=-1) + + nc = self.audio_data.shape[1] + min_thresh = min_thresh[:, :nc, :] + max_thresh = max_thresh[:, :nc, :] + + self.audio_data = self.audio_data.clamp(min_thresh, max_thresh) + + return self + + def quantization( + self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int] + ): + """Applies quantization to the input waveform. + + Parameters + ---------- + quantization_channels : typing.Union[torch.Tensor, np.ndarray, int] + Number of evenly spaced quantization channels to quantize + to. + + Returns + ------- + AudioSignal + Quantized AudioSignal. + """ + quantization_channels = util.ensure_tensor(quantization_channels, ndim=3) + + x = self.audio_data + x = (x + 1) / 2 + x = x * quantization_channels + x = x.floor() + x = x / quantization_channels + x = 2 * x - 1 + + residual = (self.audio_data - x).detach() + self.audio_data = self.audio_data - residual + return self + + def mulaw_quantization( + self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int] + ): + """Applies mu-law quantization to the input waveform. + + Parameters + ---------- + quantization_channels : typing.Union[torch.Tensor, np.ndarray, int] + Number of mu-law spaced quantization channels to quantize + to. + + Returns + ------- + AudioSignal + Quantized AudioSignal. + """ + mu = quantization_channels - 1.0 + mu = util.ensure_tensor(mu, ndim=3) + + x = self.audio_data + + # quantize + x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu) + x = ((x + 1) / 2 * mu + 0.5).to(torch.int64) + + # unquantize + x = (x / mu) * 2 - 1.0 + x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu + + residual = (self.audio_data - x).detach() + self.audio_data = self.audio_data - residual + return self + + def __matmul__(self, other): + return self.convolve(other) + + +class ImpulseResponseMixin: + """These functions are generally only used with AudioSignals that are derived + from impulse responses, not other sources like music or speech. These methods + are used to replicate the data augmentation described in [1]. + + 1. Bryan, Nicholas J. "Impulse response data augmentation and deep + neural networks for blind room acoustic parameter estimation." + ICASSP 2020-2020 IEEE International Conference on Acoustics, + Speech and Signal Processing (ICASSP). IEEE, 2020. + """ + + def decompose_ir(self): + """Decomposes an impulse response into early and late + field responses. + """ + # Equations 1 and 2 + # ----------------- + # Breaking up into early + # response + late field response. + + td = torch.argmax(self.audio_data, dim=-1, keepdim=True) + t0 = int(self.sample_rate * 0.0025) + + idx = torch.arange(self.audio_data.shape[-1], device=self.device)[None, None, :] + idx = idx.expand(self.batch_size, -1, -1) + early_idx = (idx >= td - t0) * (idx <= td + t0) + + early_response = torch.zeros_like(self.audio_data, device=self.device) + early_response[early_idx] = self.audio_data[early_idx] + + late_idx = ~early_idx + late_field = torch.zeros_like(self.audio_data, device=self.device) + late_field[late_idx] = self.audio_data[late_idx] + + # Equation 4 + # ---------- + # Decompose early response into windowed + # direct path and windowed residual. + + window = torch.zeros_like(self.audio_data, device=self.device) + for idx in range(self.batch_size): + window_idx = early_idx[idx, 0].nonzero() + window[idx, ..., window_idx] = self.get_window( + "hann", window_idx.shape[-1], self.device + ) + return early_response, late_field, window + + def measure_drr(self): + """Measures the direct-to-reverberant ratio of the impulse + response. + + Returns + ------- + float + Direct-to-reverberant ratio + """ + early_response, late_field, _ = self.decompose_ir() + num = (early_response**2).sum(dim=-1) + den = (late_field**2).sum(dim=-1) + drr = 10 * torch.log10(num / den) + return drr + + @staticmethod + def solve_alpha(early_response, late_field, wd, target_drr): + """Used to solve for the alpha value, which is used + to alter the drr. + """ + # Equation 5 + # ---------- + # Apply the good ol' quadratic formula. + + wd_sq = wd**2 + wd_sq_1 = (1 - wd) ** 2 + e_sq = early_response**2 + l_sq = late_field**2 + a = (wd_sq * e_sq).sum(dim=-1) + b = (2 * (1 - wd) * wd * e_sq).sum(dim=-1) + c = (wd_sq_1 * e_sq).sum(dim=-1) - torch.pow(10, target_drr / 10) * l_sq.sum( + dim=-1 + ) + + expr = ((b**2) - 4 * a * c).sqrt() + alpha = torch.maximum( + (-b - expr) / (2 * a), + (-b + expr) / (2 * a), + ) + return alpha + + def alter_drr(self, drr: typing.Union[torch.Tensor, np.ndarray, float]): + """Alters the direct-to-reverberant ratio of the impulse response. + + Parameters + ---------- + drr : typing.Union[torch.Tensor, np.ndarray, float] + Direct-to-reverberant ratio that impulse response will be + altered to, if specified, by default None + + Returns + ------- + AudioSignal + Altered impulse response. + """ + drr = util.ensure_tensor(drr, 2, self.batch_size).to(self.device) + + early_response, late_field, window = self.decompose_ir() + alpha = self.solve_alpha(early_response, late_field, window, drr) + min_alpha = ( + late_field.abs().max(dim=-1)[0] / early_response.abs().max(dim=-1)[0] + ) + alpha = torch.maximum(alpha, min_alpha)[..., None] + + aug_ir_data = ( + alpha * window * early_response + + ((1 - window) * early_response) + + late_field + ) + self.audio_data = aug_ir_data + self.ensure_max_of_audio() + return self diff --git a/audiotools/core/ffmpeg.py b/audiotools/core/ffmpeg.py new file mode 100644 index 0000000000000000000000000000000000000000..baf27ccca25ffbf9e915aa870ca8797c37187cdd --- /dev/null +++ b/audiotools/core/ffmpeg.py @@ -0,0 +1,204 @@ +import json +import shlex +import subprocess +import tempfile +from pathlib import Path +from typing import Tuple + +import ffmpy +import numpy as np +import torch + + +def r128stats(filepath: str, quiet: bool): + """Takes a path to an audio file, returns a dict with the loudness + stats computed by the ffmpeg ebur128 filter. + + Parameters + ---------- + filepath : str + Path to compute loudness stats on. + quiet : bool + Whether to show FFMPEG output during computation. + + Returns + ------- + dict + Dictionary containing loudness stats. + """ + ffargs = [ + "ffmpeg", + "-nostats", + "-i", + filepath, + "-filter_complex", + "ebur128", + "-f", + "null", + "-", + ] + if quiet: + ffargs += ["-hide_banner"] + proc = subprocess.Popen(ffargs, stderr=subprocess.PIPE, universal_newlines=True) + stats = proc.communicate()[1] + summary_index = stats.rfind("Summary:") + + summary_list = stats[summary_index:].split() + i_lufs = float(summary_list[summary_list.index("I:") + 1]) + i_thresh = float(summary_list[summary_list.index("I:") + 4]) + lra = float(summary_list[summary_list.index("LRA:") + 1]) + lra_thresh = float(summary_list[summary_list.index("LRA:") + 4]) + lra_low = float(summary_list[summary_list.index("low:") + 1]) + lra_high = float(summary_list[summary_list.index("high:") + 1]) + stats_dict = { + "I": i_lufs, + "I Threshold": i_thresh, + "LRA": lra, + "LRA Threshold": lra_thresh, + "LRA Low": lra_low, + "LRA High": lra_high, + } + + return stats_dict + + +def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]: + """Given a path to a file, returns the start time offset and codec of + the first audio stream. + """ + ff = ffmpy.FFprobe( + inputs={path: None}, + global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet", + ) + streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"] + seconds_offset = 0.0 + codec = None + + # Get the offset and codec of the first audio stream we find + # and return its start time, if it has one. + for stream in streams: + if stream["codec_type"] == "audio": + seconds_offset = stream.get("start_time", 0.0) + codec = stream.get("codec_name") + break + return float(seconds_offset), codec + + +class FFMPEGMixin: + _loudness = None + + def ffmpeg_loudness(self, quiet: bool = True): + """Computes loudness of audio file using FFMPEG. + + Parameters + ---------- + quiet : bool, optional + Whether to show FFMPEG output during computation, + by default True + + Returns + ------- + torch.Tensor + Loudness of every item in the batch, computed via + FFMPEG. + """ + loudness = [] + + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + for i in range(self.batch_size): + self[i].write(f.name) + loudness_stats = r128stats(f.name, quiet=quiet) + loudness.append(loudness_stats["I"]) + + self._loudness = torch.from_numpy(np.array(loudness)).float() + return self.loudness() + + def ffmpeg_resample(self, sample_rate: int, quiet: bool = True): + """Resamples AudioSignal using FFMPEG. More memory-efficient + than using julius.resample for long audio files. + + Parameters + ---------- + sample_rate : int + Sample rate to resample to. + quiet : bool, optional + Whether to show FFMPEG output during computation, + by default True + + Returns + ------- + AudioSignal + Resampled AudioSignal. + """ + from audiotools import AudioSignal + + if sample_rate == self.sample_rate: + return self + + with tempfile.NamedTemporaryFile(suffix=".wav") as f: + self.write(f.name) + f_out = f.name.replace("wav", "rs.wav") + command = f"ffmpeg -i {f.name} -ar {sample_rate} {f_out}" + if quiet: + command += " -hide_banner -loglevel error" + subprocess.check_call(shlex.split(command)) + resampled = AudioSignal(f_out) + Path.unlink(Path(f_out)) + return resampled + + @classmethod + def load_from_file_with_ffmpeg(cls, audio_path: str, quiet: bool = True, **kwargs): + """Loads AudioSignal object after decoding it to a wav file using FFMPEG. + Useful for loading audio that isn't covered by librosa's loading mechanism. Also + useful for loading mp3 files, without any offset. + + Parameters + ---------- + audio_path : str + Path to load AudioSignal from. + quiet : bool, optional + Whether to show FFMPEG output during computation, + by default True + + Returns + ------- + AudioSignal + AudioSignal loaded from file with FFMPEG. + """ + audio_path = str(audio_path) + with tempfile.TemporaryDirectory() as d: + wav_file = str(Path(d) / "extracted.wav") + padded_wav = str(Path(d) / "padded.wav") + + global_options = "-y" + if quiet: + global_options += " -loglevel error" + + ff = ffmpy.FFmpeg( + inputs={audio_path: None}, + outputs={wav_file: None}, + global_options=global_options, + ) + ff.run() + + # We pad the file using the start time offset in case it's an audio + # stream starting at some offset in a video container. + pad, codec = ffprobe_offset_and_codec(audio_path) + + # For mp3s, don't pad files with discrepancies less than 0.027s - + # it's likely due to codec latency. The amount of latency introduced + # by mp3 is 1152, which is 0.0261 44khz. So we set the threshold + # here slightly above that. + # Source: https://lame.sourceforge.io/tech-FAQ.txt. + if codec == "mp3" and pad < 0.027: + pad = 0.0 + ff = ffmpy.FFmpeg( + inputs={wav_file: None}, + outputs={padded_wav: f"-af 'adelay={pad*1000}:all=true'"}, + global_options=global_options, + ) + ff.run() + + signal = cls(padded_wav, **kwargs) + + return signal diff --git a/audiotools/core/loudness.py b/audiotools/core/loudness.py new file mode 100644 index 0000000000000000000000000000000000000000..cb3ee2675d7cb71f4c00106b0c1e901b8e51b842 --- /dev/null +++ b/audiotools/core/loudness.py @@ -0,0 +1,320 @@ +import copy + +import julius +import numpy as np +import scipy +import torch +import torch.nn.functional as F +import torchaudio + + +class Meter(torch.nn.Module): + """Tensorized version of pyloudnorm.Meter. Works with batched audio tensors. + + Parameters + ---------- + rate : int + Sample rate of audio. + filter_class : str, optional + Class of weighting filter used. + K-weighting' (default), 'Fenton/Lee 1' + 'Fenton/Lee 2', 'Dash et al.' + by default "K-weighting" + block_size : float, optional + Gating block size in seconds, by default 0.400 + zeros : int, optional + Number of zeros to use in FIR approximation of + IIR filters, by default 512 + use_fir : bool, optional + Whether to use FIR approximation or exact IIR formulation. + If computing on GPU, ``use_fir=True`` will be used, as its + much faster, by default False + """ + + def __init__( + self, + rate: int, + filter_class: str = "K-weighting", + block_size: float = 0.400, + zeros: int = 512, + use_fir: bool = False, + ): + super().__init__() + + self.rate = rate + self.filter_class = filter_class + self.block_size = block_size + self.use_fir = use_fir + + G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41])) + self.register_buffer("G", G) + + # Compute impulse responses so that filtering is fast via + # a convolution at runtime, on GPU, unlike lfilter. + impulse = np.zeros((zeros,)) + impulse[..., 0] = 1.0 + + firs = np.zeros((len(self._filters), 1, zeros)) + passband_gain = torch.zeros(len(self._filters)) + + for i, (_, filter_stage) in enumerate(self._filters.items()): + firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a, impulse) + passband_gain[i] = filter_stage.passband_gain + + firs = torch.from_numpy(firs[..., ::-1].copy()).float() + + self.register_buffer("firs", firs) + self.register_buffer("passband_gain", passband_gain) + + def apply_filter_gpu(self, data: torch.Tensor): + """Performs FIR approximation of loudness computation. + + Parameters + ---------- + data : torch.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + torch.Tensor + Filtered audio data. + """ + # Data is of shape (nb, nch, nt) + # Reshape to (nb*nch, 1, nt) + nb, nt, nch = data.shape + data = data.permute(0, 2, 1) + data = data.reshape(nb * nch, 1, nt) + + # Apply padding + pad_length = self.firs.shape[-1] + + # Apply filtering in sequence + for i in range(self.firs.shape[0]): + data = F.pad(data, (pad_length, pad_length)) + data = julius.fftconv.fft_conv1d(data, self.firs[i, None, ...]) + data = self.passband_gain[i] * data + data = data[..., 1 : nt + 1] + + data = data.permute(0, 2, 1) + data = data[:, :nt, :] + return data + + def apply_filter_cpu(self, data: torch.Tensor): + """Performs IIR formulation of loudness computation. + + Parameters + ---------- + data : torch.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + torch.Tensor + Filtered audio data. + """ + for _, filter_stage in self._filters.items(): + passband_gain = filter_stage.passband_gain + + a_coeffs = torch.from_numpy(filter_stage.a).float().to(data.device) + b_coeffs = torch.from_numpy(filter_stage.b).float().to(data.device) + + _data = data.permute(0, 2, 1) + filtered = torchaudio.functional.lfilter( + _data, a_coeffs, b_coeffs, clamp=False + ) + data = passband_gain * filtered.permute(0, 2, 1) + return data + + def apply_filter(self, data: torch.Tensor): + """Applies filter on either CPU or GPU, depending + on if the audio is on GPU or is on CPU, or if + ``self.use_fir`` is True. + + Parameters + ---------- + data : torch.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + torch.Tensor + Filtered audio data. + """ + if data.is_cuda or self.use_fir: + data = self.apply_filter_gpu(data) + else: + data = self.apply_filter_cpu(data) + return data + + def forward(self, data: torch.Tensor): + """Computes integrated loudness of data. + + Parameters + ---------- + data : torch.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + torch.Tensor + Filtered audio data. + """ + return self.integrated_loudness(data) + + def _unfold(self, input_data): + T_g = self.block_size + overlap = 0.75 # overlap of 75% of the block duration + step = 1.0 - overlap # step size by percentage + + kernel_size = int(T_g * self.rate) + stride = int(T_g * self.rate * step) + unfolded = julius.core.unfold(input_data.permute(0, 2, 1), kernel_size, stride) + unfolded = unfolded.transpose(-1, -2) + + return unfolded + + def integrated_loudness(self, data: torch.Tensor): + """Computes integrated loudness of data. + + Parameters + ---------- + data : torch.Tensor + Audio data of shape (nb, nch, nt). + + Returns + ------- + torch.Tensor + Filtered audio data. + """ + if not torch.is_tensor(data): + data = torch.from_numpy(data).float() + else: + data = data.float() + + input_data = copy.copy(data) + # Data always has a batch and channel dimension. + # Is of shape (nb, nt, nch) + if input_data.ndim < 2: + input_data = input_data.unsqueeze(-1) + if input_data.ndim < 3: + input_data = input_data.unsqueeze(0) + + nb, nt, nch = input_data.shape + + # Apply frequency weighting filters - account + # for the acoustic respose of the head and auditory system + input_data = self.apply_filter(input_data) + + G = self.G # channel gains + T_g = self.block_size # 400 ms gating block standard + Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold + + unfolded = self._unfold(input_data) + + z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2) + l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True)) + l = l.expand_as(z) + + # find gating block indices above absolute threshold + z_avg_gated = z + z_avg_gated[l <= Gamma_a] = 0 + masked = l > Gamma_a + z_avg_gated = z_avg_gated.sum(2) / masked.sum(2) + + # calculate the relative threshold value (see eq. 6) + Gamma_r = ( + -0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0 + ) + Gamma_r = Gamma_r[:, None, None] + Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1]) + + # find gating block indices above relative and absolute thresholds (end of eq. 7) + z_avg_gated = z + z_avg_gated[l <= Gamma_a] = 0 + z_avg_gated[l <= Gamma_r] = 0 + masked = (l > Gamma_a) * (l > Gamma_r) + z_avg_gated = z_avg_gated.sum(2) / masked.sum(2) + + # # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version) + # z_avg_gated = torch.nan_to_num(z_avg_gated) + z_avg_gated = torch.where( + z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated + ) + z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max) + z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min) + + LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1)) + return LUFS.float() + + @property + def filter_class(self): + return self._filter_class + + @filter_class.setter + def filter_class(self, value): + from pyloudnorm import Meter + + meter = Meter(self.rate) + meter.filter_class = value + self._filter_class = value + self._filters = meter._filters + + +class LoudnessMixin: + _loudness = None + MIN_LOUDNESS = -70 + """Minimum loudness possible.""" + + def loudness( + self, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs + ): + """Calculates loudness using an implementation of ITU-R BS.1770-4. + Allows control over gating block size and frequency weighting filters for + additional control. Measure the integrated gated loudness of a signal. + + API is derived from PyLoudnorm, but this implementation is ported to PyTorch + and is tensorized across batches. When on GPU, an FIR approximation of the IIR + filters is used to compute loudness for speed. + + Uses the weighting filters and block size defined by the meter + the integrated loudness is measured based upon the gating algorithm + defined in the ITU-R BS.1770-4 specification. + + Parameters + ---------- + filter_class : str, optional + Class of weighting filter used. + K-weighting' (default), 'Fenton/Lee 1' + 'Fenton/Lee 2', 'Dash et al.' + by default "K-weighting" + block_size : float, optional + Gating block size in seconds, by default 0.400 + kwargs : dict, optional + Keyword arguments to :py:func:`audiotools.core.loudness.Meter`. + + Returns + ------- + torch.Tensor + Loudness of audio data. + """ + if self._loudness is not None: + return self._loudness.to(self.device) + original_length = self.signal_length + if self.signal_duration < 0.5: + pad_len = int((0.5 - self.signal_duration) * self.sample_rate) + self.zero_pad(0, pad_len) + + # create BS.1770 meter + meter = Meter( + self.sample_rate, filter_class=filter_class, block_size=block_size, **kwargs + ) + meter = meter.to(self.device) + # measure loudness + loudness = meter.integrated_loudness(self.audio_data.permute(0, 2, 1)) + self.truncate_samples(original_length) + min_loudness = ( + torch.ones_like(loudness, device=loudness.device) * self.MIN_LOUDNESS + ) + self._loudness = torch.maximum(loudness, min_loudness) + + return self._loudness.to(self.device) diff --git a/audiotools/core/playback.py b/audiotools/core/playback.py new file mode 100644 index 0000000000000000000000000000000000000000..5d0f21aaa392494f35305c0084c05b87667ea14d --- /dev/null +++ b/audiotools/core/playback.py @@ -0,0 +1,252 @@ +""" +These are utilities that allow one to embed an AudioSignal +as a playable object in a Jupyter notebook, or to play audio from +the terminal, etc. +""" # fmt: skip +import base64 +import io +import random +import string +import subprocess +from tempfile import NamedTemporaryFile + +import importlib_resources as pkg_resources + +from . import templates +from .util import _close_temp_files +from .util import format_figure + +headers = pkg_resources.files(templates).joinpath("headers.html").read_text() +widget = pkg_resources.files(templates).joinpath("widget.html").read_text() + +DEFAULT_EXTENSION = ".wav" + + +def _check_imports(): # pragma: no cover + try: + import ffmpy + except: + ffmpy = False + + try: + import IPython + except: + raise ImportError("IPython must be installed in order to use this function!") + return ffmpy, IPython + + +class PlayMixin: + def embed(self, ext: str = None, display: bool = True, return_html: bool = False): + """Embeds audio as a playable audio embed in a notebook, or HTML + document, etc. + + Parameters + ---------- + ext : str, optional + Extension to use when saving the audio, by default ".wav" + display : bool, optional + This controls whether or not to display the audio when called. This + is used when the embed is the last line in a Jupyter cell, to prevent + the audio from being embedded twice, by default True + return_html : bool, optional + Whether to return the data wrapped in an HTML audio element, by default False + + Returns + ------- + str + Either the element for display, or the HTML string of it. + """ + if ext is None: + ext = DEFAULT_EXTENSION + ext = f".{ext}" if not ext.startswith(".") else ext + ffmpy, IPython = _check_imports() + sr = self.sample_rate + tmpfiles = [] + + with _close_temp_files(tmpfiles): + tmp_wav = NamedTemporaryFile(mode="w+", suffix=".wav", delete=False) + tmpfiles.append(tmp_wav) + self.write(tmp_wav.name) + if ext != ".wav" and ffmpy: + tmp_converted = NamedTemporaryFile(mode="w+", suffix=ext, delete=False) + tmpfiles.append(tmp_wav) + ff = ffmpy.FFmpeg( + inputs={tmp_wav.name: None}, + outputs={ + tmp_converted.name: "-write_xing 0 -codec:a libmp3lame -b:a 128k -y -hide_banner -loglevel error" + }, + ) + ff.run() + else: + tmp_converted = tmp_wav + + audio_element = IPython.display.Audio(data=tmp_converted.name, rate=sr) + if display: + IPython.display.display(audio_element) + + if return_html: + audio_element = ( + f" " + ) + return audio_element + + def widget( + self, + title: str = None, + ext: str = ".wav", + add_headers: bool = True, + player_width: str = "100%", + margin: str = "10px", + plot_fn: str = "specshow", + return_html: bool = False, + **kwargs, + ): + """Creates a playable widget with spectrogram. Inspired (heavily) by + https://sjvasquez.github.io/blog/melnet/. + + Parameters + ---------- + title : str, optional + Title of plot, placed in upper right of top-most axis. + ext : str, optional + Extension for embedding, by default ".mp3" + add_headers : bool, optional + Whether or not to add headers (use for first embed, False for later embeds), by default True + player_width : str, optional + Width of the player, as a string in a CSS rule, by default "100%" + margin : str, optional + Margin on all sides of player, by default "10px" + plot_fn : function, optional + Plotting function to use (by default self.specshow). + return_html : bool, optional + Whether to return the data wrapped in an HTML audio element, by default False + kwargs : dict, optional + Keyword arguments to plot_fn (by default self.specshow). + + Returns + ------- + HTML + HTML object. + """ + import matplotlib.pyplot as plt + + def _save_fig_to_tag(): + buffer = io.BytesIO() + + plt.savefig(buffer, bbox_inches="tight", pad_inches=0) + plt.close() + + buffer.seek(0) + data_uri = base64.b64encode(buffer.read()).decode("ascii") + tag = "data:image/png;base64,{0}".format(data_uri) + + return tag + + _, IPython = _check_imports() + + header_html = "" + + if add_headers: + header_html = headers.replace("PLAYER_WIDTH", str(player_width)) + header_html = header_html.replace("MARGIN", str(margin)) + IPython.display.display(IPython.display.HTML(header_html)) + + widget_html = widget + if isinstance(plot_fn, str): + plot_fn = getattr(self, plot_fn) + kwargs["title"] = title + plot_fn(**kwargs) + + fig = plt.gcf() + pixels = fig.get_size_inches() * fig.dpi + + tag = _save_fig_to_tag() + + # Make the source image for the levels + self.specshow() + format_figure((12, 1.5)) + levels_tag = _save_fig_to_tag() + + player_id = "".join(random.choice(string.ascii_uppercase) for _ in range(10)) + + audio_elem = self.embed(ext=ext, display=False) + widget_html = widget_html.replace("AUDIO_SRC", audio_elem.src_attr()) + widget_html = widget_html.replace("IMAGE_SRC", tag) + widget_html = widget_html.replace("LEVELS_SRC", levels_tag) + widget_html = widget_html.replace("PLAYER_ID", player_id) + + # Calculate width/height of figure based on figure size. + widget_html = widget_html.replace("PADDING_AMOUNT", f"{int(pixels[1])}px") + widget_html = widget_html.replace("MAX_WIDTH", f"{int(pixels[0])}px") + + IPython.display.display(IPython.display.HTML(widget_html)) + + if return_html: + html = header_html if add_headers else "" + html += widget_html + return html + + def play(self): + """ + Plays an audio signal if ffplay from the ffmpeg suite of tools is installed. + Otherwise, will fail. The audio signal is written to a temporary file + and then played with ffplay. + """ + tmpfiles = [] + with _close_temp_files(tmpfiles): + tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False) + tmpfiles.append(tmp_wav) + self.write(tmp_wav.name) + print(self) + subprocess.call( + [ + "ffplay", + "-nodisp", + "-autoexit", + "-hide_banner", + "-loglevel", + "error", + tmp_wav.name, + ] + ) + return self + + +if __name__ == "__main__": # pragma: no cover + from audiotools import AudioSignal + + signal = AudioSignal( + "tests/audio/spk/f10_script4_produced.mp3", offset=5, duration=5 + ) + + wave_html = signal.widget( + "Waveform", + plot_fn="waveplot", + return_html=True, + ) + + spec_html = signal.widget("Spectrogram", return_html=True, add_headers=False) + + combined_html = signal.widget( + "Waveform + spectrogram", + plot_fn="wavespec", + return_html=True, + add_headers=False, + ) + + signal.low_pass(8000) + lowpass_html = signal.widget( + "Lowpassed audio", + plot_fn="wavespec", + return_html=True, + add_headers=False, + ) + + with open("/tmp/index.html", "w") as f: + f.write(wave_html) + f.write(spec_html) + f.write(combined_html) + f.write(lowpass_html) diff --git a/audiotools/core/templates/__init__.py b/audiotools/core/templates/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/audiotools/core/templates/headers.html b/audiotools/core/templates/headers.html new file mode 100644 index 0000000000000000000000000000000000000000..9eaef4a94d575f7826608ad63dcc77fab13b7b19 --- /dev/null +++ b/audiotools/core/templates/headers.html @@ -0,0 +1,322 @@ + + + + + + diff --git a/audiotools/core/templates/pandoc.css b/audiotools/core/templates/pandoc.css new file mode 100644 index 0000000000000000000000000000000000000000..842be7be6d65580dab44c6a8013259644f38e6ee --- /dev/null +++ b/audiotools/core/templates/pandoc.css @@ -0,0 +1,407 @@ +/* +Copyright (c) 2017 Chris Patuzzo +https://twitter.com/chrispatuzzo + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. +*/ + +body { + font-family: Helvetica, arial, sans-serif; + font-size: 14px; + line-height: 1.6; + padding-top: 10px; + padding-bottom: 10px; + background-color: white; + padding: 30px; + color: #333; +} + +body > *:first-child { + margin-top: 0 !important; +} + +body > *:last-child { + margin-bottom: 0 !important; +} + +a { + color: #4183C4; + text-decoration: none; +} + +a.absent { + color: #cc0000; +} + +a.anchor { + display: block; + padding-left: 30px; + margin-left: -30px; + cursor: pointer; + position: absolute; + top: 0; + left: 0; + bottom: 0; +} + +h1, h2, h3, h4, h5, h6 { + margin: 20px 0 10px; + padding: 0; + font-weight: bold; + -webkit-font-smoothing: antialiased; + cursor: text; + position: relative; +} + +h2:first-child, h1:first-child, h1:first-child + h2, h3:first-child, h4:first-child, h5:first-child, h6:first-child { + margin-top: 0; + padding-top: 0; +} + +h1:hover a.anchor, h2:hover a.anchor, h3:hover a.anchor, h4:hover a.anchor, h5:hover a.anchor, h6:hover a.anchor { + text-decoration: none; +} + +h1 tt, h1 code { + font-size: inherit; +} + +h2 tt, h2 code { + font-size: inherit; +} + +h3 tt, h3 code { + font-size: inherit; +} + +h4 tt, h4 code { + font-size: inherit; +} + +h5 tt, h5 code { + font-size: inherit; +} + +h6 tt, h6 code { + font-size: inherit; +} + +h1 { + font-size: 28px; + color: black; +} + +h2 { + font-size: 24px; + border-bottom: 1px solid #cccccc; + color: black; +} + +h3 { + font-size: 18px; +} + +h4 { + font-size: 16px; +} + +h5 { + font-size: 14px; +} + +h6 { + color: #777777; + font-size: 14px; +} + +p, blockquote, ul, ol, dl, li, table, pre { + margin: 15px 0; +} + +hr { + border: 0 none; + color: #cccccc; + height: 4px; + padding: 0; +} + +body > h2:first-child { + margin-top: 0; + padding-top: 0; +} + +body > h1:first-child { + margin-top: 0; + padding-top: 0; +} + +body > h1:first-child + h2 { + margin-top: 0; + padding-top: 0; +} + +body > h3:first-child, body > h4:first-child, body > h5:first-child, body > h6:first-child { + margin-top: 0; + padding-top: 0; +} + +a:first-child h1, a:first-child h2, a:first-child h3, a:first-child h4, a:first-child h5, a:first-child h6 { + margin-top: 0; + padding-top: 0; +} + +h1 p, h2 p, h3 p, h4 p, h5 p, h6 p { + margin-top: 0; +} + +li p.first { + display: inline-block; +} + +ul, ol { + padding-left: 30px; +} + +ul :first-child, ol :first-child { + margin-top: 0; +} + +ul :last-child, ol :last-child { + margin-bottom: 0; +} + +dl { + padding: 0; +} + +dl dt { + font-size: 14px; + font-weight: bold; + font-style: italic; + padding: 0; + margin: 15px 0 5px; +} + +dl dt:first-child { + padding: 0; +} + +dl dt > :first-child { + margin-top: 0; +} + +dl dt > :last-child { + margin-bottom: 0; +} + +dl dd { + margin: 0 0 15px; + padding: 0 15px; +} + +dl dd > :first-child { + margin-top: 0; +} + +dl dd > :last-child { + margin-bottom: 0; +} + +blockquote { + border-left: 4px solid #dddddd; + padding: 0 15px; + color: #777777; +} + +blockquote > :first-child { + margin-top: 0; +} + +blockquote > :last-child { + margin-bottom: 0; +} + +table { + padding: 0; +} +table tr { + border-top: 1px solid #cccccc; + background-color: white; + margin: 0; + padding: 0; +} + +table tr:nth-child(2n) { + background-color: #f8f8f8; +} + +table tr th { + font-weight: bold; + border: 1px solid #cccccc; + text-align: left; + margin: 0; + padding: 6px 13px; +} + +table tr td { + border: 1px solid #cccccc; + text-align: left; + margin: 0; + padding: 6px 13px; +} + +table tr th :first-child, table tr td :first-child { + margin-top: 0; +} + +table tr th :last-child, table tr td :last-child { + margin-bottom: 0; +} + +img { + max-width: 100%; +} + +span.frame { + display: block; + overflow: hidden; +} + +span.frame > span { + border: 1px solid #dddddd; + display: block; + float: left; + overflow: hidden; + margin: 13px 0 0; + padding: 7px; + width: auto; +} + +span.frame span img { + display: block; + float: left; +} + +span.frame span span { + clear: both; + color: #333333; + display: block; + padding: 5px 0 0; +} + +span.align-center { + display: block; + overflow: hidden; + clear: both; +} + +span.align-center > span { + display: block; + overflow: hidden; + margin: 13px auto 0; + text-align: center; +} + +span.align-center span img { + margin: 0 auto; + text-align: center; +} + +span.align-right { + display: block; + overflow: hidden; + clear: both; +} + +span.align-right > span { + display: block; + overflow: hidden; + margin: 13px 0 0; + text-align: right; +} + +span.align-right span img { + margin: 0; + text-align: right; +} + +span.float-left { + display: block; + margin-right: 13px; + overflow: hidden; + float: left; +} + +span.float-left span { + margin: 13px 0 0; +} + +span.float-right { + display: block; + margin-left: 13px; + overflow: hidden; + float: right; +} + +span.float-right > span { + display: block; + overflow: hidden; + margin: 13px auto 0; + text-align: right; +} + +code, tt { + margin: 0 2px; + padding: 0 5px; + white-space: nowrap; + border-radius: 3px; +} + +pre code { + margin: 0; + padding: 0; + white-space: pre; + border: none; + background: transparent; +} + +.highlight pre { + font-size: 13px; + line-height: 19px; + overflow: auto; + padding: 6px 10px; + border-radius: 3px; +} + +pre { + font-size: 13px; + line-height: 19px; + overflow: auto; + padding: 6px 10px; + border-radius: 3px; +} + +pre code, pre tt { + background-color: transparent; + border: none; +} + +body { + max-width: 600px; +} diff --git a/audiotools/core/templates/widget.html b/audiotools/core/templates/widget.html new file mode 100644 index 0000000000000000000000000000000000000000..0b44e8aec64fd1db929da5fa6208dee00247c967 --- /dev/null +++ b/audiotools/core/templates/widget.html @@ -0,0 +1,52 @@ +
+
+
+
+ +
+
+ +
+ + + +
+ +
+ + +
+
+ + diff --git a/audiotools/core/util.py b/audiotools/core/util.py new file mode 100644 index 0000000000000000000000000000000000000000..ece1344658d10836aa2eb693f275294ad8cdbb52 --- /dev/null +++ b/audiotools/core/util.py @@ -0,0 +1,671 @@ +import csv +import glob +import math +import numbers +import os +import random +import typing +from contextlib import contextmanager +from dataclasses import dataclass +from pathlib import Path +from typing import Dict +from typing import List + +import numpy as np +import torch +import torchaudio +from flatten_dict import flatten +from flatten_dict import unflatten + + +@dataclass +class Info: + """Shim for torchaudio.info API changes.""" + + sample_rate: float + num_frames: int + + @property + def duration(self) -> float: + return self.num_frames / self.sample_rate + + +def info(audio_path: str): + """Shim for torchaudio.info to make 0.7.2 API match 0.8.0. + + Parameters + ---------- + audio_path : str + Path to audio file. + """ + # try default backend first, then fallback to soundfile + try: + info = torchaudio.info(str(audio_path)) + except: # pragma: no cover + info = torchaudio.backend.soundfile_backend.info(str(audio_path)) + + if isinstance(info, tuple): # pragma: no cover + signal_info = info[0] + info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length) + else: + info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames) + + return info + + +def ensure_tensor( + x: typing.Union[np.ndarray, torch.Tensor, float, int], + ndim: int = None, + batch_size: int = None, +): + """Ensures that the input ``x`` is a tensor of specified + dimensions and batch size. + + Parameters + ---------- + x : typing.Union[np.ndarray, torch.Tensor, float, int] + Data that will become a tensor on its way out. + ndim : int, optional + How many dimensions should be in the output, by default None + batch_size : int, optional + The batch size of the output, by default None + + Returns + ------- + torch.Tensor + Modified version of ``x`` as a tensor. + """ + if not torch.is_tensor(x): + x = torch.as_tensor(x) + if ndim is not None: + assert x.ndim <= ndim + while x.ndim < ndim: + x = x.unsqueeze(-1) + if batch_size is not None: + if x.shape[0] != batch_size: + shape = list(x.shape) + shape[0] = batch_size + x = x.expand(*shape) + return x + + +def _get_value(other): + from . import AudioSignal + + if isinstance(other, AudioSignal): + return other.audio_data + return other + + +def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int): + """Closest frequency bin given a frequency, number + of bins, and a sampling rate. + + Parameters + ---------- + hz : torch.Tensor + Tensor of frequencies in Hz. + n_fft : int + Number of FFT bins. + sample_rate : int + Sample rate of audio. + + Returns + ------- + torch.Tensor + Closest bins to the data. + """ + shape = hz.shape + hz = hz.flatten() + freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2) + hz[hz > sample_rate / 2] = sample_rate / 2 + + closest = (hz[None, :] - freqs[:, None]).abs() + closest_bins = closest.min(dim=0).indices + + return closest_bins.reshape(*shape) + + +def random_state(seed: typing.Union[int, np.random.RandomState]): + """ + Turn seed into a np.random.RandomState instance. + + Parameters + ---------- + seed : typing.Union[int, np.random.RandomState] or None + If seed is None, return the RandomState singleton used by np.random. + If seed is an int, return a new RandomState instance seeded with seed. + If seed is already a RandomState instance, return it. + Otherwise raise ValueError. + + Returns + ------- + np.random.RandomState + Random state object. + + Raises + ------ + ValueError + If seed is not valid, an error is thrown. + """ + if seed is None or seed is np.random: + return np.random.mtrand._rand + elif isinstance(seed, (numbers.Integral, np.integer, int)): + return np.random.RandomState(seed) + elif isinstance(seed, np.random.RandomState): + return seed + else: + raise ValueError( + "%r cannot be used to seed a numpy.random.RandomState" " instance" % seed + ) + + +def seed(random_seed, set_cudnn=False): + """ + Seeds all random states with the same random seed + for reproducibility. Seeds ``numpy``, ``random`` and ``torch`` + random generators. + For full reproducibility, two further options must be set + according to the torch documentation: + https://pytorch.org/docs/stable/notes/randomness.html + To do this, ``set_cudnn`` must be True. It defaults to + False, since setting it to True results in a performance + hit. + + Args: + random_seed (int): integer corresponding to random seed to + use. + set_cudnn (bool): Whether or not to set cudnn into determinstic + mode and off of benchmark mode. Defaults to False. + """ + + torch.manual_seed(random_seed) + np.random.seed(random_seed) + random.seed(random_seed) + + if set_cudnn: + torch.backends.cudnn.deterministic = True + torch.backends.cudnn.benchmark = False + + +@contextmanager +def _close_temp_files(tmpfiles: list): + """Utility function for creating a context and closing all temporary files + once the context is exited. For correct functionality, all temporary file + handles created inside the context must be appended to the ```tmpfiles``` + list. + + This function is taken wholesale from Scaper. + + Parameters + ---------- + tmpfiles : list + List of temporary file handles + """ + + def _close(): + for t in tmpfiles: + try: + t.close() + os.unlink(t.name) + except: + pass + + try: + yield + except: # pragma: no cover + _close() + raise + _close() + + +AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"] + + +def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS): + """Finds all audio files in a directory recursively. + Returns a list. + + Parameters + ---------- + folder : str + Folder to look for audio files in, recursively. + ext : List[str], optional + Extensions to look for without the ., by default + ``['.wav', '.flac', '.mp3', '.mp4']``. + """ + folder = Path(folder) + # Take care of case where user has passed in an audio file directly + # into one of the calling functions. + if str(folder).endswith(tuple(ext)): + # if, however, there's a glob in the path, we need to + # return the glob, not the file. + if "*" in str(folder): + return glob.glob(str(folder), recursive=("**" in str(folder))) + else: + return [folder] + + files = [] + for x in ext: + files += folder.glob(f"**/*{x}") + return files + + +def read_sources( + sources: List[str], + remove_empty: bool = True, + relative_path: str = "", + ext: List[str] = AUDIO_EXTENSIONS, +): + """Reads audio sources that can either be folders + full of audio files, or CSV files that contain paths + to audio files. CSV files that adhere to the expected + format can be generated by + :py:func:`audiotools.data.preprocess.create_csv`. + + Parameters + ---------- + sources : List[str] + List of audio sources to be converted into a + list of lists of audio files. + remove_empty : bool, optional + Whether or not to remove rows with an empty "path" + from each CSV file, by default True. + + Returns + ------- + list + List of lists of rows of CSV files. + """ + files = [] + relative_path = Path(relative_path) + for source in sources: + source = str(source) + _files = [] + if source.endswith(".csv"): + with open(source, "r") as f: + reader = csv.DictReader(f) + for x in reader: + if remove_empty and x["path"] == "": + continue + if x["path"] != "": + x["path"] = str(relative_path / x["path"]) + _files.append(x) + else: + for x in find_audio(source, ext=ext): + x = str(relative_path / x) + _files.append({"path": x}) + files.append(sorted(_files, key=lambda x: x["path"])) + return files + + +def choose_from_list_of_lists( + state: np.random.RandomState, list_of_lists: list, p: float = None +): + """Choose a single item from a list of lists. + + Parameters + ---------- + state : np.random.RandomState + Random state to use when choosing an item. + list_of_lists : list + A list of lists from which items will be drawn. + p : float, optional + Probabilities of each list, by default None + + Returns + ------- + typing.Any + An item from the list of lists. + """ + source_idx = state.choice(list(range(len(list_of_lists))), p=p) + item_idx = state.randint(len(list_of_lists[source_idx])) + return list_of_lists[source_idx][item_idx], source_idx, item_idx + + +@contextmanager +def chdir(newdir: typing.Union[Path, str]): + """ + Context manager for switching directories to run a + function. Useful for when you want to use relative + paths to different runs. + + Parameters + ---------- + newdir : typing.Union[Path, str] + Directory to switch to. + """ + curdir = os.getcwd() + try: + os.chdir(newdir) + yield + finally: + os.chdir(curdir) + + +def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"): + """Moves items in a batch (typically generated by a DataLoader as a list + or a dict) to the specified device. This works even if dictionaries + are nested. + + Parameters + ---------- + batch : typing.Union[dict, list, torch.Tensor] + Batch, typically generated by a dataloader, that will be moved to + the device. + device : str, optional + Device to move batch to, by default "cpu" + + Returns + ------- + typing.Union[dict, list, torch.Tensor] + Batch with all values moved to the specified device. + """ + if isinstance(batch, dict): + batch = flatten(batch) + for key, val in batch.items(): + try: + batch[key] = val.to(device) + except: + pass + batch = unflatten(batch) + elif torch.is_tensor(batch): + batch = batch.to(device) + elif isinstance(batch, list): + for i in range(len(batch)): + try: + batch[i] = batch[i].to(device) + except: + pass + return batch + + +def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None): + """Samples from a distribution defined by a tuple. The first + item in the tuple is the distribution type, and the rest of the + items are arguments to that distribution. The distribution function + is gotten from the ``np.random.RandomState`` object. + + Parameters + ---------- + dist_tuple : tuple + Distribution tuple + state : np.random.RandomState, optional + Random state, or seed to use, by default None + + Returns + ------- + typing.Union[float, int, str] + Draw from the distribution. + + Examples + -------- + Sample from a uniform distribution: + + >>> dist_tuple = ("uniform", 0, 1) + >>> sample_from_dist(dist_tuple) + + Sample from a constant distribution: + + >>> dist_tuple = ("const", 0) + >>> sample_from_dist(dist_tuple) + + Sample from a normal distribution: + + >>> dist_tuple = ("normal", 0, 0.5) + >>> sample_from_dist(dist_tuple) + + """ + if dist_tuple[0] == "const": + return dist_tuple[1] + state = random_state(state) + dist_fn = getattr(state, dist_tuple[0]) + return dist_fn(*dist_tuple[1:]) + + +def collate(list_of_dicts: list, n_splits: int = None): + """Collates a list of dictionaries (e.g. as returned by a + dataloader) into a dictionary with batched values. This routine + uses the default torch collate function for everything + except AudioSignal objects, which are handled by the + :py:func:`audiotools.core.audio_signal.AudioSignal.batch` + function. + + This function takes n_splits to enable splitting a batch + into multiple sub-batches for the purposes of gradient accumulation, + etc. + + Parameters + ---------- + list_of_dicts : list + List of dictionaries to be collated. + n_splits : int + Number of splits to make when creating the batches (split into + sub-batches). Useful for things like gradient accumulation. + + Returns + ------- + dict + Dictionary containing batched data. + """ + + from . import AudioSignal + + batches = [] + list_len = len(list_of_dicts) + + return_list = False if n_splits is None else True + n_splits = 1 if n_splits is None else n_splits + n_items = int(math.ceil(list_len / n_splits)) + + for i in range(0, list_len, n_items): + # Flatten the dictionaries to avoid recursion. + list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]] + dict_of_lists = { + k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0] + } + + batch = {} + for k, v in dict_of_lists.items(): + if isinstance(v, list): + if all(isinstance(s, AudioSignal) for s in v): + batch[k] = AudioSignal.batch(v, pad_signals=True) + else: + # Borrow the default collate fn from torch. + batch[k] = torch.utils.data._utils.collate.default_collate(v) + batches.append(unflatten(batch)) + + batches = batches[0] if not return_list else batches + return batches + + +BASE_SIZE = 864 +DEFAULT_FIG_SIZE = (9, 3) + + +def format_figure( + fig_size: tuple = None, + title: str = None, + fig=None, + format_axes: bool = True, + format: bool = True, + font_color: str = "white", +): + """Prettifies the spectrogram and waveform plots. A title + can be inset into the top right corner, and the axes can be + inset into the figure, allowing the data to take up the entire + image. Used in + + - :py:func:`audiotools.core.display.DisplayMixin.specshow` + - :py:func:`audiotools.core.display.DisplayMixin.waveplot` + - :py:func:`audiotools.core.display.DisplayMixin.wavespec` + + Parameters + ---------- + fig_size : tuple, optional + Size of figure, by default (9, 3) + title : str, optional + Title to inset in top right, by default None + fig : matplotlib.figure.Figure, optional + Figure object, if None ``plt.gcf()`` will be used, by default None + format_axes : bool, optional + Format the axes to be inside the figure, by default True + format : bool, optional + This formatting can be skipped entirely by passing ``format=False`` + to any of the plotting functions that use this formater, by default True + font_color : str, optional + Color of font of axes, by default "white" + """ + import matplotlib + import matplotlib.pyplot as plt + + if fig_size is None: + fig_size = DEFAULT_FIG_SIZE + if not format: + return + if fig is None: + fig = plt.gcf() + fig.set_size_inches(*fig_size) + axs = fig.axes + + pixels = (fig.get_size_inches() * fig.dpi)[0] + font_scale = pixels / BASE_SIZE + + if format_axes: + axs = fig.axes + + for ax in axs: + ymin, _ = ax.get_ylim() + xmin, _ = ax.get_xlim() + + ticks = ax.get_yticks() + for t in ticks[2:-1]: + t = axs[0].annotate( + f"{(t / 1000):2.1f}k", + xy=(xmin, t), + xycoords="data", + xytext=(5, -5), + textcoords="offset points", + ha="left", + va="top", + color=font_color, + fontsize=12 * font_scale, + alpha=0.75, + ) + + ticks = ax.get_xticks()[2:] + for t in ticks[:-1]: + t = axs[0].annotate( + f"{t:2.1f}s", + xy=(t, ymin), + xycoords="data", + xytext=(5, 5), + textcoords="offset points", + ha="center", + va="bottom", + color=font_color, + fontsize=12 * font_scale, + alpha=0.75, + ) + + ax.margins(0, 0) + ax.set_axis_off() + ax.xaxis.set_major_locator(plt.NullLocator()) + ax.yaxis.set_major_locator(plt.NullLocator()) + + plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0) + + if title is not None: + t = axs[0].annotate( + title, + xy=(1, 1), + xycoords="axes fraction", + fontsize=20 * font_scale, + xytext=(-5, -5), + textcoords="offset points", + ha="right", + va="top", + color="white", + ) + t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black")) + + +def generate_chord_dataset( + max_voices: int = 8, + sample_rate: int = 44100, + num_items: int = 5, + duration: float = 1.0, + min_note: str = "C2", + max_note: str = "C6", + output_dir: Path = "chords", +): + """ + Generates a toy multitrack dataset of chords, synthesized from sine waves. + + + Parameters + ---------- + max_voices : int, optional + Maximum number of voices in a chord, by default 8 + sample_rate : int, optional + Sample rate of audio, by default 44100 + num_items : int, optional + Number of items to generate, by default 5 + duration : float, optional + Duration of each item, by default 1.0 + min_note : str, optional + Minimum note in the dataset, by default "C2" + max_note : str, optional + Maximum note in the dataset, by default "C6" + output_dir : Path, optional + Directory to save the dataset, by default "chords" + + """ + import librosa + from . import AudioSignal + from ..data.preprocess import create_csv + + min_midi = librosa.note_to_midi(min_note) + max_midi = librosa.note_to_midi(max_note) + + tracks = [] + for idx in range(num_items): + track = {} + # figure out how many voices to put in this track + num_voices = random.randint(1, max_voices) + for voice_idx in range(num_voices): + # choose some random params + midinote = random.randint(min_midi, max_midi) + dur = random.uniform(0.85 * duration, duration) + + sig = AudioSignal.wave( + frequency=librosa.midi_to_hz(midinote), + duration=dur, + sample_rate=sample_rate, + shape="sine", + ) + track[f"voice_{voice_idx}"] = sig + tracks.append(track) + + # save the tracks to disk + output_dir = Path(output_dir) + output_dir.mkdir(exist_ok=True) + for idx, track in enumerate(tracks): + track_dir = output_dir / f"track_{idx}" + track_dir.mkdir(exist_ok=True) + for voice_name, sig in track.items(): + sig.write(track_dir / f"{voice_name}.wav") + + all_voices = list(set([k for track in tracks for k in track.keys()])) + voice_lists = {voice: [] for voice in all_voices} + for track in tracks: + for voice_name in all_voices: + if voice_name in track: + voice_lists[voice_name].append(track[voice_name].path_to_file) + else: + voice_lists[voice_name].append("") + + for voice_name, paths in voice_lists.items(): + create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True) + + return output_dir diff --git a/audiotools/core/whisper.py b/audiotools/core/whisper.py new file mode 100644 index 0000000000000000000000000000000000000000..46c071f934fc3e2be3138e7596b1c6d2ef79eade --- /dev/null +++ b/audiotools/core/whisper.py @@ -0,0 +1,97 @@ +import torch + + +class WhisperMixin: + is_initialized = False + + def setup_whisper( + self, + pretrained_model_name_or_path: str = "openai/whisper-base.en", + device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"), + ): + from transformers import WhisperForConditionalGeneration + from transformers import WhisperProcessor + + self.whisper_device = device + self.whisper_processor = WhisperProcessor.from_pretrained( + pretrained_model_name_or_path + ) + self.whisper_model = WhisperForConditionalGeneration.from_pretrained( + pretrained_model_name_or_path + ).to(self.whisper_device) + self.is_initialized = True + + def get_whisper_features(self) -> torch.Tensor: + """Preprocess audio signal as per the whisper model's training config. + + Returns + ------- + torch.Tensor + The prepinput features of the audio signal. Shape: (1, channels, seq_len) + """ + import torch + + if not self.is_initialized: + self.setup_whisper() + + signal = self.to(self.device) + raw_speech = list( + ( + signal.clone() + .resample(self.whisper_processor.feature_extractor.sampling_rate) + .audio_data[:, 0, :] + .numpy() + ) + ) + + with torch.inference_mode(): + input_features = self.whisper_processor( + raw_speech, + sampling_rate=self.whisper_processor.feature_extractor.sampling_rate, + return_tensors="pt", + ).input_features + + return input_features + + def get_whisper_transcript(self) -> str: + """Get the transcript of the audio signal using the whisper model. + + Returns + ------- + str + The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>. + """ + + if not self.is_initialized: + self.setup_whisper() + + input_features = self.get_whisper_features() + + with torch.inference_mode(): + input_features = input_features.to(self.whisper_device) + generated_ids = self.whisper_model.generate(inputs=input_features) + + transcription = self.whisper_processor.batch_decode(generated_ids) + return transcription[0] + + def get_whisper_embeddings(self) -> torch.Tensor: + """Get the last hidden state embeddings of the audio signal using the whisper model. + + Returns + ------- + torch.Tensor + The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size) + """ + import torch + + if not self.is_initialized: + self.setup_whisper() + + input_features = self.get_whisper_features() + encoder = self.whisper_model.get_encoder() + + with torch.inference_mode(): + input_features = input_features.to(self.whisper_device) + embeddings = encoder(input_features) + + return embeddings.last_hidden_state diff --git a/audiotools/data/__init__.py b/audiotools/data/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..aead269f26f3782043e68418b4c87ee323cbd015 --- /dev/null +++ b/audiotools/data/__init__.py @@ -0,0 +1,3 @@ +from . import datasets +from . import preprocess +from . import transforms diff --git a/audiotools/data/datasets.py b/audiotools/data/datasets.py new file mode 100644 index 0000000000000000000000000000000000000000..12e7a60963399aa15ff865de2d06537818ce18ee --- /dev/null +++ b/audiotools/data/datasets.py @@ -0,0 +1,517 @@ +from pathlib import Path +from typing import Callable +from typing import Dict +from typing import List +from typing import Union + +import numpy as np +from torch.utils.data import SequentialSampler +from torch.utils.data.distributed import DistributedSampler + +from ..core import AudioSignal +from ..core import util + + +class AudioLoader: + """Loads audio endlessly from a list of audio sources + containing paths to audio files. Audio sources can be + folders full of audio files (which are found via file + extension) or by providing a CSV file which contains paths + to audio files. + + Parameters + ---------- + sources : List[str], optional + Sources containing folders, or CSVs with + paths to audio files, by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + relative_path : str, optional + Path audio should be loaded relative to, by default "" + transform : Callable, optional + Transform to instantiate alongside audio sample, + by default None + ext : List[str] + List of extensions to find audio within each source by. Can + also be a file name (e.g. "vocals.wav"). by default + ``['.wav', '.flac', '.mp3', '.mp4']``. + shuffle: bool + Whether to shuffle the files within the dataloader. Defaults to True. + shuffle_state: int + State to use to seed the shuffle of the files. + """ + + def __init__( + self, + sources: List[str] = None, + weights: List[float] = None, + transform: Callable = None, + relative_path: str = "", + ext: List[str] = util.AUDIO_EXTENSIONS, + shuffle: bool = True, + shuffle_state: int = 0, + ): + self.audio_lists = util.read_sources( + sources, relative_path=relative_path, ext=ext + ) + + self.audio_indices = [ + (src_idx, item_idx) + for src_idx, src in enumerate(self.audio_lists) + for item_idx in range(len(src)) + ] + if shuffle: + state = util.random_state(shuffle_state) + state.shuffle(self.audio_indices) + + self.sources = sources + self.weights = weights + self.transform = transform + + def __call__( + self, + state, + sample_rate: int, + duration: float, + loudness_cutoff: float = -40, + num_channels: int = 1, + offset: float = None, + source_idx: int = None, + item_idx: int = None, + global_idx: int = None, + ): + if source_idx is not None and item_idx is not None: + try: + audio_info = self.audio_lists[source_idx][item_idx] + except: + audio_info = {"path": "none"} + elif global_idx is not None: + source_idx, item_idx = self.audio_indices[ + global_idx % len(self.audio_indices) + ] + audio_info = self.audio_lists[source_idx][item_idx] + else: + audio_info, source_idx, item_idx = util.choose_from_list_of_lists( + state, self.audio_lists, p=self.weights + ) + + path = audio_info["path"] + signal = AudioSignal.zeros(duration, sample_rate, num_channels) + + if path != "none": + if offset is None: + signal = AudioSignal.salient_excerpt( + path, + duration=duration, + state=state, + loudness_cutoff=loudness_cutoff, + ) + else: + signal = AudioSignal( + path, + offset=offset, + duration=duration, + ) + + if num_channels == 1: + signal = signal.to_mono() + signal = signal.resample(sample_rate) + + if signal.duration < duration: + signal = signal.zero_pad_to(int(duration * sample_rate)) + + for k, v in audio_info.items(): + signal.metadata[k] = v + + item = { + "signal": signal, + "source_idx": source_idx, + "item_idx": item_idx, + "source": str(self.sources[source_idx]), + "path": str(path), + } + if self.transform is not None: + item["transform_args"] = self.transform.instantiate(state, signal=signal) + return item + + +def default_matcher(x, y): + return Path(x).parent == Path(y).parent + + +def align_lists(lists, matcher: Callable = default_matcher): + longest_list = lists[np.argmax([len(l) for l in lists])] + for i, x in enumerate(longest_list): + for l in lists: + if i >= len(l): + l.append({"path": "none"}) + elif not matcher(l[i]["path"], x["path"]): + l.insert(i, {"path": "none"}) + return lists + + +class AudioDataset: + """Loads audio from multiple loaders (with associated transforms) + for a specified number of samples. Excerpts are drawn randomly + of the specified duration, above a specified loudness threshold + and are resampled on the fly to the desired sample rate + (if it is different from the audio source sample rate). + + This takes either a single AudioLoader object, + a dictionary of AudioLoader objects, or a dictionary of AudioLoader + objects. Each AudioLoader is called by the dataset, and the + result is placed in the output dictionary. A transform can also be + specified for the entire dataset, rather than for each specific + loader. This transform can be applied to the output of all the + loaders if desired. + + AudioLoader objects can be specified as aligned, which means the + loaders correspond to multitrack audio (e.g. a vocals, bass, + drums, and other loader for multitrack music mixtures). + + + Parameters + ---------- + loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]] + AudioLoaders to sample audio from. + sample_rate : int + Desired sample rate. + n_examples : int, optional + Number of examples (length of dataset), by default 1000 + duration : float, optional + Duration of audio samples, by default 0.5 + loudness_cutoff : float, optional + Loudness cutoff threshold for audio samples, by default -40 + num_channels : int, optional + Number of channels in output audio, by default 1 + transform : Callable, optional + Transform to instantiate alongside each dataset item, by default None + aligned : bool, optional + Whether the loaders should be sampled in an aligned manner (e.g. same + offset, duration, and matched file name), by default False + shuffle_loaders : bool, optional + Whether to shuffle the loaders before sampling from them, by default False + matcher : Callable + How to match files from adjacent audio lists (e.g. for a multitrack audio loader), + by default uses the parent directory of each file. + without_replacement : bool + Whether to choose files with or without replacement, by default True. + + + Examples + -------- + >>> from audiotools.data.datasets import AudioLoader + >>> from audiotools.data.datasets import AudioDataset + >>> from audiotools import transforms as tfm + >>> import numpy as np + >>> + >>> loaders = [ + >>> AudioLoader( + >>> sources=[f"tests/audio/spk"], + >>> transform=tfm.Equalizer(), + >>> ext=["wav"], + >>> ) + >>> for i in range(5) + >>> ] + >>> + >>> dataset = AudioDataset( + >>> loaders = loaders, + >>> sample_rate = 44100, + >>> duration = 1.0, + >>> transform = tfm.RescaleAudio(), + >>> ) + >>> + >>> item = dataset[np.random.randint(len(dataset))] + >>> + >>> for i in range(len(loaders)): + >>> item[i]["signal"] = loaders[i].transform( + >>> item[i]["signal"], **item[i]["transform_args"] + >>> ) + >>> item[i]["signal"].widget(i) + >>> + >>> mix = sum([item[i]["signal"] for i in range(len(loaders))]) + >>> mix = dataset.transform(mix, **item["transform_args"]) + >>> mix.widget("mix") + + Below is an example of how one could load MUSDB multitrack data: + + >>> import audiotools as at + >>> from pathlib import Path + >>> from audiotools import transforms as tfm + >>> import numpy as np + >>> import torch + >>> + >>> def build_dataset( + >>> sample_rate: int = 44100, + >>> duration: float = 5.0, + >>> musdb_path: str = "~/.data/musdb/", + >>> ): + >>> musdb_path = Path(musdb_path).expanduser() + >>> loaders = { + >>> src: at.datasets.AudioLoader( + >>> sources=[musdb_path], + >>> transform=tfm.Compose( + >>> tfm.VolumeNorm(("uniform", -20, -10)), + >>> tfm.Silence(prob=0.1), + >>> ), + >>> ext=[f"{src}.wav"], + >>> ) + >>> for src in ["vocals", "bass", "drums", "other"] + >>> } + >>> + >>> dataset = at.datasets.AudioDataset( + >>> loaders=loaders, + >>> sample_rate=sample_rate, + >>> duration=duration, + >>> num_channels=1, + >>> aligned=True, + >>> transform=tfm.RescaleAudio(), + >>> shuffle_loaders=True, + >>> ) + >>> return dataset, list(loaders.keys()) + >>> + >>> train_data, sources = build_dataset() + >>> dataloader = torch.utils.data.DataLoader( + >>> train_data, + >>> batch_size=16, + >>> num_workers=0, + >>> collate_fn=train_data.collate, + >>> ) + >>> batch = next(iter(dataloader)) + >>> + >>> for k in sources: + >>> src = batch[k] + >>> src["transformed"] = train_data.loaders[k].transform( + >>> src["signal"].clone(), **src["transform_args"] + >>> ) + >>> + >>> mixture = sum(batch[k]["transformed"] for k in sources) + >>> mixture = train_data.transform(mixture, **batch["transform_args"]) + >>> + >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time). + >>> # Construct the targets: + >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1) + + Similarly, here's example code for loading Slakh data: + + >>> import audiotools as at + >>> from pathlib import Path + >>> from audiotools import transforms as tfm + >>> import numpy as np + >>> import torch + >>> import glob + >>> + >>> def build_dataset( + >>> sample_rate: int = 16000, + >>> duration: float = 10.0, + >>> slakh_path: str = "~/.data/slakh/", + >>> ): + >>> slakh_path = Path(slakh_path).expanduser() + >>> + >>> # Find the max number of sources in Slakh + >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)] + >>> n_sources = len(list(set(src_names))) + >>> + >>> loaders = { + >>> f"S{i:02d}": at.datasets.AudioLoader( + >>> sources=[slakh_path], + >>> transform=tfm.Compose( + >>> tfm.VolumeNorm(("uniform", -20, -10)), + >>> tfm.Silence(prob=0.1), + >>> ), + >>> ext=[f"S{i:02d}.wav"], + >>> ) + >>> for i in range(n_sources) + >>> } + >>> dataset = at.datasets.AudioDataset( + >>> loaders=loaders, + >>> sample_rate=sample_rate, + >>> duration=duration, + >>> num_channels=1, + >>> aligned=True, + >>> transform=tfm.RescaleAudio(), + >>> shuffle_loaders=False, + >>> ) + >>> + >>> return dataset, list(loaders.keys()) + >>> + >>> train_data, sources = build_dataset() + >>> dataloader = torch.utils.data.DataLoader( + >>> train_data, + >>> batch_size=16, + >>> num_workers=0, + >>> collate_fn=train_data.collate, + >>> ) + >>> batch = next(iter(dataloader)) + >>> + >>> for k in sources: + >>> src = batch[k] + >>> src["transformed"] = train_data.loaders[k].transform( + >>> src["signal"].clone(), **src["transform_args"] + >>> ) + >>> + >>> mixture = sum(batch[k]["transformed"] for k in sources) + >>> mixture = train_data.transform(mixture, **batch["transform_args"]) + + """ + + def __init__( + self, + loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]], + sample_rate: int, + n_examples: int = 1000, + duration: float = 0.5, + offset: float = None, + loudness_cutoff: float = -40, + num_channels: int = 1, + transform: Callable = None, + aligned: bool = False, + shuffle_loaders: bool = False, + matcher: Callable = default_matcher, + without_replacement: bool = True, + ): + # Internally we convert loaders to a dictionary + if isinstance(loaders, list): + loaders = {i: l for i, l in enumerate(loaders)} + elif isinstance(loaders, AudioLoader): + loaders = {0: loaders} + + self.loaders = loaders + self.loudness_cutoff = loudness_cutoff + self.num_channels = num_channels + + self.length = n_examples + self.transform = transform + self.sample_rate = sample_rate + self.duration = duration + self.offset = offset + self.aligned = aligned + self.shuffle_loaders = shuffle_loaders + self.without_replacement = without_replacement + + if aligned: + loaders_list = list(loaders.values()) + for i in range(len(loaders_list[0].audio_lists)): + input_lists = [l.audio_lists[i] for l in loaders_list] + # Alignment happens in-place + align_lists(input_lists, matcher) + + def __getitem__(self, idx): + state = util.random_state(idx) + offset = None if self.offset is None else self.offset + item = {} + + keys = list(self.loaders.keys()) + if self.shuffle_loaders: + state.shuffle(keys) + + loader_kwargs = { + "state": state, + "sample_rate": self.sample_rate, + "duration": self.duration, + "loudness_cutoff": self.loudness_cutoff, + "num_channels": self.num_channels, + "global_idx": idx if self.without_replacement else None, + } + + # Draw item from first loader + loader = self.loaders[keys[0]] + item[keys[0]] = loader(**loader_kwargs) + + for key in keys[1:]: + loader = self.loaders[key] + if self.aligned: + # Path mapper takes the current loader + everything + # returned by the first loader. + offset = item[keys[0]]["signal"].metadata["offset"] + loader_kwargs.update( + { + "offset": offset, + "source_idx": item[keys[0]]["source_idx"], + "item_idx": item[keys[0]]["item_idx"], + } + ) + item[key] = loader(**loader_kwargs) + + # Sort dictionary back into original order + keys = list(self.loaders.keys()) + item = {k: item[k] for k in keys} + + item["idx"] = idx + if self.transform is not None: + item["transform_args"] = self.transform.instantiate( + state=state, signal=item[keys[0]]["signal"] + ) + + # If there's only one loader, pop it up + # to the main dictionary, instead of keeping it + # nested. + if len(keys) == 1: + item.update(item.pop(keys[0])) + + return item + + def __len__(self): + return self.length + + @staticmethod + def collate(list_of_dicts: Union[list, dict], n_splits: int = None): + """Collates items drawn from this dataset. Uses + :py:func:`audiotools.core.util.collate`. + + Parameters + ---------- + list_of_dicts : typing.Union[list, dict] + Data drawn from each item. + n_splits : int + Number of splits to make when creating the batches (split into + sub-batches). Useful for things like gradient accumulation. + + Returns + ------- + dict + Dictionary of batched data. + """ + return util.collate(list_of_dicts, n_splits=n_splits) + + +class ConcatDataset(AudioDataset): + def __init__(self, datasets: list): + self.datasets = datasets + + def __len__(self): + return sum([len(d) for d in self.datasets]) + + def __getitem__(self, idx): + dataset = self.datasets[idx % len(self.datasets)] + return dataset[idx // len(self.datasets)] + + +class ResumableDistributedSampler(DistributedSampler): # pragma: no cover + """Distributed sampler that can be resumed from a given start index.""" + + def __init__(self, dataset, start_idx: int = None, **kwargs): + super().__init__(dataset, **kwargs) + # Start index, allows to resume an experiment at the index it was + self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0 + + def __iter__(self): + for i, idx in enumerate(super().__iter__()): + if i >= self.start_idx: + yield idx + self.start_idx = 0 # set the index back to 0 so for the next epoch + + +class ResumableSequentialSampler(SequentialSampler): # pragma: no cover + """Sequential sampler that can be resumed from a given start index.""" + + def __init__(self, dataset, start_idx: int = None, **kwargs): + super().__init__(dataset, **kwargs) + # Start index, allows to resume an experiment at the index it was + self.start_idx = start_idx if start_idx is not None else 0 + + def __iter__(self): + for i, idx in enumerate(super().__iter__()): + if i >= self.start_idx: + yield idx + self.start_idx = 0 # set the index back to 0 so for the next epoch diff --git a/audiotools/data/preprocess.py b/audiotools/data/preprocess.py new file mode 100644 index 0000000000000000000000000000000000000000..d90de210115e45838bc8d69b350f7516ba730406 --- /dev/null +++ b/audiotools/data/preprocess.py @@ -0,0 +1,81 @@ +import csv +import os +from pathlib import Path + +from tqdm import tqdm + +from ..core import AudioSignal + + +def create_csv( + audio_files: list, output_csv: Path, loudness: bool = False, data_path: str = None +): + """Converts a folder of audio files to a CSV file. If ``loudness = True``, + the output of this function will create a CSV file that looks something + like: + + .. csv-table:: + :header: path,loudness + + daps/produced/f1_script1_produced.wav,-16.299999237060547 + daps/produced/f1_script2_produced.wav,-16.600000381469727 + daps/produced/f1_script3_produced.wav,-17.299999237060547 + daps/produced/f1_script4_produced.wav,-16.100000381469727 + daps/produced/f1_script5_produced.wav,-16.700000762939453 + daps/produced/f3_script1_produced.wav,-16.5 + + .. note:: + The paths above are written relative to the ``data_path`` argument + which defaults to the environment variable ``PATH_TO_DATA`` if + it isn't passed to this function, and defaults to the empty string + if that environment variable is not set. + + You can produce a CSV file from a directory of audio files via: + + >>> import audiotools + >>> directory = ... + >>> audio_files = audiotools.util.find_audio(directory) + >>> output_path = "train.csv" + >>> audiotools.data.preprocess.create_csv( + >>> audio_files, output_csv, loudness=True + >>> ) + + Note that you can create empty rows in the CSV file by passing an empty + string or None in the ``audio_files`` list. This is useful if you want to + sync multiple CSV files in a multitrack setting. The loudness of these + empty rows will be set to -inf. + + Parameters + ---------- + audio_files : list + List of audio files. + output_csv : Path + Output CSV, with each row containing the relative path of every file + to ``data_path``, if specified (defaults to None). + loudness : bool + Compute loudness of entire file and store alongside path. + """ + + info = [] + pbar = tqdm(audio_files) + for af in pbar: + af = Path(af) + pbar.set_description(f"Processing {af.name}") + _info = {} + if af.name == "": + _info["path"] = "" + if loudness: + _info["loudness"] = -float("inf") + else: + _info["path"] = af.relative_to(data_path) if data_path is not None else af + if loudness: + _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item() + + info.append(_info) + + with open(output_csv, "w") as f: + writer = csv.DictWriter(f, fieldnames=list(info[0].keys())) + writer.writeheader() + + for item in info: + writer.writerow(item) diff --git a/audiotools/data/transforms.py b/audiotools/data/transforms.py new file mode 100644 index 0000000000000000000000000000000000000000..504e87dc61777e36ba95eb794f497bed4cdc7d2c --- /dev/null +++ b/audiotools/data/transforms.py @@ -0,0 +1,1592 @@ +import copy +from contextlib import contextmanager +from inspect import signature +from typing import List + +import numpy as np +import torch +from flatten_dict import flatten +from flatten_dict import unflatten +from numpy.random import RandomState + +from .. import ml +from ..core import AudioSignal +from ..core import util +from .datasets import AudioLoader + +tt = torch.tensor +"""Shorthand for converting things to torch.tensor.""" + + +class BaseTransform: + """This is the base class for all transforms that are implemented + in this library. Transforms have two main operations: ``transform`` + and ``instantiate``. + + ``instantiate`` sets the parameters randomly + from distribution tuples for each parameter. For example, for the + ``BackgroundNoise`` transform, the signal-to-noise ratio (``snr``) + is chosen randomly by instantiate. By default, it chosen uniformly + between 10.0 and 30.0 (the tuple is set to ``("uniform", 10.0, 30.0)``). + + ``transform`` applies the transform using the instantiated parameters. + A simple example is as follows: + + >>> seed = 0 + >>> signal = ... + >>> transform = transforms.NoiseFloor(db = ("uniform", -50.0, -30.0)) + >>> kwargs = transform.instantiate() + >>> output = transform(signal.clone(), **kwargs) + + By breaking apart the instantiation of parameters from the actual audio + processing of the transform, we can make things more reproducible, while + also applying the transform on batches of data efficiently on GPU, + rather than on individual audio samples. + + .. note:: + We call ``signal.clone()`` for the input to the ``transform`` function + because signals are modified in-place! If you don't clone the signal, + you will lose the original data. + + Parameters + ---------- + keys : list, optional + Keys that the transform looks for when + calling ``self.transform``, by default []. In general this is + set automatically, and you won't need to manipulate this argument. + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + + Examples + -------- + + >>> seed = 0 + >>> + >>> audio_path = "tests/audio/spk/f10_script4_produced.wav" + >>> signal = AudioSignal(audio_path, offset=10, duration=2) + >>> transform = tfm.Compose( + >>> [ + >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), + >>> ], + >>> ) + >>> + >>> kwargs = transform.instantiate(seed, signal) + >>> output = transform(signal, **kwargs) + + """ + + def __init__(self, keys: list = [], name: str = None, prob: float = 1.0): + # Get keys from the _transform signature. + tfm_keys = list(signature(self._transform).parameters.keys()) + + # Filter out signal and kwargs keys. + ignore_keys = ["signal", "kwargs"] + tfm_keys = [k for k in tfm_keys if k not in ignore_keys] + + # Combine keys specified by the child class, the keys found in + # _transform signature, and the mask key. + self.keys = keys + tfm_keys + ["mask"] + + self.prob = prob + + if name is None: + name = self.__class__.__name__ + self.name = name + + def _prepare(self, batch: dict): + sub_batch = batch[self.name] + + for k in self.keys: + assert k in sub_batch.keys(), f"{k} not in batch" + + return sub_batch + + def _transform(self, signal): + return signal + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + return {} + + @staticmethod + def apply_mask(batch: dict, mask: torch.Tensor): + """Applies a mask to the batch. + + Parameters + ---------- + batch : dict + Batch whose values will be masked in the ``transform`` pass. + mask : torch.Tensor + Mask to apply to batch. + + Returns + ------- + dict + A dictionary that contains values only where ``mask = True``. + """ + masked_batch = {k: v[mask] for k, v in flatten(batch).items()} + return unflatten(masked_batch) + + def transform(self, signal: AudioSignal, **kwargs): + """Apply the transform to the audio signal, + with given keyword arguments. + + Parameters + ---------- + signal : AudioSignal + Signal that will be modified by the transforms in-place. + kwargs: dict + Keyword arguments to the specific transforms ``self._transform`` + function. + + Returns + ------- + AudioSignal + Transformed AudioSignal. + + Examples + -------- + + >>> for seed in range(10): + >>> kwargs = transform.instantiate(seed, signal) + >>> output = transform(signal.clone(), **kwargs) + + """ + tfm_kwargs = self._prepare(kwargs) + mask = tfm_kwargs["mask"] + + if torch.any(mask): + tfm_kwargs = self.apply_mask(tfm_kwargs, mask) + tfm_kwargs = {k: v for k, v in tfm_kwargs.items() if k != "mask"} + signal[mask] = self._transform(signal[mask], **tfm_kwargs) + + return signal + + def __call__(self, *args, **kwargs): + return self.transform(*args, **kwargs) + + def instantiate( + self, + state: RandomState = None, + signal: AudioSignal = None, + ): + """Instantiates parameters for the transform. + + Parameters + ---------- + state : RandomState, optional + _description_, by default None + signal : AudioSignal, optional + _description_, by default None + + Returns + ------- + dict + Dictionary containing instantiated arguments for every keyword + argument to ``self._transform``. + + Examples + -------- + + >>> for seed in range(10): + >>> kwargs = transform.instantiate(seed, signal) + >>> output = transform(signal.clone(), **kwargs) + + """ + state = util.random_state(state) + + # Not all instantiates need the signal. Check if signal + # is needed before passing it in, so that the end-user + # doesn't need to have variables they're not using flowing + # into their function. + needs_signal = "signal" in set(signature(self._instantiate).parameters.keys()) + kwargs = {} + if needs_signal: + kwargs = {"signal": signal} + + # Instantiate the parameters for the transform. + params = self._instantiate(state, **kwargs) + for k in list(params.keys()): + v = params[k] + if isinstance(v, (AudioSignal, torch.Tensor, dict)): + params[k] = v + else: + params[k] = tt(v) + mask = state.rand() <= self.prob + params[f"mask"] = tt(mask) + + # Put the params into a nested dictionary that will be + # used later when calling the transform. This is to avoid + # collisions in the dictionary. + params = {self.name: params} + + return params + + def batch_instantiate( + self, + states: list = None, + signal: AudioSignal = None, + ): + """Instantiates arguments for every item in a batch, + given a list of states. Each state in the list + corresponds to one item in the batch. + + Parameters + ---------- + states : list, optional + List of states, by default None + signal : AudioSignal, optional + AudioSignal to pass to the ``self.instantiate`` section + if it is needed for this transform, by default None + + Returns + ------- + dict + Collated dictionary of arguments. + + Examples + -------- + + >>> batch_size = 4 + >>> signal = AudioSignal(audio_path, offset=10, duration=2) + >>> signal_batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)]) + >>> + >>> states = [seed + idx for idx in list(range(batch_size))] + >>> kwargs = transform.batch_instantiate(states, signal_batch) + >>> batch_output = transform(signal_batch, **kwargs) + """ + kwargs = [] + for state in states: + kwargs.append(self.instantiate(state, signal)) + kwargs = util.collate(kwargs) + return kwargs + + +class Identity(BaseTransform): + """This transform just returns the original signal.""" + + pass + + +class SpectralTransform(BaseTransform): + """Spectral transforms require STFT data to exist, since manipulations + of the STFT require the spectrogram. This just calls ``stft`` before + the transform is called, and calls ``istft`` after the transform is + called so that the audio data is written to after the spectral + manipulation. + """ + + def transform(self, signal, **kwargs): + signal.stft() + super().transform(signal, **kwargs) + signal.istft() + return signal + + +class Compose(BaseTransform): + """Compose applies transforms in sequence, one after the other. The + transforms are passed in as positional arguments or as a list like so: + + >>> transform = tfm.Compose( + >>> [ + >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]), + >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]), + >>> ], + >>> ) + + This will convolve the signal with a room impulse response, and then + add background noise to the signal. Instantiate instantiates + all the parameters for every transform in the transform list so the + interface for using the Compose transform is the same as everything + else: + + >>> kwargs = transform.instantiate() + >>> output = transform(signal.clone(), **kwargs) + + Under the hood, the transform maps each transform to a unique name + under the hood of the form ``{position}.{name}``, where ``position`` + is the index of the transform in the list. ``Compose`` can nest + within other ``Compose`` transforms, like so: + + >>> preprocess = transforms.Compose( + >>> tfm.GlobalVolumeNorm(), + >>> tfm.CrossTalk(), + >>> name="preprocess", + >>> ) + >>> augment = transforms.Compose( + >>> tfm.RoomImpulseResponse(), + >>> tfm.BackgroundNoise(), + >>> name="augment", + >>> ) + >>> postprocess = transforms.Compose( + >>> tfm.VolumeChange(), + >>> tfm.RescaleAudio(), + >>> tfm.ShiftPhase(), + >>> name="postprocess", + >>> ) + >>> transform = transforms.Compose(preprocess, augment, postprocess), + + This defines 3 composed transforms, and then composes them in sequence + with one another. + + Parameters + ---------- + *transforms : list + List of transforms to apply + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__(self, *transforms: list, name: str = None, prob: float = 1.0): + if isinstance(transforms[0], list): + transforms = transforms[0] + + for i, tfm in enumerate(transforms): + tfm.name = f"{i}.{tfm.name}" + + keys = [tfm.name for tfm in transforms] + super().__init__(keys=keys, name=name, prob=prob) + + self.transforms = transforms + self.transforms_to_apply = keys + + @contextmanager + def filter(self, *names: list): + """This can be used to skip transforms entirely when applying + the sequence of transforms to a signal. For example, take + the following transforms with the names ``preprocess, augment, postprocess``. + + >>> preprocess = transforms.Compose( + >>> tfm.GlobalVolumeNorm(), + >>> tfm.CrossTalk(), + >>> name="preprocess", + >>> ) + >>> augment = transforms.Compose( + >>> tfm.RoomImpulseResponse(), + >>> tfm.BackgroundNoise(), + >>> name="augment", + >>> ) + >>> postprocess = transforms.Compose( + >>> tfm.VolumeChange(), + >>> tfm.RescaleAudio(), + >>> tfm.ShiftPhase(), + >>> name="postprocess", + >>> ) + >>> transform = transforms.Compose(preprocess, augment, postprocess) + + If we wanted to apply all 3 to a signal, we do: + + >>> kwargs = transform.instantiate() + >>> output = transform(signal.clone(), **kwargs) + + But if we only wanted to apply the ``preprocess`` and ``postprocess`` + transforms to the signal, we do: + + >>> with transform_fn.filter("preprocess", "postprocess"): + >>> output = transform(signal.clone(), **kwargs) + + Parameters + ---------- + *names : list + List of transforms, identified by name, to apply to signal. + """ + old_transforms = self.transforms_to_apply + self.transforms_to_apply = names + yield + self.transforms_to_apply = old_transforms + + def _transform(self, signal, **kwargs): + for transform in self.transforms: + if any([x in transform.name for x in self.transforms_to_apply]): + signal = transform(signal, **kwargs) + return signal + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + parameters = {} + for transform in self.transforms: + parameters.update(transform.instantiate(state, signal=signal)) + return parameters + + def __getitem__(self, idx): + return self.transforms[idx] + + def __len__(self): + return len(self.transforms) + + def __iter__(self): + for transform in self.transforms: + yield transform + + +class Choose(Compose): + """Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`, + but instead of applying all the transforms in sequence, it applies just a single transform, + which is chosen for each item in the batch. + + Parameters + ---------- + *transforms : list + List of transforms to apply + weights : list + Probability of choosing any specific transform. + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + + Examples + -------- + + >>> transforms.Choose(tfm.LowPass(), tfm.HighPass()) + """ + + def __init__( + self, + *transforms: list, + weights: list = None, + name: str = None, + prob: float = 1.0, + ): + super().__init__(*transforms, name=name, prob=prob) + + if weights is None: + _len = len(self.transforms) + weights = [1 / _len for _ in range(_len)] + self.weights = np.array(weights) + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + kwargs = super()._instantiate(state, signal) + tfm_idx = list(range(len(self.transforms))) + tfm_idx = state.choice(tfm_idx, p=self.weights) + one_hot = [] + for i, t in enumerate(self.transforms): + mask = kwargs[t.name]["mask"] + if mask.item(): + kwargs[t.name]["mask"] = tt(i == tfm_idx) + one_hot.append(kwargs[t.name]["mask"]) + kwargs["one_hot"] = one_hot + return kwargs + + +class Repeat(Compose): + """Repeatedly applies a given transform ``n_repeat`` times." + + Parameters + ---------- + transform : BaseTransform + Transform to repeat. + n_repeat : int, optional + Number of times to repeat transform, by default 1 + """ + + def __init__( + self, + transform, + n_repeat: int = 1, + name: str = None, + prob: float = 1.0, + ): + transforms = [copy.copy(transform) for _ in range(n_repeat)] + super().__init__(transforms, name=name, prob=prob) + + self.n_repeat = n_repeat + + +class RepeatUpTo(Choose): + """Repeatedly applies a given transform up to ``max_repeat`` times." + + Parameters + ---------- + transform : BaseTransform + Transform to repeat. + max_repeat : int, optional + Max number of times to repeat transform, by default 1 + weights : list + Probability of choosing any specific number up to ``max_repeat``. + """ + + def __init__( + self, + transform, + max_repeat: int = 5, + weights: list = None, + name: str = None, + prob: float = 1.0, + ): + transforms = [] + for n in range(1, max_repeat): + transforms.append(Repeat(transform, n_repeat=n)) + super().__init__(transforms, name=name, prob=prob, weights=weights) + + self.max_repeat = max_repeat + + +class ClippingDistortion(BaseTransform): + """Adds clipping distortion to signal. Corresponds + to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`. + + Parameters + ---------- + perc : tuple, optional + Clipping percentile. Values are between 0.0 to 1.0. + Typical values are 0.1 or below, by default ("uniform", 0.0, 0.1) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + perc: tuple = ("uniform", 0.0, 0.1), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.perc = perc + + def _instantiate(self, state: RandomState): + return {"perc": util.sample_from_dist(self.perc, state)} + + def _transform(self, signal, perc): + return signal.clip_distortion(perc) + + +class Equalizer(BaseTransform): + """Applies an equalization curve to the audio signal. Corresponds + to :py:func:`audiotools.core.effects.EffectMixin.equalizer`. + + Parameters + ---------- + eq_amount : tuple, optional + The maximum dB cut to apply to the audio in any band, + by default ("const", 1.0 dB) + n_bands : int, optional + Number of bands in EQ, by default 6 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + eq_amount: tuple = ("const", 1.0), + n_bands: int = 6, + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.eq_amount = eq_amount + self.n_bands = n_bands + + def _instantiate(self, state: RandomState): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + return {"eq": eq} + + def _transform(self, signal, eq): + return signal.equalizer(eq) + + +class Quantization(BaseTransform): + """Applies quantization to the input waveform. Corresponds + to :py:func:`audiotools.core.effects.EffectMixin.quantization`. + + Parameters + ---------- + channels : tuple, optional + Number of evenly spaced quantization channels to quantize + to, by default ("choice", [8, 32, 128, 256, 1024]) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + channels: tuple = ("choice", [8, 32, 128, 256, 1024]), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.channels = channels + + def _instantiate(self, state: RandomState): + return {"channels": util.sample_from_dist(self.channels, state)} + + def _transform(self, signal, channels): + return signal.quantization(channels) + + +class MuLawQuantization(BaseTransform): + """Applies mu-law quantization to the input waveform. Corresponds + to :py:func:`audiotools.core.effects.EffectMixin.mulaw_quantization`. + + Parameters + ---------- + channels : tuple, optional + Number of mu-law spaced quantization channels to quantize + to, by default ("choice", [8, 32, 128, 256, 1024]) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + channels: tuple = ("choice", [8, 32, 128, 256, 1024]), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.channels = channels + + def _instantiate(self, state: RandomState): + return {"channels": util.sample_from_dist(self.channels, state)} + + def _transform(self, signal, channels): + return signal.mulaw_quantization(channels) + + +class NoiseFloor(BaseTransform): + """Adds a noise floor of Gaussian noise to the signal at a specified + dB. + + Parameters + ---------- + db : tuple, optional + Level of noise to add to signal, by default ("const", -50.0) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db: tuple = ("const", -50.0), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.db = db + + def _instantiate(self, state: RandomState, signal: AudioSignal): + db = util.sample_from_dist(self.db, state) + audio_data = state.randn(signal.num_channels, signal.signal_length) + nz_signal = AudioSignal(audio_data, signal.sample_rate) + nz_signal.normalize(db) + return {"nz_signal": nz_signal} + + def _transform(self, signal, nz_signal): + # Clone bg_signal so that transform can be repeatedly applied + # to different signals with the same effect. + return signal + nz_signal + + +class BackgroundNoise(BaseTransform): + """Adds background noise from audio specified by a set of CSV files. + A valid CSV file looks like, and is typically generated by + :py:func:`audiotools.data.preprocess.create_csv`: + + .. csv-table:: + :header: path + + room_tone/m6_script2_clean.wav + room_tone/m6_script2_cleanraw.wav + room_tone/m6_script2_ipad_balcony1.wav + room_tone/m6_script2_ipad_bedroom1.wav + room_tone/m6_script2_ipad_confroom1.wav + room_tone/m6_script2_ipad_confroom2.wav + room_tone/m6_script2_ipad_livingroom1.wav + room_tone/m6_script2_ipad_office1.wav + + .. note:: + All paths are relative to an environment variable called ``PATH_TO_DATA``, + so that CSV files are portable across machines where data may be + located in different places. + + This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix` + and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the + hood. + + Parameters + ---------- + snr : tuple, optional + Signal-to-noise ratio, by default ("uniform", 10.0, 30.0) + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + eq_amount : tuple, optional + Amount of equalization to apply, by default ("const", 1.0) + n_bands : int, optional + Number of bands in equalizer, by default 3 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + loudness_cutoff : float, optional + Loudness cutoff when loading from audio files, by default None + """ + + def __init__( + self, + snr: tuple = ("uniform", 10.0, 30.0), + sources: List[str] = None, + weights: List[float] = None, + eq_amount: tuple = ("const", 1.0), + n_bands: int = 3, + name: str = None, + prob: float = 1.0, + loudness_cutoff: float = None, + ): + super().__init__(name=name, prob=prob) + + self.snr = snr + self.eq_amount = eq_amount + self.n_bands = n_bands + self.loader = AudioLoader(sources, weights) + self.loudness_cutoff = loudness_cutoff + + def _instantiate(self, state: RandomState, signal: AudioSignal): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + snr = util.sample_from_dist(self.snr, state) + + bg_signal = self.loader( + state, + signal.sample_rate, + duration=signal.signal_duration, + loudness_cutoff=self.loudness_cutoff, + num_channels=signal.num_channels, + )["signal"] + + return {"eq": eq, "bg_signal": bg_signal, "snr": snr} + + def _transform(self, signal, bg_signal, snr, eq): + # Clone bg_signal so that transform can be repeatedly applied + # to different signals with the same effect. + return signal.mix(bg_signal.clone(), snr, eq) + + +class CrossTalk(BaseTransform): + """Adds crosstalk between speakers, whose audio is drawn from a CSV file + that was produced via :py:func:`audiotools.data.preprocess.create_csv`. + + This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix` + under the hood. + + Parameters + ---------- + snr : tuple, optional + How loud cross-talk speaker is relative to original signal in dB, + by default ("uniform", 0.0, 10.0) + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + loudness_cutoff : float, optional + Loudness cutoff when loading from audio files, by default -40 + """ + + def __init__( + self, + snr: tuple = ("uniform", 0.0, 10.0), + sources: List[str] = None, + weights: List[float] = None, + name: str = None, + prob: float = 1.0, + loudness_cutoff: float = -40, + ): + super().__init__(name=name, prob=prob) + + self.snr = snr + self.loader = AudioLoader(sources, weights) + self.loudness_cutoff = loudness_cutoff + + def _instantiate(self, state: RandomState, signal: AudioSignal): + snr = util.sample_from_dist(self.snr, state) + crosstalk_signal = self.loader( + state, + signal.sample_rate, + duration=signal.signal_duration, + loudness_cutoff=self.loudness_cutoff, + num_channels=signal.num_channels, + )["signal"] + + return {"crosstalk_signal": crosstalk_signal, "snr": snr} + + def _transform(self, signal, crosstalk_signal, snr): + # Clone bg_signal so that transform can be repeatedly applied + # to different signals with the same effect. + loudness = signal.loudness() + mix = signal.mix(crosstalk_signal.clone(), snr) + mix.normalize(loudness) + return mix + + +class RoomImpulseResponse(BaseTransform): + """Convolves signal with a room impulse response, at a specified + direct-to-reverberant ratio, with equalization applied. Room impulse + response data is drawn from a CSV file that was produced via + :py:func:`audiotools.data.preprocess.create_csv`. + + This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir` + under the hood. + + Parameters + ---------- + drr : tuple, optional + _description_, by default ("uniform", 0.0, 30.0) + sources : List[str], optional + Sources containing folders, or CSVs with paths to audio files, + by default None + weights : List[float], optional + Weights to sample audio files from each source, by default None + eq_amount : tuple, optional + Amount of equalization to apply, by default ("const", 1.0) + n_bands : int, optional + Number of bands in equalizer, by default 6 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + use_original_phase : bool, optional + Whether or not to use the original phase, by default False + offset : float, optional + Offset from each impulse response file to use, by default 0.0 + duration : float, optional + Duration of each impulse response, by default 1.0 + """ + + def __init__( + self, + drr: tuple = ("uniform", 0.0, 30.0), + sources: List[str] = None, + weights: List[float] = None, + eq_amount: tuple = ("const", 1.0), + n_bands: int = 6, + name: str = None, + prob: float = 1.0, + use_original_phase: bool = False, + offset: float = 0.0, + duration: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.drr = drr + self.eq_amount = eq_amount + self.n_bands = n_bands + self.use_original_phase = use_original_phase + + self.loader = AudioLoader(sources, weights) + self.offset = offset + self.duration = duration + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + eq_amount = util.sample_from_dist(self.eq_amount, state) + eq = -eq_amount * state.rand(self.n_bands) + drr = util.sample_from_dist(self.drr, state) + + ir_signal = self.loader( + state, + signal.sample_rate, + offset=self.offset, + duration=self.duration, + loudness_cutoff=None, + num_channels=signal.num_channels, + )["signal"] + ir_signal.zero_pad_to(signal.sample_rate) + + return {"eq": eq, "ir_signal": ir_signal, "drr": drr} + + def _transform(self, signal, ir_signal, drr, eq): + # Clone ir_signal so that transform can be repeatedly applied + # to different signals with the same effect. + return signal.apply_ir( + ir_signal.clone(), drr, eq, use_original_phase=self.use_original_phase + ) + + +class VolumeChange(BaseTransform): + """Changes the volume of the input signal. + + Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`. + + Parameters + ---------- + db : tuple, optional + Change in volume in decibels, by default ("uniform", -12.0, 0.0) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db: tuple = ("uniform", -12.0, 0.0), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + self.db = db + + def _instantiate(self, state: RandomState): + return {"db": util.sample_from_dist(self.db, state)} + + def _transform(self, signal, db): + return signal.volume_change(db) + + +class VolumeNorm(BaseTransform): + """Normalizes the volume of the excerpt to a specified decibel. + + Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`. + + Parameters + ---------- + db : tuple, optional + dB to normalize signal to, by default ("const", -24) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db: tuple = ("const", -24), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.db = db + + def _instantiate(self, state: RandomState): + return {"db": util.sample_from_dist(self.db, state)} + + def _transform(self, signal, db): + return signal.normalize(db) + + +class GlobalVolumeNorm(BaseTransform): + """Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this + transform also normalizes the volume of a signal, but it uses + the volume of the entire audio file the loaded excerpt comes from, + rather than the volume of just the excerpt. The volume of the + entire audio file is expected in ``signal.metadata["loudness"]``. + If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv` + with ``loudness = True``, like the following: + + .. csv-table:: + :header: path,loudness + + daps/produced/f1_script1_produced.wav,-16.299999237060547 + daps/produced/f1_script2_produced.wav,-16.600000381469727 + daps/produced/f1_script3_produced.wav,-17.299999237060547 + daps/produced/f1_script4_produced.wav,-16.100000381469727 + daps/produced/f1_script5_produced.wav,-16.700000762939453 + daps/produced/f3_script1_produced.wav,-16.5 + + The ``AudioLoader`` will automatically load the loudness column into + the metadata of the signal. + + Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`. + + Parameters + ---------- + db : tuple, optional + dB to normalize signal to, by default ("const", -24) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db: tuple = ("const", -24), + name: str = None, + prob: float = 1.0, + ): + super().__init__(name=name, prob=prob) + + self.db = db + + def _instantiate(self, state: RandomState, signal: AudioSignal): + if "loudness" not in signal.metadata: + db_change = 0.0 + elif float(signal.metadata["loudness"]) == float("-inf"): + db_change = 0.0 + else: + db = util.sample_from_dist(self.db, state) + db_change = db - float(signal.metadata["loudness"]) + + return {"db": db_change} + + def _transform(self, signal, db): + return signal.volume_change(db) + + +class Silence(BaseTransform): + """Zeros out the signal with some probability. + + Parameters + ---------- + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 0.1 + """ + + def __init__(self, name: str = None, prob: float = 0.1): + super().__init__(name=name, prob=prob) + + def _transform(self, signal): + _loudness = signal._loudness + signal = AudioSignal( + torch.zeros_like(signal.audio_data), + sample_rate=signal.sample_rate, + stft_params=signal.stft_params, + ) + # So that the amound of noise added is as if it wasn't silenced. + # TODO: improve this hack + signal._loudness = _loudness + + return signal + + +class LowPass(BaseTransform): + """Applies a LowPass filter. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`. + + Parameters + ---------- + cutoff : tuple, optional + Cutoff frequency distribution, + by default ``("choice", [4000, 8000, 16000])`` + zeros : int, optional + Number of zero-crossings in filter, argument to + ``julius.LowPassFilters``, by default 51 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + cutoff: tuple = ("choice", [4000, 8000, 16000]), + zeros: int = 51, + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + + self.cutoff = cutoff + self.zeros = zeros + + def _instantiate(self, state: RandomState): + return {"cutoff": util.sample_from_dist(self.cutoff, state)} + + def _transform(self, signal, cutoff): + return signal.low_pass(cutoff, zeros=self.zeros) + + +class HighPass(BaseTransform): + """Applies a HighPass filter. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`. + + Parameters + ---------- + cutoff : tuple, optional + Cutoff frequency distribution, + by default ``("choice", [50, 100, 250, 500, 1000])`` + zeros : int, optional + Number of zero-crossings in filter, argument to + ``julius.LowPassFilters``, by default 51 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + cutoff: tuple = ("choice", [50, 100, 250, 500, 1000]), + zeros: int = 51, + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + + self.cutoff = cutoff + self.zeros = zeros + + def _instantiate(self, state: RandomState): + return {"cutoff": util.sample_from_dist(self.cutoff, state)} + + def _transform(self, signal, cutoff): + return signal.high_pass(cutoff, zeros=self.zeros) + + +class RescaleAudio(BaseTransform): + """Rescales the audio so it is in between ``-val`` and ``val`` + only if the original audio exceeds those bounds. Useful if + transforms have caused the audio to clip. + + Uses :py:func:`audiotools.core.effects.EffectMixin.ensure_max_of_audio`. + + Parameters + ---------- + val : float, optional + Max absolute value of signal, by default 1.0 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__(self, val: float = 1.0, name: str = None, prob: float = 1): + super().__init__(name=name, prob=prob) + + self.val = val + + def _transform(self, signal): + return signal.ensure_max_of_audio(self.val) + + +class ShiftPhase(SpectralTransform): + """Shifts the phase of the audio. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.shift)phase`. + + Parameters + ---------- + shift : tuple, optional + How much to shift phase by, by default ("uniform", -np.pi, np.pi) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + shift: tuple = ("uniform", -np.pi, np.pi), + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + self.shift = shift + + def _instantiate(self, state: RandomState): + return {"shift": util.sample_from_dist(self.shift, state)} + + def _transform(self, signal, shift): + return signal.shift_phase(shift) + + +class InvertPhase(ShiftPhase): + """Inverts the phase of the audio. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.shift_phase`. + + Parameters + ---------- + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__(self, name: str = None, prob: float = 1): + super().__init__(shift=("const", np.pi), name=name, prob=prob) + + +class CorruptPhase(SpectralTransform): + """Corrupts the phase of the audio. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.corrupt_phase`. + + Parameters + ---------- + scale : tuple, optional + How much to corrupt phase by, by default ("uniform", 0, np.pi) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, scale: tuple = ("uniform", 0, np.pi), name: str = None, prob: float = 1 + ): + super().__init__(name=name, prob=prob) + self.scale = scale + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + scale = util.sample_from_dist(self.scale, state) + corruption = state.normal(scale=scale, size=signal.phase.shape[1:]) + return {"corruption": corruption.astype("float32")} + + def _transform(self, signal, corruption): + return signal.shift_phase(shift=corruption) + + +class FrequencyMask(SpectralTransform): + """Masks a band of frequencies at a center frequency + from the audio. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`. + + Parameters + ---------- + f_center : tuple, optional + Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0) + f_width : tuple, optional + Width of zero'd out band, by default ("const", 0.1) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + f_center: tuple = ("uniform", 0.0, 1.0), + f_width: tuple = ("const", 0.1), + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + self.f_center = f_center + self.f_width = f_width + + def _instantiate(self, state: RandomState, signal: AudioSignal): + f_center = util.sample_from_dist(self.f_center, state) + f_width = util.sample_from_dist(self.f_width, state) + + fmin = max(f_center - (f_width / 2), 0.0) + fmax = min(f_center + (f_width / 2), 1.0) + + fmin_hz = (signal.sample_rate / 2) * fmin + fmax_hz = (signal.sample_rate / 2) * fmax + + return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz} + + def _transform(self, signal, fmin_hz: float, fmax_hz: float): + return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz) + + +class TimeMask(SpectralTransform): + """Masks out contiguous time-steps from signal. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`. + + Parameters + ---------- + t_center : tuple, optional + Center time in terms of 0.0 and 1.0 (duration of signal), + by default ("uniform", 0.0, 1.0) + t_width : tuple, optional + Width of dropped out portion, by default ("const", 0.025) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + t_center: tuple = ("uniform", 0.0, 1.0), + t_width: tuple = ("const", 0.025), + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + self.t_center = t_center + self.t_width = t_width + + def _instantiate(self, state: RandomState, signal: AudioSignal): + t_center = util.sample_from_dist(self.t_center, state) + t_width = util.sample_from_dist(self.t_width, state) + + tmin = max(t_center - (t_width / 2), 0.0) + tmax = min(t_center + (t_width / 2), 1.0) + + tmin_s = signal.signal_duration * tmin + tmax_s = signal.signal_duration * tmax + return {"tmin_s": tmin_s, "tmax_s": tmax_s} + + def _transform(self, signal, tmin_s: float, tmax_s: float): + return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s) + + +class MaskLowMagnitudes(SpectralTransform): + """Masks low magnitude regions out of signal. + + Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_low_magnitudes`. + + Parameters + ---------- + db_cutoff : tuple, optional + Decibel value for which things below it will be masked away, + by default ("uniform", -10, 10) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + db_cutoff: tuple = ("uniform", -10, 10), + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + self.db_cutoff = db_cutoff + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + return {"db_cutoff": util.sample_from_dist(self.db_cutoff, state)} + + def _transform(self, signal, db_cutoff: float): + return signal.mask_low_magnitudes(db_cutoff) + + +class Smoothing(BaseTransform): + """Convolves the signal with a smoothing window. + + Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`. + + Parameters + ---------- + window_type : tuple, optional + Type of window to use, by default ("const", "average") + window_length : tuple, optional + Length of smoothing window, by + default ("choice", [8, 16, 32, 64, 128, 256, 512]) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + window_type: tuple = ("const", "average"), + window_length: tuple = ("choice", [8, 16, 32, 64, 128, 256, 512]), + name: str = None, + prob: float = 1, + ): + super().__init__(name=name, prob=prob) + self.window_type = window_type + self.window_length = window_length + + def _instantiate(self, state: RandomState, signal: AudioSignal = None): + window_type = util.sample_from_dist(self.window_type, state) + window_length = util.sample_from_dist(self.window_length, state) + window = signal.get_window( + window_type=window_type, window_length=window_length, device="cpu" + ) + return {"window": AudioSignal(window, signal.sample_rate)} + + def _transform(self, signal, window): + sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values + sscale[sscale == 0.0] = 1.0 + + out = signal.convolve(window) + + oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values + oscale[oscale == 0.0] = 1.0 + + out = out * (sscale / oscale) + return out + + +class TimeNoise(TimeMask): + """Similar to :py:func:`audiotools.data.transforms.TimeMask`, but + replaces with noise instead of zeros. + + Parameters + ---------- + t_center : tuple, optional + Center time in terms of 0.0 and 1.0 (duration of signal), + by default ("uniform", 0.0, 1.0) + t_width : tuple, optional + Width of dropped out portion, by default ("const", 0.025) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + t_center: tuple = ("uniform", 0.0, 1.0), + t_width: tuple = ("const", 0.025), + name: str = None, + prob: float = 1, + ): + super().__init__(t_center=t_center, t_width=t_width, name=name, prob=prob) + + def _transform(self, signal, tmin_s: float, tmax_s: float): + signal = signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s, val=0.0) + mag, phase = signal.magnitude, signal.phase + + mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase) + mask = (mag == 0.0) * (phase == 0.0) + + mag[mask] = mag_r[mask] + phase[mask] = phase_r[mask] + + signal.magnitude = mag + signal.phase = phase + return signal + + +class FrequencyNoise(FrequencyMask): + """Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but + replaces with noise instead of zeros. + + Parameters + ---------- + f_center : tuple, optional + Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0) + f_width : tuple, optional + Width of zero'd out band, by default ("const", 0.1) + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + f_center: tuple = ("uniform", 0.0, 1.0), + f_width: tuple = ("const", 0.1), + name: str = None, + prob: float = 1, + ): + super().__init__(f_center=f_center, f_width=f_width, name=name, prob=prob) + + def _transform(self, signal, fmin_hz: float, fmax_hz: float): + signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz) + mag, phase = signal.magnitude, signal.phase + + mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase) + mask = (mag == 0.0) * (phase == 0.0) + + mag[mask] = mag_r[mask] + phase[mask] = phase_r[mask] + + signal.magnitude = mag + signal.phase = phase + return signal + + +class SpectralDenoising(Equalizer): + """Applies denoising algorithm detailed in + :py:func:`audiotools.ml.layers.spectral_gate.SpectralGate`, + using a randomly generated noise signal for denoising. + + Parameters + ---------- + eq_amount : tuple, optional + Amount of eq to apply to noise signal, by default ("const", 1.0) + denoise_amount : tuple, optional + Amount to denoise by, by default ("uniform", 0.8, 1.0) + nz_volume : float, optional + Volume of noise to denoise with, by default -40 + n_bands : int, optional + Number of bands in equalizer, by default 6 + n_freq : int, optional + Number of frequency bins to smooth by, by default 3 + n_time : int, optional + Number of time bins to smooth by, by default 5 + name : str, optional + Name of this transform, used to identify it in the dictionary + produced by ``self.instantiate``, by default None + prob : float, optional + Probability of applying this transform, by default 1.0 + """ + + def __init__( + self, + eq_amount: tuple = ("const", 1.0), + denoise_amount: tuple = ("uniform", 0.8, 1.0), + nz_volume: float = -40, + n_bands: int = 6, + n_freq: int = 3, + n_time: int = 5, + name: str = None, + prob: float = 1, + ): + super().__init__(eq_amount=eq_amount, n_bands=n_bands, name=name, prob=prob) + + self.nz_volume = nz_volume + self.denoise_amount = denoise_amount + self.spectral_gate = ml.layers.SpectralGate(n_freq, n_time) + + def _transform(self, signal, nz, eq, denoise_amount): + nz = nz.normalize(self.nz_volume).equalizer(eq) + self.spectral_gate = self.spectral_gate.to(signal.device) + signal = self.spectral_gate(signal, nz, denoise_amount) + return signal + + def _instantiate(self, state: RandomState): + kwargs = super()._instantiate(state) + kwargs["denoise_amount"] = util.sample_from_dist(self.denoise_amount, state) + kwargs["nz"] = AudioSignal(state.randn(22050), 44100) + return kwargs diff --git a/audiotools/metrics/__init__.py b/audiotools/metrics/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..c9c8d2df61f94afae8e39e57abf156e8e4059a9e --- /dev/null +++ b/audiotools/metrics/__init__.py @@ -0,0 +1,6 @@ +""" +Functions for comparing AudioSignal objects to one another. +""" # fmt: skip +from . import distance +from . import quality +from . import spectral diff --git a/audiotools/metrics/distance.py b/audiotools/metrics/distance.py new file mode 100644 index 0000000000000000000000000000000000000000..ce78739bfc29f9ddc39b23063b4243ddac10adaf --- /dev/null +++ b/audiotools/metrics/distance.py @@ -0,0 +1,131 @@ +import torch +from torch import nn + +from .. import AudioSignal + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr diff --git a/audiotools/metrics/quality.py b/audiotools/metrics/quality.py new file mode 100644 index 0000000000000000000000000000000000000000..1608f25507082b49ccbf49289025a5a94a422808 --- /dev/null +++ b/audiotools/metrics/quality.py @@ -0,0 +1,159 @@ +import os + +import numpy as np +import torch + +from .. import AudioSignal + + +def stoi( + estimates: AudioSignal, + references: AudioSignal, + extended: int = False, +): + """Short term objective intelligibility + Computes the STOI (See [1][2]) of a denoised signal compared to a clean + signal, The output is expected to have a monotonic relation with the + subjective speech-intelligibility, where a higher score denotes better + speech intelligibility. Uses pystoi under the hood. + + Parameters + ---------- + estimates : AudioSignal + Denoised speech + references : AudioSignal + Clean original speech + extended : int, optional + Boolean, whether to use the extended STOI described in [3], by default False + + Returns + ------- + Tensor[float] + Short time objective intelligibility measure between clean and + denoised speech + + References + ---------- + 1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time + Objective Intelligibility Measure for Time-Frequency Weighted Noisy + Speech', ICASSP 2010, Texas, Dallas. + 2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for + Intelligibility Prediction of Time-Frequency Weighted Noisy Speech', + IEEE Transactions on Audio, Speech, and Language Processing, 2011. + 3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the + Intelligibility of Speech Masked by Modulated Noise Maskers', + IEEE Transactions on Audio, Speech and Language Processing, 2016. + """ + import pystoi + + estimates = estimates.clone().to_mono() + references = references.clone().to_mono() + + stois = [] + for i in range(estimates.batch_size): + _stoi = pystoi.stoi( + references.audio_data[i, 0].detach().cpu().numpy(), + estimates.audio_data[i, 0].detach().cpu().numpy(), + references.sample_rate, + extended=extended, + ) + stois.append(_stoi) + return torch.from_numpy(np.array(stois)) + + +def pesq( + estimates: AudioSignal, + references: AudioSignal, + mode: str = "wb", + target_sr: float = 16000, +): + """_summary_ + + Parameters + ---------- + estimates : AudioSignal + Degraded AudioSignal + references : AudioSignal + Reference AudioSignal + mode : str, optional + 'wb' (wide-band) or 'nb' (narrow-band), by default "wb" + target_sr : float, optional + Target sample rate, by default 16000 + + Returns + ------- + Tensor[float] + PESQ score: P.862.2 Prediction (MOS-LQO) + """ + from pesq import pesq as pesq_fn + + estimates = estimates.clone().to_mono().resample(target_sr) + references = references.clone().to_mono().resample(target_sr) + + pesqs = [] + for i in range(estimates.batch_size): + _pesq = pesq_fn( + estimates.sample_rate, + references.audio_data[i, 0].detach().cpu().numpy(), + estimates.audio_data[i, 0].detach().cpu().numpy(), + mode, + ) + pesqs.append(_pesq) + return torch.from_numpy(np.array(pesqs)) + + +def visqol( + estimates: AudioSignal, + references: AudioSignal, + mode: str = "audio", +): # pragma: no cover + """ViSQOL score. + + Parameters + ---------- + estimates : AudioSignal + Degraded AudioSignal + references : AudioSignal + Reference AudioSignal + mode : str, optional + 'audio' or 'speech', by default 'audio' + + Returns + ------- + Tensor[float] + ViSQOL score (MOS-LQO) + """ + from visqol import visqol_lib_py + from visqol.pb2 import visqol_config_pb2 + from visqol.pb2 import similarity_result_pb2 + + config = visqol_config_pb2.VisqolConfig() + if mode == "audio": + target_sr = 48000 + config.options.use_speech_scoring = False + svr_model_path = "libsvm_nu_svr_model.txt" + elif mode == "speech": + target_sr = 16000 + config.options.use_speech_scoring = True + svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite" + else: + raise ValueError(f"Unrecognized mode: {mode}") + config.audio.sample_rate = target_sr + config.options.svr_model_path = os.path.join( + os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path + ) + + api = visqol_lib_py.VisqolApi() + api.Create(config) + + estimates = estimates.clone().to_mono().resample(target_sr) + references = references.clone().to_mono().resample(target_sr) + + visqols = [] + for i in range(estimates.batch_size): + _visqol = api.Measure( + references.audio_data[i, 0].detach().cpu().numpy().astype(float), + estimates.audio_data[i, 0].detach().cpu().numpy().astype(float), + ) + visqols.append(_visqol.moslqo) + return torch.from_numpy(np.array(visqols)) diff --git a/audiotools/metrics/spectral.py b/audiotools/metrics/spectral.py new file mode 100644 index 0000000000000000000000000000000000000000..7ce953882efa4e5b777a0348bee6c1be39279a6c --- /dev/null +++ b/audiotools/metrics/spectral.py @@ -0,0 +1,247 @@ +import typing +from typing import List + +import numpy as np +from torch import nn + +from .. import AudioSignal +from .. import STFTParams + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class PhaseLoss(nn.Module): + """Difference between phase spectrograms. + + Parameters + ---------- + window_length : int, optional + Length of STFT window, by default 2048 + hop_length : int, optional + Hop length of STFT window, by default 512 + weight : float, optional + Weight of loss, by default 1.0 + """ + + def __init__( + self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0 + ): + super().__init__() + + self.weight = weight + self.stft_params = STFTParams(window_length, hop_length) + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes phase loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Phase loss. + """ + s = self.stft_params + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + + # Take circular difference + diff = x.phase - y.phase + diff[diff < -np.pi] += 2 * np.pi + diff[diff > np.pi] -= -2 * np.pi + + # Scale true magnitude to weights in [0, 1] + x_min, x_max = x.magnitude.min(), x.magnitude.max() + weights = (x.magnitude - x_min) / (x_max - x_min) + + # Take weighted mean of all phase errors + loss = ((weights * diff) ** 2).mean() + return loss diff --git a/audiotools/ml/__init__.py b/audiotools/ml/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a9ca69977bad57e1a92b7551d601d9224ee854ab --- /dev/null +++ b/audiotools/ml/__init__.py @@ -0,0 +1,5 @@ +from . import decorators +from . import layers +from .accelerator import Accelerator +from .experiment import Experiment +from .layers import BaseModel diff --git a/audiotools/ml/accelerator.py b/audiotools/ml/accelerator.py new file mode 100644 index 0000000000000000000000000000000000000000..37c6e8d954f112b8b0aff257894e62add8874e30 --- /dev/null +++ b/audiotools/ml/accelerator.py @@ -0,0 +1,184 @@ +import os +import typing + +import torch +import torch.distributed as dist +from torch.nn.parallel import DataParallel +from torch.nn.parallel import DistributedDataParallel + +from ..data.datasets import ResumableDistributedSampler as DistributedSampler +from ..data.datasets import ResumableSequentialSampler as SequentialSampler + + +class Accelerator: # pragma: no cover + """This class is used to prepare models and dataloaders for + usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to + prepare the respective objects. In the case of models, they are moved to + the appropriate GPU and SyncBatchNorm is applied to them. In the case of + dataloaders, a sampler is created and the dataloader is initialized with + that sampler. + + If the world size is 1, prepare_model and prepare_dataloader are + no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the + script was launched without ``torchrun``, and ``DataParallel`` + will be used instead of ``DistributedDataParallel`` (not recommended), if + the world size (number of GPUs) is greater than 1. + + Parameters + ---------- + amp : bool, optional + Whether or not to enable automatic mixed precision, by default False + """ + + def __init__(self, amp: bool = False): + local_rank = os.getenv("LOCAL_RANK", None) + self.world_size = torch.cuda.device_count() + + self.use_ddp = self.world_size > 1 and local_rank is not None + self.use_dp = self.world_size > 1 and local_rank is None + self.device = "cpu" if self.world_size == 0 else "cuda" + + if self.use_ddp: + local_rank = int(local_rank) + dist.init_process_group( + "nccl", + init_method="env://", + world_size=self.world_size, + rank=local_rank, + ) + + self.local_rank = 0 if local_rank is None else local_rank + self.amp = amp + + class DummyScaler: + def __init__(self): + pass + + def step(self, optimizer): + optimizer.step() + + def scale(self, loss): + return loss + + def unscale_(self, optimizer): + return optimizer + + def update(self): + pass + + self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler() + self.device_ctx = ( + torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None + ) + + def __enter__(self): + if self.device_ctx is not None: + self.device_ctx.__enter__() + return self + + def __exit__(self, exc_type, exc_value, traceback): + if self.device_ctx is not None: + self.device_ctx.__exit__(exc_type, exc_value, traceback) + + def prepare_model(self, model: torch.nn.Module, **kwargs): + """Prepares model for DDP or DP. The model is moved to + the device of the correct rank. + + Parameters + ---------- + model : torch.nn.Module + Model that is converted for DDP or DP. + + Returns + ------- + torch.nn.Module + Wrapped model, or original model if DDP and DP are turned off. + """ + model = model.to(self.device) + if self.use_ddp: + model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model) + model = DistributedDataParallel( + model, device_ids=[self.local_rank], **kwargs + ) + elif self.use_dp: + model = DataParallel(model, **kwargs) + return model + + # Automatic mixed-precision utilities + def autocast(self, *args, **kwargs): + """Context manager for autocasting. Arguments + go to ``torch.cuda.amp.autocast``. + """ + return torch.cuda.amp.autocast(self.amp, *args, **kwargs) + + def backward(self, loss: torch.Tensor): + """Backwards pass, after scaling the loss if ``amp`` is + enabled. + + Parameters + ---------- + loss : torch.Tensor + Loss value. + """ + self.scaler.scale(loss).backward() + + def step(self, optimizer: torch.optim.Optimizer): + """Steps the optimizer, using a ``scaler`` if ``amp`` is + enabled. + + Parameters + ---------- + optimizer : torch.optim.Optimizer + Optimizer to step forward. + """ + self.scaler.step(optimizer) + + def update(self): + """Updates the scale factor.""" + self.scaler.update() + + def prepare_dataloader( + self, dataset: typing.Iterable, start_idx: int = None, **kwargs + ): + """Wraps a dataset with a DataLoader, using the correct sampler if DDP is + enabled. + + Parameters + ---------- + dataset : typing.Iterable + Dataset to build Dataloader around. + start_idx : int, optional + Start index of sampler, useful if resuming from some epoch, + by default None + + Returns + ------- + _type_ + _description_ + """ + + if self.use_ddp: + sampler = DistributedSampler( + dataset, + start_idx, + num_replicas=self.world_size, + rank=self.local_rank, + ) + if "num_workers" in kwargs: + kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1) + kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1) + else: + sampler = SequentialSampler(dataset, start_idx) + + dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs) + return dataloader + + @staticmethod + def unwrap(model): + """Unwraps the model if it was wrapped in DDP or DP, otherwise + just returns the model. Use this to unwrap the model returned by + :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`. + """ + if hasattr(model, "module"): + return model.module + return model diff --git a/audiotools/ml/decorators.py b/audiotools/ml/decorators.py new file mode 100644 index 0000000000000000000000000000000000000000..834ec10270ff9e8e84a5fa99e13a686516a4af41 --- /dev/null +++ b/audiotools/ml/decorators.py @@ -0,0 +1,440 @@ +import math +import os +import time +from collections import defaultdict +from functools import wraps + +import torch +import torch.distributed as dist +from rich import box +from rich.console import Console +from rich.console import Group +from rich.live import Live +from rich.markdown import Markdown +from rich.padding import Padding +from rich.panel import Panel +from rich.progress import BarColumn +from rich.progress import Progress +from rich.progress import SpinnerColumn +from rich.progress import TimeElapsedColumn +from rich.progress import TimeRemainingColumn +from rich.rule import Rule +from rich.table import Table +from torch.utils.tensorboard import SummaryWriter + + +# This is here so that the history can be pickled. +def default_list(): + return [] + + +class Mean: + """Keeps track of the running mean, along with the latest + value. + """ + + def __init__(self): + self.reset() + + def __call__(self): + mean = self.total / max(self.count, 1) + return mean + + def reset(self): + self.count = 0 + self.total = 0 + + def update(self, val): + if math.isfinite(val): + self.count += 1 + self.total += val + + +def when(condition): + """Runs a function only when the condition is met. The condition is + a function that is run. + + Parameters + ---------- + condition : Callable + Function to run to check whether or not to run the decorated + function. + + Example + ------- + Checkpoint only runs every 100 iterations, and only if the + local rank is 0. + + >>> i = 0 + >>> rank = 0 + >>> + >>> @when(lambda: i % 100 == 0 and rank == 0) + >>> def checkpoint(): + >>> print("Saving to /runs/exp1") + >>> + >>> for i in range(1000): + >>> checkpoint() + + """ + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + if condition(): + return fn(*args, **kwargs) + + return decorated + + return decorator + + +def timer(prefix: str = "time"): + """Adds execution time to the output dictionary of the decorated + function. The function decorated by this must output a dictionary. + The key added will follow the form "[prefix]/[name_of_function]" + + Parameters + ---------- + prefix : str, optional + The key added will follow the form "[prefix]/[name_of_function]", + by default "time". + """ + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + s = time.perf_counter() + output = fn(*args, **kwargs) + assert isinstance(output, dict) + e = time.perf_counter() + output[f"{prefix}/{fn.__name__}"] = e - s + return output + + return decorated + + return decorator + + +class Tracker: + """ + A tracker class that helps to monitor the progress of training and logging the metrics. + + Attributes + ---------- + metrics : dict + A dictionary containing the metrics for each label. + history : dict + A dictionary containing the history of metrics for each label. + writer : SummaryWriter + A SummaryWriter object for logging the metrics. + rank : int + The rank of the current process. + step : int + The current step of the training. + tasks : dict + A dictionary containing the progress bars and tables for each label. + pbar : Progress + A progress bar object for displaying the progress. + consoles : list + A list of console objects for logging. + live : Live + A Live object for updating the display live. + + Methods + ------- + print(msg: str) + Prints the given message to all consoles. + update(label: str, fn_name: str) + Updates the progress bar and table for the given label. + done(label: str, title: str) + Resets the progress bar and table for the given label and prints the final result. + track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ) + A decorator for tracking the progress and metrics of a function. + log(label: str, value_type: str = "value", history: bool = True) + A decorator for logging the metrics of a function. + is_best(label: str, key: str) -> bool + Checks if the latest value of the given key in the label is the best so far. + state_dict() -> dict + Returns a dictionary containing the state of the tracker. + load_state_dict(state_dict: dict) -> Tracker + Loads the state of the tracker from the given state dictionary. + """ + + def __init__( + self, + writer: SummaryWriter = None, + log_file: str = None, + rank: int = 0, + console_width: int = 100, + step: int = 0, + ): + """ + Initializes the Tracker object. + + Parameters + ---------- + writer : SummaryWriter, optional + A SummaryWriter object for logging the metrics, by default None. + log_file : str, optional + The path to the log file, by default None. + rank : int, optional + The rank of the current process, by default 0. + console_width : int, optional + The width of the console, by default 100. + step : int, optional + The current step of the training, by default 0. + """ + self.metrics = {} + self.history = {} + self.writer = writer + self.rank = rank + self.step = step + + # Create progress bars etc. + self.tasks = {} + self.pbar = Progress( + SpinnerColumn(), + "[progress.description]{task.description}", + "{task.completed}/{task.total}", + BarColumn(), + TimeElapsedColumn(), + "/", + TimeRemainingColumn(), + ) + self.consoles = [Console(width=console_width)] + self.live = Live(console=self.consoles[0], refresh_per_second=10) + if log_file is not None: + self.consoles.append(Console(width=console_width, file=open(log_file, "a"))) + + def print(self, msg): + """ + Prints the given message to all consoles. + + Parameters + ---------- + msg : str + The message to be printed. + """ + if self.rank == 0: + for c in self.consoles: + c.log(msg) + + def update(self, label, fn_name): + """ + Updates the progress bar and table for the given label. + + Parameters + ---------- + label : str + The label of the progress bar and table to be updated. + fn_name : str + The name of the function associated with the label. + """ + if self.rank == 0: + self.pbar.advance(self.tasks[label]["pbar"]) + + # Create table + table = Table(title=label, expand=True, box=box.MINIMAL) + table.add_column("key", style="cyan") + table.add_column("value", style="bright_blue") + table.add_column("mean", style="bright_green") + + keys = self.metrics[label]["value"].keys() + for k in keys: + value = self.metrics[label]["value"][k] + mean = self.metrics[label]["mean"][k]() + table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}") + + self.tasks[label]["table"] = table + tables = [t["table"] for t in self.tasks.values()] + group = Group(*tables, self.pbar) + self.live.update( + Group( + Padding("", (0, 0)), + Rule(f"[italic]{fn_name}()", style="white"), + Padding("", (0, 0)), + Panel.fit( + group, padding=(0, 5), title="[b]Progress", border_style="blue" + ), + ) + ) + + def done(self, label: str, title: str): + """ + Resets the progress bar and table for the given label and prints the final result. + + Parameters + ---------- + label : str + The label of the progress bar and table to be reset. + title : str + The title to be displayed when printing the final result. + """ + for label in self.metrics: + for v in self.metrics[label]["mean"].values(): + v.reset() + + if self.rank == 0: + self.pbar.reset(self.tasks[label]["pbar"]) + tables = [t["table"] for t in self.tasks.values()] + group = Group(Markdown(f"# {title}"), *tables, self.pbar) + self.print(group) + + def track( + self, + label: str, + length: int, + completed: int = 0, + op: dist.ReduceOp = dist.ReduceOp.AVG, + ddp_active: bool = "LOCAL_RANK" in os.environ, + ): + """ + A decorator for tracking the progress and metrics of a function. + + Parameters + ---------- + label : str + The label to be associated with the progress and metrics. + length : int + The total number of iterations to be completed. + completed : int, optional + The number of iterations already completed, by default 0. + op : dist.ReduceOp, optional + The reduce operation to be used, by default dist.ReduceOp.AVG. + ddp_active : bool, optional + Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ. + """ + self.tasks[label] = { + "pbar": self.pbar.add_task( + f"[white]Iteration ({label})", total=length, completed=completed + ), + "table": Table(), + } + self.metrics[label] = { + "value": defaultdict(), + "mean": defaultdict(lambda: Mean()), + } + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if not isinstance(output, dict): + self.update(label, fn.__name__) + return output + # Collect across all DDP processes + scalar_keys = [] + for k, v in output.items(): + if isinstance(v, (int, float)): + v = torch.tensor([v]) + if not torch.is_tensor(v): + continue + if ddp_active and v.is_cuda: # pragma: no cover + dist.all_reduce(v, op=op) + output[k] = v.detach() + if torch.numel(v) == 1: + scalar_keys.append(k) + output[k] = v.item() + + # Save the outputs to tracker + for k, v in output.items(): + if k not in scalar_keys: + continue + self.metrics[label]["value"][k] = v + # Update the running mean + self.metrics[label]["mean"][k].update(v) + + self.update(label, fn.__name__) + return output + + return decorated + + return decorator + + def log(self, label: str, value_type: str = "value", history: bool = True): + """ + A decorator for logging the metrics of a function. + + Parameters + ---------- + label : str + The label to be associated with the logging. + value_type : str, optional + The type of value to be logged, by default "value". + history : bool, optional + Whether to save the history of the metrics, by default True. + """ + assert value_type in ["mean", "value"] + if history: + if label not in self.history: + self.history[label] = defaultdict(default_list) + + def decorator(fn): + @wraps(fn) + def decorated(*args, **kwargs): + output = fn(*args, **kwargs) + if self.rank == 0: + nonlocal value_type, label + metrics = self.metrics[label][value_type] + for k, v in metrics.items(): + v = v() if isinstance(v, Mean) else v + if self.writer is not None: + self.writer.add_scalar(f"{k}/{label}", v, self.step) + if label in self.history: + self.history[label][k].append(v) + + if label in self.history: + self.history[label]["step"].append(self.step) + + return output + + return decorated + + return decorator + + def is_best(self, label, key): + """ + Checks if the latest value of the given key in the label is the best so far. + + Parameters + ---------- + label : str + The label of the metrics to be checked. + key : str + The key of the metric to be checked. + + Returns + ------- + bool + True if the latest value is the best so far, otherwise False. + """ + return self.history[label][key][-1] == min(self.history[label][key]) + + def state_dict(self): + """ + Returns a dictionary containing the state of the tracker. + + Returns + ------- + dict + A dictionary containing the history and step of the tracker. + """ + return {"history": self.history, "step": self.step} + + def load_state_dict(self, state_dict): + """ + Loads the state of the tracker from the given state dictionary. + + Parameters + ---------- + state_dict : dict + A dictionary containing the history and step of the tracker. + + Returns + ------- + Tracker + The tracker object with the loaded state. + """ + self.history = state_dict["history"] + self.step = state_dict["step"] + return self diff --git a/audiotools/ml/experiment.py b/audiotools/ml/experiment.py new file mode 100644 index 0000000000000000000000000000000000000000..62833d0f8f80dcdf496a1a5d2785ef666e0a15b6 --- /dev/null +++ b/audiotools/ml/experiment.py @@ -0,0 +1,90 @@ +""" +Useful class for Experiment tracking, and ensuring code is +saved alongside files. +""" # fmt: skip +import datetime +import os +import shlex +import shutil +import subprocess +import typing +from pathlib import Path + +import randomname + + +class Experiment: + """This class contains utilities for managing experiments. + It is a context manager, that when you enter it, changes + your directory to a specified experiment folder (which + optionally can have an automatically generated experiment + name, or a specified one), and changes the CUDA device used + to the specified device (or devices). + + Parameters + ---------- + exp_directory : str + Folder where all experiments are saved, by default "runs/". + exp_name : str, optional + Name of the experiment, by default uses the current time, date, and + hostname to save. + """ + + def __init__( + self, + exp_directory: str = "runs/", + exp_name: str = None, + ): + if exp_name is None: + exp_name = self.generate_exp_name() + exp_dir = Path(exp_directory) / exp_name + exp_dir.mkdir(parents=True, exist_ok=True) + + self.exp_dir = exp_dir + self.exp_name = exp_name + self.git_tracked_files = ( + subprocess.check_output( + shlex.split("git ls-tree --full-tree --name-only -r HEAD") + ) + .decode("utf-8") + .splitlines() + ) + self.parent_directory = Path(".").absolute() + + def __enter__(self): + self.prev_dir = os.getcwd() + os.chdir(self.exp_dir) + return self + + def __exit__(self, exc_type, exc_value, traceback): + os.chdir(self.prev_dir) + + @staticmethod + def generate_exp_name(): + """Generates a random experiment name based on the date + and a randomly generated adjective-noun tuple. + + Returns + ------- + str + Randomly generated experiment name. + """ + date = datetime.datetime.now().strftime("%y%m%d") + name = f"{date}-{randomname.get_name()}" + return name + + def snapshot(self, filter_fn: typing.Callable = lambda f: True): + """Captures a full snapshot of all the files tracked by git at the time + the experiment is run. It also captures the diff against the committed + code as a separate file. + + Parameters + ---------- + filter_fn : typing.Callable, optional + Function that can be used to exclude some files + from the snapshot, by default accepts all files + """ + for f in self.git_tracked_files: + if filter_fn(f): + Path(f).parent.mkdir(parents=True, exist_ok=True) + shutil.copyfile(self.parent_directory / f, f) diff --git a/audiotools/ml/layers/__init__.py b/audiotools/ml/layers/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..92a016cab2ddf06bf5dadfae241b7e5d9def4878 --- /dev/null +++ b/audiotools/ml/layers/__init__.py @@ -0,0 +1,2 @@ +from .base import BaseModel +from .spectral_gate import SpectralGate diff --git a/audiotools/ml/layers/base.py b/audiotools/ml/layers/base.py new file mode 100644 index 0000000000000000000000000000000000000000..b82c96cdd7336ca6b8ed6fc7f0192d69a8e998dd --- /dev/null +++ b/audiotools/ml/layers/base.py @@ -0,0 +1,328 @@ +import inspect +import shutil +import tempfile +import typing +from pathlib import Path + +import torch +from torch import nn + + +class BaseModel(nn.Module): + """This is a class that adds useful save/load functionality to a + ``torch.nn.Module`` object. ``BaseModel`` objects can be saved + as ``torch.package`` easily, making them super easy to port between + machines without requiring a ton of dependencies. Files can also be + saved as just weights, in the standard way. + + >>> class Model(ml.BaseModel): + >>> def __init__(self, arg1: float = 1.0): + >>> super().__init__() + >>> self.arg1 = arg1 + >>> self.linear = nn.Linear(1, 1) + >>> + >>> def forward(self, x): + >>> return self.linear(x) + >>> + >>> model1 = Model() + >>> + >>> with tempfile.NamedTemporaryFile(suffix=".pth") as f: + >>> model1.save( + >>> f.name, + >>> ) + >>> model2 = Model.load(f.name) + >>> out2 = seed_and_run(model2, x) + >>> assert torch.allclose(out1, out2) + >>> + >>> model1.save(f.name, package=True) + >>> model2 = Model.load(f.name) + >>> model2.save(f.name, package=False) + >>> model3 = Model.load(f.name) + >>> out3 = seed_and_run(model3, x) + >>> + >>> with tempfile.TemporaryDirectory() as d: + >>> model1.save_to_folder(d, {"data": 1.0}) + >>> Model.load_from_folder(d) + + """ + + EXTERN = [ + "audiotools.**", + "tqdm", + "__main__", + "numpy.**", + "julius.**", + "torchaudio.**", + "scipy.**", + "einops", + ] + """Names of libraries that are external to the torch.package saving mechanism. + Source code from these libraries will not be packaged into the model. This can + be edited by the user of this class by editing ``model.EXTERN``.""" + INTERN = [] + """Names of libraries that are internal to the torch.package saving mechanism. + Source code from these libraries will be saved alongside the model.""" + + def save( + self, + path: str, + metadata: dict = None, + package: bool = True, + intern: list = [], + extern: list = [], + mock: list = [], + ): + """Saves the model, either as a torch package, or just as + weights, alongside some specified metadata. + + Parameters + ---------- + path : str + Path to save model to. + metadata : dict, optional + Any metadata to save alongside the model, + by default None + package : bool, optional + Whether to use ``torch.package`` to save the model in + a format that is portable, by default True + intern : list, optional + List of additional libraries that are internal + to the model, used with torch.package, by default [] + extern : list, optional + List of additional libraries that are external to + the model, used with torch.package, by default [] + mock : list, optional + List of libraries to mock, used with torch.package, + by default [] + + Returns + ------- + str + Path to saved model. + """ + sig = inspect.signature(self.__class__) + args = {} + + for key, val in sig.parameters.items(): + arg_val = val.default + if arg_val is not inspect.Parameter.empty: + args[key] = arg_val + + # Look up attibutes in self, and if any of them are in args, + # overwrite them in args. + for attribute in dir(self): + if attribute in args: + args[attribute] = getattr(self, attribute) + + metadata = {} if metadata is None else metadata + metadata["kwargs"] = args + if not hasattr(self, "metadata"): + self.metadata = {} + self.metadata.update(metadata) + + if not package: + state_dict = {"state_dict": self.state_dict(), "metadata": metadata} + torch.save(state_dict, path) + else: + self._save_package(path, intern=intern, extern=extern, mock=mock) + + return path + + @property + def device(self): + """Gets the device the model is on by looking at the device of + the first parameter. May not be valid if model is split across + multiple devices. + """ + return list(self.parameters())[0].device + + @classmethod + def load( + cls, + location: str, + *args, + package_name: str = None, + strict: bool = False, + **kwargs, + ): + """Load model from a path. Tries first to load as a package, and if + that fails, tries to load as weights. The arguments to the class are + specified inside the model weights file. + + Parameters + ---------- + location : str + Path to file. + package_name : str, optional + Name of package, by default ``cls.__name__``. + strict : bool, optional + Ignore unmatched keys, by default False + kwargs : dict + Additional keyword arguments to the model instantiation, if + not loading from package. + + Returns + ------- + BaseModel + A model that inherits from BaseModel. + """ + try: + model = cls._load_package(location, package_name=package_name) + except: + model_dict = torch.load(location, "cpu") + metadata = model_dict["metadata"] + metadata["kwargs"].update(kwargs) + + sig = inspect.signature(cls) + class_keys = list(sig.parameters.keys()) + for k in list(metadata["kwargs"].keys()): + if k not in class_keys: + metadata["kwargs"].pop(k) + + model = cls(*args, **metadata["kwargs"]) + model.load_state_dict(model_dict["state_dict"], strict=strict) + model.metadata = metadata + + return model + + def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs): + package_name = type(self).__name__ + resource_name = f"{type(self).__name__}.pth" + + # Below is for loading and re-saving a package. + if hasattr(self, "importer"): + kwargs["importer"] = (self.importer, torch.package.sys_importer) + del self.importer + + # Why do we use a tempfile, you ask? + # It's so we can load a packaged model and then re-save + # it to the same location. torch.package throws an + # error if it's loading and writing to the same + # file (this is undocumented). + with tempfile.NamedTemporaryFile(suffix=".pth") as f: + with torch.package.PackageExporter(f.name, **kwargs) as exp: + exp.intern(self.INTERN + intern) + exp.mock(mock) + exp.extern(self.EXTERN + extern) + exp.save_pickle(package_name, resource_name, self) + + if hasattr(self, "metadata"): + exp.save_pickle( + package_name, f"{package_name}.metadata", self.metadata + ) + + shutil.copyfile(f.name, path) + + # Must reset the importer back to `self` if it existed + # so that you can save the model again! + if "importer" in kwargs: + self.importer = kwargs["importer"][0] + return path + + @classmethod + def _load_package(cls, path, package_name=None): + package_name = cls.__name__ if package_name is None else package_name + resource_name = f"{package_name}.pth" + + imp = torch.package.PackageImporter(path) + model = imp.load_pickle(package_name, resource_name, "cpu") + try: + model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata") + except: # pragma: no cover + pass + model.importer = imp + + return model + + def save_to_folder( + self, + folder: typing.Union[str, Path], + extra_data: dict = None, + package: bool = True, + ): + """Dumps a model into a folder, as both a package + and as weights, as well as anything specified in + ``extra_data``. ``extra_data`` is a dictionary of other + pickleable files, with the keys being the paths + to save them in. The model is saved under a subfolder + specified by the name of the class (e.g. ``folder/generator/[package, weights].pth`` + if the model name was ``Generator``). + + >>> with tempfile.TemporaryDirectory() as d: + >>> extra_data = { + >>> "optimizer.pth": optimizer.state_dict() + >>> } + >>> model.save_to_folder(d, extra_data) + >>> Model.load_from_folder(d) + + Parameters + ---------- + folder : typing.Union[str, Path] + _description_ + extra_data : dict, optional + _description_, by default None + + Returns + ------- + str + Path to folder + """ + extra_data = {} if extra_data is None else extra_data + model_name = type(self).__name__.lower() + target_base = Path(f"{folder}/{model_name}/") + target_base.mkdir(exist_ok=True, parents=True) + + if package: + package_path = target_base / f"package.pth" + self.save(package_path) + + weights_path = target_base / f"weights.pth" + self.save(weights_path, package=False) + + for path, obj in extra_data.items(): + torch.save(obj, target_base / path) + + return target_base + + @classmethod + def load_from_folder( + cls, + folder: typing.Union[str, Path], + package: bool = True, + strict: bool = False, + **kwargs, + ): + """Loads the model from a folder generated by + :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. + Like that function, this one looks for a subfolder that has + the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the + model name was ``Generator``). + + Parameters + ---------- + folder : typing.Union[str, Path] + _description_ + package : bool, optional + Whether to use ``torch.package`` to load the model, + loading the model from ``package.pth``. + strict : bool, optional + Ignore unmatched keys, by default False + + Returns + ------- + tuple + tuple of model and extra data as saved by + :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`. + """ + folder = Path(folder) / cls.__name__.lower() + model_pth = "package.pth" if package else "weights.pth" + model_pth = folder / model_pth + + model = cls.load(model_pth, strict=strict) + extra_data = {} + excluded = ["package.pth", "weights.pth"] + files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded] + for f in files: + extra_data[f.name] = torch.load(f, **kwargs) + + return model, extra_data diff --git a/audiotools/ml/layers/spectral_gate.py b/audiotools/ml/layers/spectral_gate.py new file mode 100644 index 0000000000000000000000000000000000000000..c4ae8b5eab2e56ce13541695f52a11a454759dae --- /dev/null +++ b/audiotools/ml/layers/spectral_gate.py @@ -0,0 +1,127 @@ +import torch +import torch.nn.functional as F +from torch import nn + +from ...core import AudioSignal +from ...core import STFTParams +from ...core import util + + +class SpectralGate(nn.Module): + """Spectral gating algorithm for noise reduction, + as in Audacity/Ocenaudio. The steps are as follows: + + 1. An FFT is calculated over the noise audio clip + 2. Statistics are calculated over FFT of the the noise + (in frequency) + 3. A threshold is calculated based upon the statistics + of the noise (and the desired sensitivity of the algorithm) + 4. An FFT is calculated over the signal + 5. A mask is determined by comparing the signal FFT to the + threshold + 6. The mask is smoothed with a filter over frequency and time + 7. The mask is appled to the FFT of the signal, and is inverted + + Implementation inspired by Tim Sainburg's noisereduce: + + https://timsainburg.com/noise-reduction-python.html + + Parameters + ---------- + n_freq : int, optional + Number of frequency bins to smooth by, by default 3 + n_time : int, optional + Number of time bins to smooth by, by default 5 + """ + + def __init__(self, n_freq: int = 3, n_time: int = 5): + super().__init__() + + smoothing_filter = torch.outer( + torch.cat( + [ + torch.linspace(0, 1, n_freq + 2)[:-1], + torch.linspace(1, 0, n_freq + 2), + ] + )[..., 1:-1], + torch.cat( + [ + torch.linspace(0, 1, n_time + 2)[:-1], + torch.linspace(1, 0, n_time + 2), + ] + )[..., 1:-1], + ) + smoothing_filter = smoothing_filter / smoothing_filter.sum() + smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0) + self.register_buffer("smoothing_filter", smoothing_filter) + + def forward( + self, + audio_signal: AudioSignal, + nz_signal: AudioSignal, + denoise_amount: float = 1.0, + n_std: float = 3.0, + win_length: int = 2048, + hop_length: int = 512, + ): + """Perform noise reduction. + + Parameters + ---------- + audio_signal : AudioSignal + Audio signal that noise will be removed from. + nz_signal : AudioSignal, optional + Noise signal to compute noise statistics from. + denoise_amount : float, optional + Amount to denoise by, by default 1.0 + n_std : float, optional + Number of standard deviations above which to consider + noise, by default 3.0 + win_length : int, optional + Length of window for STFT, by default 2048 + hop_length : int, optional + Hop length for STFT, by default 512 + + Returns + ------- + AudioSignal + Denoised audio signal. + """ + stft_params = STFTParams(win_length, hop_length, "sqrt_hann") + + audio_signal = audio_signal.clone() + audio_signal.stft_data = None + audio_signal.stft_params = stft_params + + nz_signal = nz_signal.clone() + nz_signal.stft_params = stft_params + + nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10() + nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1) + nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1) + + nz_thresh = nz_freq_mean + nz_freq_std * n_std + + stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10() + nb, nac, nf, nt = stft_db.shape + db_thresh = nz_thresh.expand(nb, nac, -1, nt) + + stft_mask = (stft_db < db_thresh).float() + shape = stft_mask.shape + + stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt) + pad_tuple = ( + self.smoothing_filter.shape[-2] // 2, + self.smoothing_filter.shape[-1] // 2, + ) + stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple) + stft_mask = stft_mask.reshape(*shape) + stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to( + audio_signal.device + ) + stft_mask = 1 - stft_mask + + audio_signal.stft_data *= stft_mask + audio_signal.istft() + + return audio_signal diff --git a/audiotools/post.py b/audiotools/post.py new file mode 100644 index 0000000000000000000000000000000000000000..6ced2d1e66a4ffda3269685bd45593b01038739f --- /dev/null +++ b/audiotools/post.py @@ -0,0 +1,140 @@ +import tempfile +import typing +import zipfile +from pathlib import Path + +import markdown2 as md +import matplotlib.pyplot as plt +import torch +from IPython.display import HTML + + +def audio_table( + audio_dict: dict, + first_column: str = None, + format_fn: typing.Callable = None, + **kwargs, +): # pragma: no cover + """Embeds an audio table into HTML, or as the output cell + in a notebook. + + Parameters + ---------- + audio_dict : dict + Dictionary of data to embed. + first_column : str, optional + The label for the first column of the table, by default None + format_fn : typing.Callable, optional + How to format the data, by default None + + Returns + ------- + str + Table as a string + + Examples + -------- + + >>> audio_dict = {} + >>> for i in range(signal_batch.batch_size): + >>> audio_dict[i] = { + >>> "input": signal_batch[i], + >>> "output": output_batch[i] + >>> } + >>> audiotools.post.audio_zip(audio_dict) + + """ + from audiotools import AudioSignal + + output = [] + columns = None + + def _default_format_fn(label, x, **kwargs): + if torch.is_tensor(x): + x = x.tolist() + + if x is None: + return "." + elif isinstance(x, AudioSignal): + return x.embed(display=False, return_html=True, **kwargs) + else: + return str(x) + + if format_fn is None: + format_fn = _default_format_fn + + if first_column is None: + first_column = "." + + for k, v in audio_dict.items(): + if not isinstance(v, dict): + v = {"Audio": v} + + v_keys = list(v.keys()) + if columns is None: + columns = [first_column] + v_keys + output.append(" | ".join(columns)) + + layout = "|---" + len(v_keys) * "|:-:" + output.append(layout) + + formatted_audio = [] + for col in columns[1:]: + formatted_audio.append(format_fn(col, v[col], **kwargs)) + + row = f"| {k} | " + row += " | ".join(formatted_audio) + output.append(row) + + output = "\n" + "\n".join(output) + return output + + +def in_notebook(): # pragma: no cover + """Determines if code is running in a notebook. + + Returns + ------- + bool + Whether or not this is running in a notebook. + """ + try: + from IPython import get_ipython + + if "IPKernelApp" not in get_ipython().config: # pragma: no cover + return False + except ImportError: + return False + except AttributeError: + return False + return True + + +def disp(obj, **kwargs): # pragma: no cover + """Displays an object, depending on if its in a notebook + or not. + + Parameters + ---------- + obj : typing.Any + Any object to display. + + """ + from audiotools import AudioSignal + + IN_NOTEBOOK = in_notebook() + + if isinstance(obj, AudioSignal): + audio_elem = obj.embed(display=False, return_html=True) + if IN_NOTEBOOK: + return HTML(audio_elem) + else: + print(audio_elem) + if isinstance(obj, dict): + table = audio_table(obj, **kwargs) + if IN_NOTEBOOK: + return HTML(md.markdown(table, extras=["tables"])) + else: + print(table) + if isinstance(obj, plt.Figure): + plt.show() diff --git a/audiotools/preference.py b/audiotools/preference.py new file mode 100644 index 0000000000000000000000000000000000000000..800a852e8119dd18ea65784cf95182de2470fbc4 --- /dev/null +++ b/audiotools/preference.py @@ -0,0 +1,600 @@ +############################################################## +### Tools for creating preference tests (MUSHRA, ABX, etc) ### +############################################################## +import copy +import csv +import random +import sys +import traceback +from collections import defaultdict +from pathlib import Path +from typing import List + +import gradio as gr + +from audiotools.core.util import find_audio + +################################################################ +### Logic for audio player, and adding audio / play buttons. ### +################################################################ + +WAVESURFER = """
""" + +CUSTOM_CSS = """ +.gradio-container { + max-width: 840px !important; +} +region.wavesurfer-region:before { + content: attr(data-region-label); +} + +block { + min-width: 0 !important; +} + +#wave-timeline { + background-color: rgba(0, 0, 0, 0.8); +} + +.head.svelte-1cl284s { + display: none; +} +""" + +load_wavesurfer_js = """ +function load_wavesurfer() { + function load_script(url) { + const script = document.createElement('script'); + script.src = url; + document.body.appendChild(script); + + return new Promise((res, rej) => { + script.onload = function() { + res(); + } + script.onerror = function () { + rej(); + } + }); + } + + function create_wavesurfer() { + var options = { + container: '#waveform', + waveColor: '#F2F2F2', // Set a darker wave color + progressColor: 'white', // Set a slightly lighter progress color + loaderColor: 'white', // Set a slightly lighter loader color + cursorColor: 'black', // Set a slightly lighter cursor color + backgroundColor: '#00AAFF', // Set a black background color + barWidth: 4, + barRadius: 3, + barHeight: 1, // the height of the wave + plugins: [ + WaveSurfer.regions.create({ + regionsMinLength: 0.0, + dragSelection: { + slop: 5 + }, + color: 'hsla(200, 50%, 70%, 0.4)', + }), + WaveSurfer.timeline.create({ + container: "#wave-timeline", + primaryLabelInterval: 5.0, + secondaryLabelInterval: 1.0, + primaryFontColor: '#F2F2F2', + secondaryFontColor: '#F2F2F2', + }), + ] + }; + wavesurfer = WaveSurfer.create(options); + wavesurfer.on('region-created', region => { + wavesurfer.regions.clear(); + }); + wavesurfer.on('finish', function () { + var loop = document.getElementById("loop-button").textContent.includes("ON"); + if (loop) { + wavesurfer.play(); + } + else { + var button_elements = document.getElementsByClassName('playpause') + var buttons = Array.from(button_elements); + + for (let j = 0; j < buttons.length; j++) { + buttons[j].classList.remove("primary"); + buttons[j].classList.add("secondary"); + buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") + } + } + }); + + wavesurfer.on('region-out', function () { + var loop = document.getElementById("loop-button").textContent.includes("ON"); + if (!loop) { + var button_elements = document.getElementsByClassName('playpause') + var buttons = Array.from(button_elements); + + for (let j = 0; j < buttons.length; j++) { + buttons[j].classList.remove("primary"); + buttons[j].classList.add("secondary"); + buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") + } + wavesurfer.pause(); + } + }); + + console.log("Created WaveSurfer object.") + } + + load_script('https://unpkg.com/wavesurfer.js@6.6.4') + .then(() => { + load_script("https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.timeline.min.js") + .then(() => { + load_script('https://unpkg.com/wavesurfer.js@6.6.4/dist/plugin/wavesurfer.regions.min.js') + .then(() => { + console.log("Loaded regions"); + create_wavesurfer(); + document.getElementById("start-survey").click(); + }) + }) + }); +} +""" + +play = lambda i: """ +function play() { + var audio_elements = document.getElementsByTagName('audio'); + var button_elements = document.getElementsByClassName('playpause') + + var audio_array = Array.from(audio_elements); + var buttons = Array.from(button_elements); + + var src_link = audio_array[{i}].getAttribute("src"); + console.log(src_link); + + var loop = document.getElementById("loop-button").textContent.includes("ON"); + var playing = buttons[{i}].textContent.includes("Stop"); + + for (let j = 0; j < buttons.length; j++) { + if (j != {i} || playing) { + buttons[j].classList.remove("primary"); + buttons[j].classList.add("secondary"); + buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") + } + else { + buttons[j].classList.remove("secondary"); + buttons[j].classList.add("primary"); + buttons[j].textContent = buttons[j].textContent.replace("Play", "Stop") + } + } + + if (playing) { + wavesurfer.pause(); + wavesurfer.seekTo(0.0); + } + else { + wavesurfer.load(src_link); + wavesurfer.on('ready', function () { + var region = Object.values(wavesurfer.regions.list)[0]; + + if (region != null) { + region.loop = loop; + region.play(); + } else { + wavesurfer.play(); + } + }); + } +} +""".replace( + "{i}", str(i) +) + +clear_regions = """ +function clear_regions() { + wavesurfer.clearRegions(); +} +""" + +reset_player = """ +function reset_player() { + wavesurfer.clearRegions(); + wavesurfer.pause(); + wavesurfer.seekTo(0.0); + + var button_elements = document.getElementsByClassName('playpause') + var buttons = Array.from(button_elements); + + for (let j = 0; j < buttons.length; j++) { + buttons[j].classList.remove("primary"); + buttons[j].classList.add("secondary"); + buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play") + } +} +""" + +loop_region = """ +function loop_region() { + var element = document.getElementById("loop-button"); + var loop = element.textContent.includes("OFF"); + console.log(loop); + + try { + var region = Object.values(wavesurfer.regions.list)[0]; + region.loop = loop; + } catch {} + + if (loop) { + element.classList.remove("secondary"); + element.classList.add("primary"); + element.textContent = "Looping ON"; + } else { + element.classList.remove("primary"); + element.classList.add("secondary"); + element.textContent = "Looping OFF"; + } +} +""" + + +class Player: + def __init__(self, app): + self.app = app + + self.app.load(_js=load_wavesurfer_js) + self.app.css = CUSTOM_CSS + + self.wavs = [] + self.position = 0 + + def create(self): + gr.HTML(WAVESURFER) + gr.Markdown( + "Click and drag on the waveform above to select a region for playback. " + "Once created, the region can be moved around and resized. " + "Clear the regions using the button below. Hit play on one of the buttons below to start!" + ) + + with gr.Row(): + clear = gr.Button("Clear region") + loop = gr.Button("Looping OFF", elem_id="loop-button") + + loop.click(None, _js=loop_region) + clear.click(None, _js=clear_regions) + + gr.HTML("
") + + def add(self, name: str = "Play"): + i = self.position + self.wavs.append( + { + "audio": gr.Audio(visible=False), + "button": gr.Button(name, elem_classes=["playpause"]), + "position": i, + } + ) + self.wavs[-1]["button"].click(None, _js=play(i)) + self.position += 1 + return self.wavs[-1] + + def to_list(self): + return [x["audio"] for x in self.wavs] + + +############################################################ +### Keeping track of users, and CSS for the progress bar ### +############################################################ + +load_tracker = lambda name: """ +function load_name() { + function setCookie(name, value, exp_days) { + var d = new Date(); + d.setTime(d.getTime() + (exp_days*24*60*60*1000)); + var expires = "expires=" + d.toGMTString(); + document.cookie = name + "=" + value + ";" + expires + ";path=/"; + } + + function getCookie(name) { + var cname = name + "="; + var decodedCookie = decodeURIComponent(document.cookie); + var ca = decodedCookie.split(';'); + for(var i = 0; i < ca.length; i++){ + var c = ca[i]; + while(c.charAt(0) == ' '){ + c = c.substring(1); + } + if(c.indexOf(cname) == 0){ + return c.substring(cname.length, c.length); + } + } + return ""; + } + + name = getCookie("{name}"); + if (name == "") { + name = Math.random().toString(36).slice(2); + console.log(name); + setCookie("name", name, 30); + } + name = getCookie("{name}"); + return name; +} +""".replace( + "{name}", name +) + +# Progress bar + +progress_template = """ + + + + Progress Bar + + + +
+
+
{TEXT}
+
+ + +""" + + +def create_tracker(app, cookie_name="name"): + user = gr.Text(label="user", interactive=True, visible=False, elem_id="user") + app.load(_js=load_tracker(cookie_name), outputs=user) + return user + + +################################################################# +### CSS and HTML for labeling sliders for both ABX and MUSHRA ### +################################################################# + +slider_abx = """ + + + + + Labels Example + + + +
+
Prefer A
+
Toss-up
+
Prefer B
+
+ + +""" + +slider_mushra = """ + + + + + Labels Example + + + +
+
bad
+
poor
+
fair
+
good
+
excellent
+
+ + +""" + +######################################################### +### Handling loading audio and tracking session state ### +######################################################### + + +class Samples: + def __init__(self, folder: str, shuffle: bool = True, n_samples: int = None): + files = find_audio(folder) + samples = defaultdict(lambda: defaultdict()) + + for f in files: + condition = f.parent.stem + samples[f.name][condition] = f + + self.samples = samples + self.names = list(samples.keys()) + self.filtered = False + self.current = 0 + + if shuffle: + random.shuffle(self.names) + + self.n_samples = len(self.names) if n_samples is None else n_samples + + def get_updates(self, idx, order): + key = self.names[idx] + return [gr.update(value=str(self.samples[key][o])) for o in order] + + def progress(self): + try: + pct = self.current / len(self) * 100 + except: # pragma: no cover + pct = 100 + text = f"On {self.current} / {len(self)} samples" + pbar = ( + copy.copy(progress_template) + .replace("{PROGRESS}", str(pct)) + .replace("{TEXT}", str(text)) + ) + return gr.update(value=pbar) + + def __len__(self): + return self.n_samples + + def filter_completed(self, user, save_path): + if not self.filtered: + done = [] + if Path(save_path).exists(): + with open(save_path, "r") as f: + reader = csv.DictReader(f) + done = [r["sample"] for r in reader if r["user"] == user] + self.names = [k for k in self.names if k not in done] + self.names = self.names[: self.n_samples] + self.filtered = True # Avoid filtering more than once per session. + + def get_next_sample(self, reference, conditions): + random.shuffle(conditions) + if reference is not None: + self.order = [reference] + conditions + else: + self.order = conditions + + try: + updates = self.get_updates(self.current, self.order) + self.current += 1 + done = gr.update(interactive=True) + pbar = self.progress() + except: + traceback.print_exc() + updates = [gr.update() for _ in range(len(self.order))] + done = gr.update(value="No more samples!", interactive=False) + self.current = len(self) + pbar = self.progress() + + return updates, done, pbar + + +def save_result(result, save_path): + with open(save_path, mode="a", newline="") as file: + writer = csv.DictWriter(file, fieldnames=sorted(list(result.keys()))) + if file.tell() == 0: + writer.writeheader() + writer.writerow(result) diff --git a/src/inference.py b/src/inference.py new file mode 100644 index 0000000000000000000000000000000000000000..654090b58a9e85e740821e485428b5fb37766edb --- /dev/null +++ b/src/inference.py @@ -0,0 +1,169 @@ +import os +import random +import pandas as pd +import torch +import librosa +import numpy as np +import soundfile as sf +from tqdm import tqdm +from .utils import scale_shift_re + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +@torch.no_grad() +def inference(autoencoder, unet, gt, gt_mask, + tokenizer, text_encoder, + params, noise_scheduler, + text_raw, neg_text=None, + audio_frames=500, + guidance_scale=3, guidance_rescale=0.0, + ddim_steps=50, eta=1, random_seed=2024, + device='cuda', + ): + if neg_text is None: + neg_text = [""] + if tokenizer is not None: + text_batch = tokenizer(text_raw, + max_length=params['text_encoder']['max_length'], + padding="max_length", truncation=True, return_tensors="pt") + text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool() + text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state + + uncond_text_batch = tokenizer(neg_text, + max_length=params['text_encoder']['max_length'], + padding="max_length", truncation=True, return_tensors="pt") + uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool() + uncond_text = text_encoder(input_ids=uncond_text, + attention_mask=uncond_text_mask).last_hidden_state + else: + text, text_mask = None, None + guidance_scale = None + + codec_dim = params['model']['out_chans'] + unet.eval() + + if random_seed is not None: + generator = torch.Generator(device=device).manual_seed(random_seed) + else: + generator = torch.Generator(device=device) + generator.seed() + + noise_scheduler.set_timesteps(ddim_steps) + + # init noise + noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device) + latents = noise + + for t in noise_scheduler.timesteps: + latents = noise_scheduler.scale_model_input(latents, t) + + if guidance_scale: + + latents_combined = torch.cat([latents, latents], dim=0) + text_combined = torch.cat([text, uncond_text], dim=0) + text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0) + + if gt is not None: + gt_combined = torch.cat([gt, gt], dim=0) + gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0) + else: + gt_combined = None + gt_mask_combined = None + + output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined, + cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined) + output_text, output_uncond = torch.chunk(output_combined, 2, dim=0) + + output_pred = output_uncond + guidance_scale * (output_text - output_uncond) + if guidance_rescale > 0.0: + output_pred = rescale_noise_cfg(output_pred, output_text, + guidance_rescale=guidance_rescale) + else: + output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask, + cls_token=None, gt=gt, mae_mask_infer=gt_mask) + + latents = noise_scheduler.step(model_output=output_pred, timestep=t, + sample=latents, + eta=eta, generator=generator).prev_sample + + pred = scale_shift_re(latents, params['autoencoder']['scale'], + params['autoencoder']['shift']) + if gt is not None: + pred[~gt_mask] = gt[~gt_mask] + pred_wav = autoencoder(embedding=pred) + return pred_wav + + +@torch.no_grad() +def eval_udit(autoencoder, unet, + tokenizer, text_encoder, + params, noise_scheduler, + val_df, subset, + audio_frames, mae=False, + guidance_scale=3, guidance_rescale=0.0, + ddim_steps=50, eta=1, random_seed=2023, + device='cuda', + epoch=0, save_path='logs/eval/', val_num=5): + val_df = pd.read_csv(val_df) + val_df = val_df[val_df['split'] == subset] + if mae: + val_df = val_df[val_df['audio_length'] != 0] + + save_path = save_path + str(epoch) + '/' + os.makedirs(save_path, exist_ok=True) + + for i in tqdm(range(len(val_df))): + row = val_df.iloc[i] + text = [row['caption']] + if mae: + audio_path = params['data']['val_dir'] + str(row['audio_path']) + gt, sr = librosa.load(audio_path, sr=params['data']['sr']) + gt = gt / (np.max(np.abs(gt)) + 1e-9) + sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr']) + num_samples = 10 * sr + if len(gt) < num_samples: + padding = num_samples - len(gt) + gt = np.pad(gt, (0, padding), 'constant') + else: + gt = gt[:num_samples] + gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device) + gt = autoencoder(audio=gt) + B, D, L = gt.shape + mask_len = int(L * 0.2) + gt_mask = torch.zeros(B, D, L).to(device) + for _ in range(2): + start = random.randint(0, L - mask_len) + gt_mask[:, :, start:start + mask_len] = 1 + gt_mask = gt_mask.bool() + else: + gt = None + gt_mask = None + + pred = inference(autoencoder, unet, gt, gt_mask, + tokenizer, text_encoder, + params, noise_scheduler, + text, neg_text=None, + audio_frames=audio_frames, + guidance_scale=guidance_scale, guidance_rescale=guidance_rescale, + ddim_steps=ddim_steps, eta=eta, random_seed=random_seed, + device=device) + + pred = pred.cpu().numpy().squeeze(0).squeeze(0) + + sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr']) + + if i + 1 >= val_num: + break diff --git a/src/inference_controlnet.py b/src/inference_controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..eeb5dd30228c9a597c25e523d0d94c564bbe910b --- /dev/null +++ b/src/inference_controlnet.py @@ -0,0 +1,129 @@ +import os +import random +import pandas as pd +import torch +import librosa +import numpy as np +import soundfile as sf +from tqdm import tqdm +from .utils import scale_shift_re + + +def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0): + """ + Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and + Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4 + """ + std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True) + std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True) + # rescale the results from guidance (fixes overexposure) + noise_pred_rescaled = noise_cfg * (std_text / std_cfg) + # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images + noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg + return noise_cfg + + +@torch.no_grad() +def inference(autoencoder, unet, controlnet, + gt, gt_mask, condition, + tokenizer, text_encoder, + params, noise_scheduler, + text_raw, neg_text=None, + audio_frames=500, + guidance_scale=3, guidance_rescale=0.0, + ddim_steps=50, eta=1, random_seed=2024, + conditioning_scale=1.0, + device='cuda', + ): + if neg_text is None: + neg_text = [""] + if tokenizer is not None: + text_batch = tokenizer(text_raw, + max_length=params['text_encoder']['max_length'], + padding="max_length", truncation=True, return_tensors="pt") + text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool() + text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state + + uncond_text_batch = tokenizer(neg_text, + max_length=params['text_encoder']['max_length'], + padding="max_length", truncation=True, return_tensors="pt") + uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool() + uncond_text = text_encoder(input_ids=uncond_text, + attention_mask=uncond_text_mask).last_hidden_state + else: + text, text_mask = None, None + guidance_scale = None + + codec_dim = params['model']['out_chans'] + unet.eval() + controlnet.eval() + + if random_seed is not None: + generator = torch.Generator(device=device).manual_seed(random_seed) + else: + generator = torch.Generator(device=device) + generator.seed() + + noise_scheduler.set_timesteps(ddim_steps) + + # init noise + noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device) + latents = noise + + for t in noise_scheduler.timesteps: + latents = noise_scheduler.scale_model_input(latents, t) + + if guidance_scale: + latents_combined = torch.cat([latents, latents], dim=0) + text_combined = torch.cat([text, uncond_text], dim=0) + text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0) + condition_combined = torch.cat([condition, condition], dim=0) + + if gt is not None: + gt_combined = torch.cat([gt, gt], dim=0) + gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0) + else: + gt_combined = None + gt_mask_combined = None + + x, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined, + cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined, + forward_model=False) + controlnet_skips = controlnet(x, t, text_combined, + context_mask=text_mask_combined, + cls_token=None, + condition=condition_combined, + conditioning_scale=conditioning_scale) + output_combined = unet.model(x, t, text_combined, + context_mask=text_mask_combined, + cls_token=None, controlnet_skips=controlnet_skips) + + output_text, output_uncond = torch.chunk(output_combined, 2, dim=0) + + output_pred = output_uncond + guidance_scale * (output_text - output_uncond) + if guidance_rescale > 0.0: + output_pred = rescale_noise_cfg(output_pred, output_text, + guidance_rescale=guidance_rescale) + else: + x, _ = unet(latents, t, text, context_mask=text_mask, + cls_token=None, gt=gt, mae_mask_infer=gt_mask, + forward_model=False) + controlnet_skips = controlnet(x, t, text, + context_mask=text_mask, + cls_token=None, + condition=condition, + conditioning_scale=conditioning_scale) + output_pred = unet.model(x, t, text, + context_mask=text_mask, + cls_token=None, controlnet_skips=controlnet_skips) + + latents = noise_scheduler.step(model_output=output_pred, timestep=t, + sample=latents, + eta=eta, generator=generator).prev_sample + + pred = scale_shift_re(latents, params['autoencoder']['scale'], + params['autoencoder']['shift']) + if gt is not None: + pred[~gt_mask] = gt[~gt_mask] + pred_wav = autoencoder(embedding=pred) + return pred_wav \ No newline at end of file diff --git a/src/models/.ipynb_checkpoints/blocks-checkpoint.py b/src/models/.ipynb_checkpoints/blocks-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..8ef730009cb7cf664c5d9021e551b275680d11f3 --- /dev/null +++ b/src/models/.ipynb_checkpoints/blocks-checkpoint.py @@ -0,0 +1,325 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from .utils.attention import Attention, JointAttention +from .utils.modules import unpatchify, FeedForward +from .utils.modules import film_modulate + + +class AdaLN(nn.Module): + def __init__(self, dim, ada_mode='ada', r=None, alpha=None): + super().__init__() + self.ada_mode = ada_mode + self.scale_shift_table = None + if ada_mode == 'ada': + # move nn.silu outside + self.time_ada = nn.Linear(dim, 6 * dim, bias=True) + elif ada_mode == 'ada_single': + # adaln used in pixel-art alpha + self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) + elif ada_mode in ['ada_lora', 'ada_lora_bias']: + self.lora_a = nn.Linear(dim, r * 6, bias=False) + self.lora_b = nn.Linear(r * 6, dim * 6, bias=False) + self.scaling = alpha / r + if ada_mode == 'ada_lora_bias': + # take bias out for consistency + self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) + else: + raise NotImplementedError + + def forward(self, time_token=None, time_ada=None): + if self.ada_mode == 'ada': + assert time_ada is None + B = time_token.shape[0] + time_ada = self.time_ada(time_token).reshape(B, 6, -1) + elif self.ada_mode == 'ada_single': + B = time_ada.shape[0] + time_ada = time_ada.reshape(B, 6, -1) + time_ada = self.scale_shift_table[None] + time_ada + elif self.ada_mode in ['ada_lora', 'ada_lora_bias']: + B = time_ada.shape[0] + time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling + time_ada = time_ada + time_ada_lora + time_ada = time_ada.reshape(B, 6, -1) + if self.scale_shift_table is not None: + time_ada = self.scale_shift_table[None] + time_ada + else: + raise NotImplementedError + return time_ada + + +class DiTBlock(nn.Module): + """ + A modified PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__(self, dim, context_dim=None, + num_heads=8, mlp_ratio=4., + qkv_bias=False, qk_scale=None, qk_norm=None, + act_layer='gelu', norm_layer=nn.LayerNorm, + time_fusion='none', + ada_lora_rank=None, ada_lora_alpha=None, + skip=False, skip_norm=False, + rope_mode='none', + context_norm=False, + use_checkpoint=False): + + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, + qk_norm=qk_norm, + rope_mode=rope_mode) + + if context_dim is not None: + self.use_context = True + self.cross_attn = Attention(dim=dim, + num_heads=num_heads, + context_dim=context_dim, + qkv_bias=qkv_bias, qk_scale=qk_scale, + qk_norm=qk_norm, + rope_mode='none') + self.norm2 = norm_layer(dim) + if context_norm: + self.norm_context = norm_layer(context_dim) + else: + self.norm_context = nn.Identity() + else: + self.use_context = False + + self.norm3 = norm_layer(dim) + self.mlp = FeedForward(dim=dim, mult=mlp_ratio, + activation_fn=act_layer, dropout=0) + + self.use_adanorm = True if time_fusion != 'token' else False + if self.use_adanorm: + self.adaln = AdaLN(dim, ada_mode=time_fusion, + r=ada_lora_rank, alpha=ada_lora_alpha) + if skip: + self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity() + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None + + self.use_checkpoint = use_checkpoint + + def forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + if self.use_checkpoint: + return checkpoint(self._forward, x, + time_token, time_ada, skip, context, + x_mask, context_mask, extras, + use_reentrant=False) + else: + return self._forward(x, + time_token, time_ada, skip, context, + x_mask, context_mask, extras) + + def _forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + B, T, C = x.shape + if self.skip_linear is not None: + assert skip is not None + cat = torch.cat([x, skip], dim=-1) + cat = self.skip_norm(cat) + x = self.skip_linear(cat) + + if self.use_adanorm: + time_ada = self.adaln(time_token, time_ada) + (shift_msa, scale_msa, gate_msa, + shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1) + + # self attention + if self.use_adanorm: + x_norm = film_modulate(self.norm1(x), shift=shift_msa, + scale=scale_msa) + x = x + (1 - gate_msa) * self.attn(x_norm, context=None, + context_mask=x_mask, + extras=extras) + else: + x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask, + extras=extras) + + # cross attention + if self.use_context: + assert context is not None + x = x + self.cross_attn(x=self.norm2(x), + context=self.norm_context(context), + context_mask=context_mask, extras=extras) + + # mlp + if self.use_adanorm: + x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp) + x = x + (1 - gate_mlp) * self.mlp(x_norm) + else: + x = x + self.mlp(self.norm3(x)) + + return x + + +class JointDiTBlock(nn.Module): + """ + A modified PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__(self, dim, context_dim=None, + num_heads=8, mlp_ratio=4., + qkv_bias=False, qk_scale=None, qk_norm=None, + act_layer='gelu', norm_layer=nn.LayerNorm, + time_fusion='none', + ada_lora_rank=None, ada_lora_alpha=None, + skip=(False, False), + rope_mode=False, + context_norm=False, + use_checkpoint=False,): + + super().__init__() + # no cross attention + assert context_dim is None + self.attn_norm_x = norm_layer(dim) + self.attn_norm_c = norm_layer(dim) + self.attn = JointAttention(dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, + qk_norm=qk_norm, + rope_mode=rope_mode) + self.ffn_norm_x = norm_layer(dim) + self.ffn_norm_c = norm_layer(dim) + self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio, + activation_fn=act_layer, dropout=0) + self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio, + activation_fn=act_layer, dropout=0) + + # Zero-out the shift table + self.use_adanorm = True if time_fusion != 'token' else False + if self.use_adanorm: + self.adaln = AdaLN(dim, ada_mode=time_fusion, + r=ada_lora_rank, alpha=ada_lora_alpha) + + if skip is False: + skip_x, skip_c = False, False + else: + skip_x, skip_c = skip + + self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None + self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None + + self.use_checkpoint = use_checkpoint + + def forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + if self.use_checkpoint: + return checkpoint(self._forward, x, + time_token, time_ada, skip, + context, x_mask, context_mask, extras, + use_reentrant=False) + else: + return self._forward(x, + time_token, time_ada, skip, + context, x_mask, context_mask, extras) + + def _forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + + assert context is None and context_mask is None + + context, x = x[:, :extras, :], x[:, extras:, :] + context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:] + + if skip is not None: + skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :] + + B, T, C = x.shape + if self.skip_linear_x is not None: + x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1)) + + if self.skip_linear_c is not None: + context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1)) + + if self.use_adanorm: + time_ada = self.adaln(time_token, time_ada) + (shift_msa, scale_msa, gate_msa, + shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1) + + # self attention + x_norm = self.attn_norm_x(x) + c_norm = self.attn_norm_c(context) + if self.use_adanorm: + x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa) + x_out, c_out = self.attn(x_norm, context=c_norm, + x_mask=x_mask, context_mask=context_mask, + extras=extras) + if self.use_adanorm: + x = x + (1 - gate_msa) * x_out + else: + x = x + x_out + context = context + c_out + + # mlp + if self.use_adanorm: + x_norm = film_modulate(self.ffn_norm_x(x), + shift=shift_mlp, scale=scale_mlp) + x = x + (1 - gate_mlp) * self.mlp_x(x_norm) + else: + x = x + self.mlp_x(self.ffn_norm_x(x)) + + c_norm = self.ffn_norm_c(context) + context = context + self.mlp_c(c_norm) + + return torch.cat((context, x), dim=1) + + +class FinalBlock(nn.Module): + def __init__(self, embed_dim, patch_size, in_chans, + img_size, + input_type='2d', + norm_layer=nn.LayerNorm, + use_conv=True, + use_adanorm=True): + super().__init__() + self.in_chans = in_chans + self.img_size = img_size + self.input_type = input_type + + self.norm = norm_layer(embed_dim) + if use_adanorm: + self.use_adanorm = True + else: + self.use_adanorm = False + + if input_type == '2d': + self.patch_dim = patch_size ** 2 * in_chans + self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) + if use_conv: + self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, + 3, padding=1) + else: + self.final_layer = nn.Identity() + + elif input_type == '1d': + self.patch_dim = patch_size * in_chans + self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) + if use_conv: + self.final_layer = nn.Conv1d(self.in_chans, self.in_chans, + 3, padding=1) + else: + self.final_layer = nn.Identity() + + def forward(self, x, time_ada=None, extras=0): + B, T, C = x.shape + x = x[:, extras:, :] + # only handle generation target + if self.use_adanorm: + shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1) + x = film_modulate(self.norm(x), shift, scale) + else: + x = self.norm(x) + x = self.linear(x) + x = unpatchify(x, self.in_chans, self.input_type, self.img_size) + x = self.final_layer(x) + return x \ No newline at end of file diff --git a/src/models/.ipynb_checkpoints/conditioners-checkpoint.py b/src/models/.ipynb_checkpoints/conditioners-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..cade7febf61ef005f421c42cf17bb1bb2935a751 --- /dev/null +++ b/src/models/.ipynb_checkpoints/conditioners-checkpoint.py @@ -0,0 +1,183 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat +import math +from .udit import UDiT +from .utils.span_mask import compute_mask_indices + + +class EmbeddingCFG(nn.Module): + """ + Handles label dropout for classifier-free guidance. + """ + # todo: support 2D input + + def __init__(self, in_channels): + super().__init__() + self.cfg_embedding = nn.Parameter( + torch.randn(in_channels) / in_channels ** 0.5) + + def token_drop(self, condition, condition_mask, cfg_prob): + """ + Drops labels to enable classifier-free guidance. + """ + b, t, device = condition.shape[0], condition.shape[1], condition.device + drop_ids = torch.rand(b, device=device) < cfg_prob + uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t) + condition = torch.where(drop_ids[:, None, None], uncond, condition) + if condition_mask is not None: + condition_mask[drop_ids] = False + condition_mask[drop_ids, 0] = True + + return condition, condition_mask + + def forward(self, condition, condition_mask, cfg_prob=0.0): + if condition_mask is not None: + condition_mask = condition_mask.clone() + if cfg_prob > 0: + condition, condition_mask = self.token_drop(condition, + condition_mask, + cfg_prob) + return condition, condition_mask + + +class DiscreteCFG(nn.Module): + def __init__(self, replace_id=2): + super(DiscreteCFG, self).__init__() + self.replace_id = replace_id + + def forward(self, context, context_mask, cfg_prob): + context = context.clone() + if context_mask is not None: + context_mask = context_mask.clone() + if cfg_prob > 0: + cfg_mask = torch.rand(len(context)) < cfg_prob + if torch.any(cfg_mask): + context[cfg_mask] = 0 + context[cfg_mask, 0] = self.replace_id + if context_mask is not None: + context_mask[cfg_mask] = False + context_mask[cfg_mask, 0] = True + return context, context_mask + + +class CFGModel(nn.Module): + def __init__(self, context_dim, backbone): + super().__init__() + self.model = backbone + self.context_cfg = EmbeddingCFG(context_dim) + + def forward(self, x, timesteps, + context, x_mask=None, context_mask=None, + cfg_prob=0.0): + context = self.context_cfg(context, cfg_prob) + x = self.model(x=x, timesteps=timesteps, + context=context, + x_mask=x_mask, context_mask=context_mask) + return x + + +class ConcatModel(nn.Module): + def __init__(self, backbone, in_dim, stride=[]): + super().__init__() + self.model = backbone + + self.downsample_layers = nn.ModuleList() + for i, s in enumerate(stride): + downsample_layer = nn.Conv1d( + in_dim, + in_dim * 2, + kernel_size=2 * s, + stride=s, + padding=math.ceil(s / 2), + ) + self.downsample_layers.append(downsample_layer) + in_dim = in_dim * 2 + + self.context_cfg = EmbeddingCFG(in_dim) + + def forward(self, x, timesteps, + context, x_mask=None, + cfg=False, cfg_prob=0.0): + + # todo: support 2D input + # x: B, C, L + # context: B, C, L + + for downsample_layer in self.downsample_layers: + context = downsample_layer(context) + + context = context.transpose(1, 2) + context = self.context_cfg(caption=context, + cfg=cfg, cfg_prob=cfg_prob) + context = context.transpose(1, 2) + + assert context.shape[-1] == x.shape[-1] + x = torch.cat([context, x], dim=1) + x = self.model(x=x, timesteps=timesteps, + context=None, x_mask=x_mask, context_mask=None) + return x + + +class MaskDiT(nn.Module): + def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs): + super().__init__() + self.model = UDiT(**kwargs) + self.mae = mae + if self.mae: + out_channel = kwargs.pop('out_chans', None) + self.mask_embed = nn.Parameter(torch.zeros((out_channel))) + self.mae_prob = mae_prob + self.mask_ratio = mask_ratio + self.mask_span = mask_span + + def random_masking(self, gt, mask_ratios, mae_mask_infer=None): + B, D, L = gt.shape + if mae_mask_infer is None: + # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1) + mask_ratios = mask_ratios.cpu().numpy() + mask = compute_mask_indices(shape=[B, L], + padding_mask=None, + mask_prob=mask_ratios, + mask_length=self.mask_span, + mask_type="static", + mask_other=0.0, + min_masks=1, + no_overlap=False, + min_space=0,) + mask = mask.unsqueeze(1).expand_as(gt) + else: + mask = mae_mask_infer + mask = mask.expand_as(gt) + gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask] + return gt, mask.type_as(gt) + + def forward(self, x, timesteps, context, + x_mask=None, context_mask=None, cls_token=None, + gt=None, mae_mask_infer=None, + forward_model=True): + # todo: handle controlnet inside + mae_mask = torch.ones_like(x) + if self.mae: + if gt is not None: + B, D, L = gt.shape + mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device) + gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer) + # apply mae only to the selected batches + if mae_mask_infer is None: + # determine mae batch + mae_batch = torch.rand(B) < self.mae_prob + gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch] + mae_mask[~mae_batch] = 1.0 + else: + B, D, L = x.shape + gt = self.mask_embed.view(1, D, 1).expand_as(x) + x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1) + + if forward_model: + x = self.model(x=x, timesteps=timesteps, context=context, + x_mask=x_mask, context_mask=context_mask, + cls_token=cls_token) + # print(mae_mask[:, 0, :].sum(dim=-1)) + return x, mae_mask diff --git a/src/models/.ipynb_checkpoints/controlnet-checkpoint.py b/src/models/.ipynb_checkpoints/controlnet-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..1750621847ed116a6fbab55a50e67963699d6a5a --- /dev/null +++ b/src/models/.ipynb_checkpoints/controlnet-checkpoint.py @@ -0,0 +1,318 @@ +import torch +import torch.nn as nn + +from .utils.modules import PatchEmbed, TimestepEmbedder +from .utils.modules import PE_wrapper, RMSNorm +from .blocks import DiTBlock, JointDiTBlock +from .utils.span_mask import compute_mask_indices + + +class DiTControlNetEmbed(nn.Module): + def __init__(self, in_chans, out_chans, blocks, + cond_mask=False, cond_mask_prob=None, + cond_mask_ratio=None, cond_mask_span=None): + super().__init__() + self.conv_in = nn.Conv1d(in_chans, blocks[0], kernel_size=1) + + self.cond_mask = cond_mask + if self.cond_mask: + self.mask_embed = nn.Parameter(torch.zeros((blocks[0]))) + self.mask_prob = cond_mask_prob + self.mask_ratio = cond_mask_ratio + self.mask_span = cond_mask_span + blocks[0] = blocks[0] + 1 + + conv_blocks = [] + for i in range(len(blocks) - 1): + channel_in = blocks[i] + channel_out = blocks[i + 1] + block = nn.Sequential( + nn.Conv1d(channel_in, channel_in, kernel_size=3, padding=1), + nn.SiLU(), + nn.Conv1d(channel_in, channel_out, kernel_size=3, padding=1, stride=2), + nn.SiLU(),) + conv_blocks.append(block) + self.blocks = nn.ModuleList(conv_blocks) + + self.conv_out = nn.Conv1d(blocks[-1], out_chans, kernel_size=1) + nn.init.zeros_(self.conv_out.weight) + nn.init.zeros_(self.conv_out.bias) + + def random_masking(self, gt, mask_ratios, mae_mask_infer=None): + B, D, L = gt.shape + if mae_mask_infer is None: + # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1) + mask_ratios = mask_ratios.cpu().numpy() + mask = compute_mask_indices(shape=[B, L], + padding_mask=None, + mask_prob=mask_ratios, + mask_length=self.mask_span, + mask_type="static", + mask_other=0.0, + min_masks=1, + no_overlap=False, + min_space=0,) + # only apply mask to some batches + mask_batch = torch.rand(B) < self.mask_prob + mask[~mask_batch] = False + mask = mask.unsqueeze(1).expand_as(gt) + else: + mask = mae_mask_infer + mask = mask.expand_as(gt) + gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask].type_as(gt) + return gt, mask.type_as(gt) + + def forward(self, conditioning, cond_mask_infer=None): + embedding = self.conv_in(conditioning) + + if self.cond_mask: + B, D, L = embedding.shape + if not self.training and cond_mask_infer is None: + cond_mask_infer = torch.zeros_like(embedding).bool() + mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(embedding.device) + embedding, cond_mask = self.random_masking(embedding, mask_ratios, cond_mask_infer) + embedding = torch.cat([embedding, cond_mask[:, 0:1, :]], dim=1) + + for block in self.blocks: + embedding = block(embedding) + + embedding = self.conv_out(embedding) + + # B, L, C + embedding = embedding.transpose(1, 2).contiguous() + + return embedding + + +class DiTControlNet(nn.Module): + def __init__(self, + img_size=(224, 224), patch_size=16, in_chans=3, + input_type='2d', out_chans=None, + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., + qkv_bias=False, qk_scale=None, qk_norm=None, + act_layer='gelu', norm_layer='layernorm', + context_norm=False, + use_checkpoint=False, + # time fusion ada or token + time_fusion='token', + ada_lora_rank=None, ada_lora_alpha=None, + cls_dim=None, + # max length is only used for concat + context_dim=768, context_fusion='concat', + context_max_length=128, context_pe_method='sinu', + pe_method='abs', rope_mode='none', + use_conv=True, + skip=True, skip_norm=True, + # controlnet configs + cond_in=None, cond_blocks=None, + cond_mask=False, cond_mask_prob=None, + cond_mask_ratio=None, cond_mask_span=None, + **kwargs): + super().__init__() + self.num_features = self.embed_dim = embed_dim + # input + self.in_chans = in_chans + self.input_type = input_type + if self.input_type == '2d': + num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + elif self.input_type == '1d': + num_patches = img_size // patch_size + self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, input_type=input_type) + out_chans = in_chans if out_chans is None else out_chans + self.out_chans = out_chans + + # position embedding + self.rope = rope_mode + self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method, + length=num_patches) + + print(f'x position embedding: {pe_method}') + print(f'rope mode: {self.rope}') + + # time embed + self.time_embed = TimestepEmbedder(embed_dim) + self.time_fusion = time_fusion + self.use_adanorm = False + + # cls embed + if cls_dim is not None: + self.cls_embed = nn.Sequential( + nn.Linear(cls_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True),) + else: + self.cls_embed = None + + # time fusion + if time_fusion == 'token': + # put token at the beginning of sequence + self.extras = 2 if self.cls_embed else 1 + self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras) + elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']: + self.use_adanorm = True + # aviod repetitive silu for each adaln block + self.time_act = nn.SiLU() + self.extras = 0 + if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']: + # shared adaln + self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) + else: + self.time_ada = None + else: + raise NotImplementedError + print(f'time fusion mode: {self.time_fusion}') + + # context + # use a simple projection + self.use_context = False + self.context_cross = False + self.context_max_length = context_max_length + self.context_fusion = 'none' + if context_dim is not None: + self.use_context = True + self.context_embed = nn.Sequential( + nn.Linear(context_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True),) + self.context_fusion = context_fusion + if context_fusion == 'concat' or context_fusion == 'joint': + self.extras += context_max_length + self.context_pe = PE_wrapper(dim=embed_dim, + method=context_pe_method, + length=context_max_length) + # no cross attention layers + context_dim = None + elif context_fusion == 'cross': + self.context_pe = PE_wrapper(dim=embed_dim, + method=context_pe_method, + length=context_max_length) + self.context_cross = True + context_dim = embed_dim + else: + raise NotImplementedError + print(f'context fusion mode: {context_fusion}') + print(f'context position embedding: {context_pe_method}') + + if self.context_fusion == 'joint': + Block = JointDiTBlock + else: + Block = DiTBlock + + # norm layers + if norm_layer == 'layernorm': + norm_layer = nn.LayerNorm + elif norm_layer == 'rmsnorm': + norm_layer = RMSNorm + else: + raise NotImplementedError + + self.in_blocks = nn.ModuleList([ + Block( + dim=embed_dim, context_dim=context_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, + act_layer=act_layer, norm_layer=norm_layer, + time_fusion=time_fusion, + ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, + skip=False, skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + self.controlnet_pre = DiTControlNetEmbed(in_chans=cond_in, out_chans=embed_dim, + blocks=cond_blocks, + cond_mask=cond_mask, + cond_mask_prob=cond_mask_prob, + cond_mask_ratio=cond_mask_ratio, + cond_mask_span=cond_mask_span) + + controlnet_zero_blocks = [] + for i in range(depth // 2): + block = nn.Linear(embed_dim, embed_dim) + nn.init.zeros_(block.weight) + nn.init.zeros_(block.bias) + controlnet_zero_blocks.append(block) + self.controlnet_zero_blocks = nn.ModuleList(controlnet_zero_blocks) + + print('ControlNet ready \n') + + def set_trainable(self): + for param in self.parameters(): + param.requires_grad = False + + # only train input_proj, blocks, and output_proj + for module_name in ['controlnet_pre', 'in_blocks', 'controlnet_zero_blocks']: + module = getattr(self, module_name, None) + if module is not None: + for param in module.parameters(): + param.requires_grad = True + module.train() + else: + print(f'\n!!!warning missing trainable blocks: {module_name}!!!\n') + + def forward(self, x, timesteps, context, + x_mask=None, context_mask=None, + cls_token=None, + condition=None, cond_mask_infer=None, + conditioning_scale=1.0): + # make it compatible with int time step during inference + if timesteps.dim() == 0: + timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long) + + x = self.patch_embed(x) + # add condition to x + condition = self.controlnet_pre(condition) + x = x + condition + x = self.x_pe(x) + + B, L, D = x.shape + + if self.use_context: + context_token = self.context_embed(context) + context_token = self.context_pe(context_token) + if self.context_fusion == 'concat' or self.context_fusion == 'joint': + x, x_mask = self._concat_x_context(x=x, context=context_token, + x_mask=x_mask, + context_mask=context_mask) + context_token, context_mask = None, None + else: + context_token, context_mask = None, None + + time_token = self.time_embed(timesteps) + if self.cls_embed: + cls_token = self.cls_embed(cls_token) + time_ada = None + if self.use_adanorm: + if self.cls_embed: + time_token = time_token + cls_token + time_token = self.time_act(time_token) + if self.time_ada is not None: + time_ada = self.time_ada(time_token) + else: + time_token = time_token.unsqueeze(dim=1) + if self.cls_embed: + cls_token = cls_token.unsqueeze(dim=1) + time_token = torch.cat([time_token, cls_token], dim=1) + time_token = self.time_pe(time_token) + x = torch.cat((time_token, x), dim=1) + if x_mask is not None: + x_mask = torch.cat( + [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(), + x_mask], dim=1) + time_token = None + + skips = [] + for blk in self.in_blocks: + x = blk(x=x, time_token=time_token, time_ada=time_ada, + skip=None, context=context_token, + x_mask=x_mask, context_mask=context_mask, + extras=self.extras) + skips.append(x) + + controlnet_skips = [] + for skip, controlnet_block in zip(skips, self.controlnet_zero_blocks): + controlnet_skips.append(controlnet_block(skip) * conditioning_scale) + + return controlnet_skips \ No newline at end of file diff --git a/src/models/.ipynb_checkpoints/udit-checkpoint.py b/src/models/.ipynb_checkpoints/udit-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e126efd370efabbfcc4f4359194f9c95c6e9d154 --- /dev/null +++ b/src/models/.ipynb_checkpoints/udit-checkpoint.py @@ -0,0 +1,365 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint +import math +from .utils.modules import PatchEmbed, TimestepEmbedder +from .utils.modules import PE_wrapper, RMSNorm +from .blocks import DiTBlock, JointDiTBlock, FinalBlock + + +class UDiT(nn.Module): + def __init__(self, + img_size=224, patch_size=16, in_chans=3, + input_type='2d', out_chans=None, + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., + qkv_bias=False, qk_scale=None, qk_norm=None, + act_layer='gelu', norm_layer='layernorm', + context_norm=False, + use_checkpoint=False, + # time fusion ada or token + time_fusion='token', + ada_lora_rank=None, ada_lora_alpha=None, + cls_dim=None, + # max length is only used for concat + context_dim=768, context_fusion='concat', + context_max_length=128, context_pe_method='sinu', + pe_method='abs', rope_mode='none', + use_conv=True, + skip=True, skip_norm=True): + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + # input + self.in_chans = in_chans + self.input_type = input_type + if self.input_type == '2d': + num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + elif self.input_type == '1d': + num_patches = img_size // patch_size + self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, input_type=input_type) + out_chans = in_chans if out_chans is None else out_chans + self.out_chans = out_chans + + # position embedding + self.rope = rope_mode + self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method, + length=num_patches) + + print(f'x position embedding: {pe_method}') + print(f'rope mode: {self.rope}') + + # time embed + self.time_embed = TimestepEmbedder(embed_dim) + self.time_fusion = time_fusion + self.use_adanorm = False + + # cls embed + if cls_dim is not None: + self.cls_embed = nn.Sequential( + nn.Linear(cls_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True),) + else: + self.cls_embed = None + + # time fusion + if time_fusion == 'token': + # put token at the beginning of sequence + self.extras = 2 if self.cls_embed else 1 + self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras) + elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']: + self.use_adanorm = True + # aviod repetitive silu for each adaln block + self.time_act = nn.SiLU() + self.extras = 0 + self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True) + if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']: + # shared adaln + self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) + else: + self.time_ada = None + else: + raise NotImplementedError + print(f'time fusion mode: {self.time_fusion}') + + # context + # use a simple projection + self.use_context = False + self.context_cross = False + self.context_max_length = context_max_length + self.context_fusion = 'none' + if context_dim is not None: + self.use_context = True + self.context_embed = nn.Sequential( + nn.Linear(context_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True),) + self.context_fusion = context_fusion + if context_fusion == 'concat' or context_fusion == 'joint': + self.extras += context_max_length + self.context_pe = PE_wrapper(dim=embed_dim, + method=context_pe_method, + length=context_max_length) + # no cross attention layers + context_dim = None + elif context_fusion == 'cross': + self.context_pe = PE_wrapper(dim=embed_dim, + method=context_pe_method, + length=context_max_length) + self.context_cross = True + context_dim = embed_dim + else: + raise NotImplementedError + print(f'context fusion mode: {context_fusion}') + print(f'context position embedding: {context_pe_method}') + + if self.context_fusion == 'joint': + Block = JointDiTBlock + self.use_skip = skip[0] + else: + Block = DiTBlock + self.use_skip = skip + + # norm layers + if norm_layer == 'layernorm': + norm_layer = nn.LayerNorm + elif norm_layer == 'rmsnorm': + norm_layer = RMSNorm + else: + raise NotImplementedError + + print(f'use long skip connection: {skip}') + self.in_blocks = nn.ModuleList([ + Block( + dim=embed_dim, context_dim=context_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, + act_layer=act_layer, norm_layer=norm_layer, + time_fusion=time_fusion, + ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, + skip=False, skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + self.mid_block = Block( + dim=embed_dim, context_dim=context_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, + act_layer=act_layer, norm_layer=norm_layer, + time_fusion=time_fusion, + ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, + skip=False, skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint) + + self.out_blocks = nn.ModuleList([ + Block( + dim=embed_dim, context_dim=context_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, + act_layer=act_layer, norm_layer=norm_layer, + time_fusion=time_fusion, + ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, + skip=skip, skip_norm=skip_norm, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + # FinalLayer block + self.use_conv = use_conv + self.final_block = FinalBlock(embed_dim=embed_dim, + patch_size=patch_size, + img_size=img_size, + in_chans=out_chans, + input_type=input_type, + norm_layer=norm_layer, + use_conv=use_conv, + use_adanorm=self.use_adanorm) + self.initialize_weights() + + def _init_ada(self): + if self.time_fusion == 'ada': + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + for block in self.in_blocks: + nn.init.constant_(block.adaln.time_ada.weight, 0) + nn.init.constant_(block.adaln.time_ada.bias, 0) + nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0) + nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0) + for block in self.out_blocks: + nn.init.constant_(block.adaln.time_ada.weight, 0) + nn.init.constant_(block.adaln.time_ada.bias, 0) + elif self.time_fusion == 'ada_single': + nn.init.constant_(self.time_ada.weight, 0) + nn.init.constant_(self.time_ada.bias, 0) + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + elif self.time_fusion in ['ada_lora', 'ada_lora_bias']: + nn.init.constant_(self.time_ada.weight, 0) + nn.init.constant_(self.time_ada.bias, 0) + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + for block in self.in_blocks: + nn.init.kaiming_uniform_(block.adaln.lora_a.weight, + a=math.sqrt(5)) + nn.init.constant_(block.adaln.lora_b.weight, 0) + nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight, + a=math.sqrt(5)) + nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0) + for block in self.out_blocks: + nn.init.kaiming_uniform_(block.adaln.lora_a.weight, + a=math.sqrt(5)) + nn.init.constant_(block.adaln.lora_b.weight, 0) + + def initialize_weights(self): + # Basic init for all layers + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # init patch Conv like Linear + w = self.patch_embed.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.patch_embed.proj.bias, 0) + + # Zero-out AdaLN + if self.use_adanorm: + self._init_ada() + + # Zero-out Cross Attention + if self.context_cross: + for block in self.in_blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0) + nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0) + for block in self.out_blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out cls embedding + if self.cls_embed: + if self.use_adanorm: + nn.init.constant_(self.cls_embed[-1].weight, 0) + nn.init.constant_(self.cls_embed[-1].bias, 0) + + # Zero-out Output + # might not zero-out this when using v-prediction + # it could be good when using noise-prediction + # nn.init.constant_(self.final_block.linear.weight, 0) + # nn.init.constant_(self.final_block.linear.bias, 0) + # if self.use_conv: + # nn.init.constant_(self.final_block.final_layer.weight.data, 0) + # nn.init.constant_(self.final_block.final_layer.bias, 0) + + # init out Conv + if self.use_conv: + nn.init.xavier_uniform_(self.final_block.final_layer.weight) + nn.init.constant_(self.final_block.final_layer.bias, 0) + + def _concat_x_context(self, x, context, x_mask=None, context_mask=None): + assert context.shape[-2] == self.context_max_length + # Check if either x_mask or context_mask is provided + B = x.shape[0] + # Create default masks if they are not provided + if x_mask is None: + x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() + if context_mask is None: + context_mask = torch.ones(B, context.shape[-2], + device=context.device).bool() + # Concatenate the masks along the second dimension (dim=1) + x_mask = torch.cat([context_mask, x_mask], dim=1) + # Concatenate context and x along the second dimension (dim=1) + x = torch.cat((context, x), dim=1) + return x, x_mask + + def forward(self, x, timesteps, context, + x_mask=None, context_mask=None, + cls_token=None, controlnet_skips=None, + ): + # make it compatible with int time step during inference + if timesteps.dim() == 0: + timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long) + + x = self.patch_embed(x) + x = self.x_pe(x) + + B, L, D = x.shape + + if self.use_context: + context_token = self.context_embed(context) + context_token = self.context_pe(context_token) + if self.context_fusion == 'concat' or self.context_fusion == 'joint': + x, x_mask = self._concat_x_context(x=x, context=context_token, + x_mask=x_mask, + context_mask=context_mask) + context_token, context_mask = None, None + else: + context_token, context_mask = None, None + + time_token = self.time_embed(timesteps) + if self.cls_embed: + cls_token = self.cls_embed(cls_token) + time_ada = None + time_ada_final = None + if self.use_adanorm: + if self.cls_embed: + time_token = time_token + cls_token + time_token = self.time_act(time_token) + time_ada_final = self.time_ada_final(time_token) + if self.time_ada is not None: + time_ada = self.time_ada(time_token) + else: + time_token = time_token.unsqueeze(dim=1) + if self.cls_embed: + cls_token = cls_token.unsqueeze(dim=1) + time_token = torch.cat([time_token, cls_token], dim=1) + time_token = self.time_pe(time_token) + x = torch.cat((time_token, x), dim=1) + if x_mask is not None: + x_mask = torch.cat( + [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(), + x_mask], dim=1) + time_token = None + + skips = [] + for blk in self.in_blocks: + x = blk(x=x, time_token=time_token, time_ada=time_ada, + skip=None, context=context_token, + x_mask=x_mask, context_mask=context_mask, + extras=self.extras) + if self.use_skip: + skips.append(x) + + x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada, + skip=None, context=context_token, + x_mask=x_mask, context_mask=context_mask, + extras=self.extras) + for blk in self.out_blocks: + if self.use_skip: + skip = skips.pop() + if controlnet_skips: + # add to skip like u-net controlnet + skip = skip + controlnet_skips.pop() + else: + skip = None + if controlnet_skips: + # directly add to x + x = x + controlnet_skips.pop() + + x = blk(x=x, time_token=time_token, time_ada=time_ada, + skip=skip, context=context_token, + x_mask=x_mask, context_mask=context_mask, + extras=self.extras) + + x = self.final_block(x, time_ada=time_ada_final, extras=self.extras) + + return x \ No newline at end of file diff --git a/src/models/__pycache__/attention.cpython-311.pyc b/src/models/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5f0d5a6c50f65f9163dc99a4b7d27343a5167677 Binary files /dev/null and b/src/models/__pycache__/attention.cpython-311.pyc differ diff --git a/src/models/__pycache__/blocks.cpython-310.pyc b/src/models/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7461fb52ed06192ceb625e9a60ee1e5fb68d6b9 Binary files /dev/null and b/src/models/__pycache__/blocks.cpython-310.pyc differ diff --git a/src/models/__pycache__/blocks.cpython-311.pyc b/src/models/__pycache__/blocks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6e2e896fbaf9999c48a80ab82f91eaf281fd6e31 Binary files /dev/null and b/src/models/__pycache__/blocks.cpython-311.pyc differ diff --git a/src/models/__pycache__/conditioners.cpython-310.pyc b/src/models/__pycache__/conditioners.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..cacc8c6625287466c020bb10ab7c4ecd352e2d72 Binary files /dev/null and b/src/models/__pycache__/conditioners.cpython-310.pyc differ diff --git a/src/models/__pycache__/conditioners.cpython-311.pyc b/src/models/__pycache__/conditioners.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5327c1cb1adb18a3c8e198a6aad33788d244680c Binary files /dev/null and b/src/models/__pycache__/conditioners.cpython-311.pyc differ diff --git a/src/models/__pycache__/controlnet.cpython-311.pyc b/src/models/__pycache__/controlnet.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fd2690660df68d7bcd3a749771c43615891d11d Binary files /dev/null and b/src/models/__pycache__/controlnet.cpython-311.pyc differ diff --git a/src/models/__pycache__/modules.cpython-311.pyc b/src/models/__pycache__/modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a075bf74a7e0860dfbb3bf07774efc46eda5face Binary files /dev/null and b/src/models/__pycache__/modules.cpython-311.pyc differ diff --git a/src/models/__pycache__/rotary.cpython-311.pyc b/src/models/__pycache__/rotary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d1ba7bb807f377ad675aae2086a2f5147407c850 Binary files /dev/null and b/src/models/__pycache__/rotary.cpython-311.pyc differ diff --git a/src/models/__pycache__/timm.cpython-311.pyc b/src/models/__pycache__/timm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..381ee0b6c15f37a71522fc8d52501e1158d5e461 Binary files /dev/null and b/src/models/__pycache__/timm.cpython-311.pyc differ diff --git a/src/models/__pycache__/udit.cpython-310.pyc b/src/models/__pycache__/udit.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d83aeee69c058f58072944bf4188c549393916c0 Binary files /dev/null and b/src/models/__pycache__/udit.cpython-310.pyc differ diff --git a/src/models/__pycache__/udit.cpython-311.pyc b/src/models/__pycache__/udit.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1c2807d366b0eb4030274b26d5195f3e4e0d7604 Binary files /dev/null and b/src/models/__pycache__/udit.cpython-311.pyc differ diff --git a/src/models/blocks.py b/src/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..8ef730009cb7cf664c5d9021e551b275680d11f3 --- /dev/null +++ b/src/models/blocks.py @@ -0,0 +1,325 @@ +import torch +import torch.nn as nn +from torch.utils.checkpoint import checkpoint +from .utils.attention import Attention, JointAttention +from .utils.modules import unpatchify, FeedForward +from .utils.modules import film_modulate + + +class AdaLN(nn.Module): + def __init__(self, dim, ada_mode='ada', r=None, alpha=None): + super().__init__() + self.ada_mode = ada_mode + self.scale_shift_table = None + if ada_mode == 'ada': + # move nn.silu outside + self.time_ada = nn.Linear(dim, 6 * dim, bias=True) + elif ada_mode == 'ada_single': + # adaln used in pixel-art alpha + self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) + elif ada_mode in ['ada_lora', 'ada_lora_bias']: + self.lora_a = nn.Linear(dim, r * 6, bias=False) + self.lora_b = nn.Linear(r * 6, dim * 6, bias=False) + self.scaling = alpha / r + if ada_mode == 'ada_lora_bias': + # take bias out for consistency + self.scale_shift_table = nn.Parameter(torch.zeros(6, dim)) + else: + raise NotImplementedError + + def forward(self, time_token=None, time_ada=None): + if self.ada_mode == 'ada': + assert time_ada is None + B = time_token.shape[0] + time_ada = self.time_ada(time_token).reshape(B, 6, -1) + elif self.ada_mode == 'ada_single': + B = time_ada.shape[0] + time_ada = time_ada.reshape(B, 6, -1) + time_ada = self.scale_shift_table[None] + time_ada + elif self.ada_mode in ['ada_lora', 'ada_lora_bias']: + B = time_ada.shape[0] + time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling + time_ada = time_ada + time_ada_lora + time_ada = time_ada.reshape(B, 6, -1) + if self.scale_shift_table is not None: + time_ada = self.scale_shift_table[None] + time_ada + else: + raise NotImplementedError + return time_ada + + +class DiTBlock(nn.Module): + """ + A modified PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__(self, dim, context_dim=None, + num_heads=8, mlp_ratio=4., + qkv_bias=False, qk_scale=None, qk_norm=None, + act_layer='gelu', norm_layer=nn.LayerNorm, + time_fusion='none', + ada_lora_rank=None, ada_lora_alpha=None, + skip=False, skip_norm=False, + rope_mode='none', + context_norm=False, + use_checkpoint=False): + + super().__init__() + self.norm1 = norm_layer(dim) + self.attn = Attention(dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, + qk_norm=qk_norm, + rope_mode=rope_mode) + + if context_dim is not None: + self.use_context = True + self.cross_attn = Attention(dim=dim, + num_heads=num_heads, + context_dim=context_dim, + qkv_bias=qkv_bias, qk_scale=qk_scale, + qk_norm=qk_norm, + rope_mode='none') + self.norm2 = norm_layer(dim) + if context_norm: + self.norm_context = norm_layer(context_dim) + else: + self.norm_context = nn.Identity() + else: + self.use_context = False + + self.norm3 = norm_layer(dim) + self.mlp = FeedForward(dim=dim, mult=mlp_ratio, + activation_fn=act_layer, dropout=0) + + self.use_adanorm = True if time_fusion != 'token' else False + if self.use_adanorm: + self.adaln = AdaLN(dim, ada_mode=time_fusion, + r=ada_lora_rank, alpha=ada_lora_alpha) + if skip: + self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity() + self.skip_linear = nn.Linear(2 * dim, dim) + else: + self.skip_linear = None + + self.use_checkpoint = use_checkpoint + + def forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + if self.use_checkpoint: + return checkpoint(self._forward, x, + time_token, time_ada, skip, context, + x_mask, context_mask, extras, + use_reentrant=False) + else: + return self._forward(x, + time_token, time_ada, skip, context, + x_mask, context_mask, extras) + + def _forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + B, T, C = x.shape + if self.skip_linear is not None: + assert skip is not None + cat = torch.cat([x, skip], dim=-1) + cat = self.skip_norm(cat) + x = self.skip_linear(cat) + + if self.use_adanorm: + time_ada = self.adaln(time_token, time_ada) + (shift_msa, scale_msa, gate_msa, + shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1) + + # self attention + if self.use_adanorm: + x_norm = film_modulate(self.norm1(x), shift=shift_msa, + scale=scale_msa) + x = x + (1 - gate_msa) * self.attn(x_norm, context=None, + context_mask=x_mask, + extras=extras) + else: + x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask, + extras=extras) + + # cross attention + if self.use_context: + assert context is not None + x = x + self.cross_attn(x=self.norm2(x), + context=self.norm_context(context), + context_mask=context_mask, extras=extras) + + # mlp + if self.use_adanorm: + x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp) + x = x + (1 - gate_mlp) * self.mlp(x_norm) + else: + x = x + self.mlp(self.norm3(x)) + + return x + + +class JointDiTBlock(nn.Module): + """ + A modified PixArt block with adaptive layer norm (adaLN-single) conditioning. + """ + + def __init__(self, dim, context_dim=None, + num_heads=8, mlp_ratio=4., + qkv_bias=False, qk_scale=None, qk_norm=None, + act_layer='gelu', norm_layer=nn.LayerNorm, + time_fusion='none', + ada_lora_rank=None, ada_lora_alpha=None, + skip=(False, False), + rope_mode=False, + context_norm=False, + use_checkpoint=False,): + + super().__init__() + # no cross attention + assert context_dim is None + self.attn_norm_x = norm_layer(dim) + self.attn_norm_c = norm_layer(dim) + self.attn = JointAttention(dim=dim, + num_heads=num_heads, + qkv_bias=qkv_bias, qk_scale=qk_scale, + qk_norm=qk_norm, + rope_mode=rope_mode) + self.ffn_norm_x = norm_layer(dim) + self.ffn_norm_c = norm_layer(dim) + self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio, + activation_fn=act_layer, dropout=0) + self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio, + activation_fn=act_layer, dropout=0) + + # Zero-out the shift table + self.use_adanorm = True if time_fusion != 'token' else False + if self.use_adanorm: + self.adaln = AdaLN(dim, ada_mode=time_fusion, + r=ada_lora_rank, alpha=ada_lora_alpha) + + if skip is False: + skip_x, skip_c = False, False + else: + skip_x, skip_c = skip + + self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None + self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None + + self.use_checkpoint = use_checkpoint + + def forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + if self.use_checkpoint: + return checkpoint(self._forward, x, + time_token, time_ada, skip, + context, x_mask, context_mask, extras, + use_reentrant=False) + else: + return self._forward(x, + time_token, time_ada, skip, + context, x_mask, context_mask, extras) + + def _forward(self, x, time_token=None, time_ada=None, + skip=None, context=None, + x_mask=None, context_mask=None, extras=None): + + assert context is None and context_mask is None + + context, x = x[:, :extras, :], x[:, extras:, :] + context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:] + + if skip is not None: + skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :] + + B, T, C = x.shape + if self.skip_linear_x is not None: + x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1)) + + if self.skip_linear_c is not None: + context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1)) + + if self.use_adanorm: + time_ada = self.adaln(time_token, time_ada) + (shift_msa, scale_msa, gate_msa, + shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1) + + # self attention + x_norm = self.attn_norm_x(x) + c_norm = self.attn_norm_c(context) + if self.use_adanorm: + x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa) + x_out, c_out = self.attn(x_norm, context=c_norm, + x_mask=x_mask, context_mask=context_mask, + extras=extras) + if self.use_adanorm: + x = x + (1 - gate_msa) * x_out + else: + x = x + x_out + context = context + c_out + + # mlp + if self.use_adanorm: + x_norm = film_modulate(self.ffn_norm_x(x), + shift=shift_mlp, scale=scale_mlp) + x = x + (1 - gate_mlp) * self.mlp_x(x_norm) + else: + x = x + self.mlp_x(self.ffn_norm_x(x)) + + c_norm = self.ffn_norm_c(context) + context = context + self.mlp_c(c_norm) + + return torch.cat((context, x), dim=1) + + +class FinalBlock(nn.Module): + def __init__(self, embed_dim, patch_size, in_chans, + img_size, + input_type='2d', + norm_layer=nn.LayerNorm, + use_conv=True, + use_adanorm=True): + super().__init__() + self.in_chans = in_chans + self.img_size = img_size + self.input_type = input_type + + self.norm = norm_layer(embed_dim) + if use_adanorm: + self.use_adanorm = True + else: + self.use_adanorm = False + + if input_type == '2d': + self.patch_dim = patch_size ** 2 * in_chans + self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) + if use_conv: + self.final_layer = nn.Conv2d(self.in_chans, self.in_chans, + 3, padding=1) + else: + self.final_layer = nn.Identity() + + elif input_type == '1d': + self.patch_dim = patch_size * in_chans + self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True) + if use_conv: + self.final_layer = nn.Conv1d(self.in_chans, self.in_chans, + 3, padding=1) + else: + self.final_layer = nn.Identity() + + def forward(self, x, time_ada=None, extras=0): + B, T, C = x.shape + x = x[:, extras:, :] + # only handle generation target + if self.use_adanorm: + shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1) + x = film_modulate(self.norm(x), shift, scale) + else: + x = self.norm(x) + x = self.linear(x) + x = unpatchify(x, self.in_chans, self.input_type, self.img_size) + x = self.final_layer(x) + return x \ No newline at end of file diff --git a/src/models/conditioners.py b/src/models/conditioners.py new file mode 100644 index 0000000000000000000000000000000000000000..cade7febf61ef005f421c42cf17bb1bb2935a751 --- /dev/null +++ b/src/models/conditioners.py @@ -0,0 +1,183 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import repeat +import math +from .udit import UDiT +from .utils.span_mask import compute_mask_indices + + +class EmbeddingCFG(nn.Module): + """ + Handles label dropout for classifier-free guidance. + """ + # todo: support 2D input + + def __init__(self, in_channels): + super().__init__() + self.cfg_embedding = nn.Parameter( + torch.randn(in_channels) / in_channels ** 0.5) + + def token_drop(self, condition, condition_mask, cfg_prob): + """ + Drops labels to enable classifier-free guidance. + """ + b, t, device = condition.shape[0], condition.shape[1], condition.device + drop_ids = torch.rand(b, device=device) < cfg_prob + uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t) + condition = torch.where(drop_ids[:, None, None], uncond, condition) + if condition_mask is not None: + condition_mask[drop_ids] = False + condition_mask[drop_ids, 0] = True + + return condition, condition_mask + + def forward(self, condition, condition_mask, cfg_prob=0.0): + if condition_mask is not None: + condition_mask = condition_mask.clone() + if cfg_prob > 0: + condition, condition_mask = self.token_drop(condition, + condition_mask, + cfg_prob) + return condition, condition_mask + + +class DiscreteCFG(nn.Module): + def __init__(self, replace_id=2): + super(DiscreteCFG, self).__init__() + self.replace_id = replace_id + + def forward(self, context, context_mask, cfg_prob): + context = context.clone() + if context_mask is not None: + context_mask = context_mask.clone() + if cfg_prob > 0: + cfg_mask = torch.rand(len(context)) < cfg_prob + if torch.any(cfg_mask): + context[cfg_mask] = 0 + context[cfg_mask, 0] = self.replace_id + if context_mask is not None: + context_mask[cfg_mask] = False + context_mask[cfg_mask, 0] = True + return context, context_mask + + +class CFGModel(nn.Module): + def __init__(self, context_dim, backbone): + super().__init__() + self.model = backbone + self.context_cfg = EmbeddingCFG(context_dim) + + def forward(self, x, timesteps, + context, x_mask=None, context_mask=None, + cfg_prob=0.0): + context = self.context_cfg(context, cfg_prob) + x = self.model(x=x, timesteps=timesteps, + context=context, + x_mask=x_mask, context_mask=context_mask) + return x + + +class ConcatModel(nn.Module): + def __init__(self, backbone, in_dim, stride=[]): + super().__init__() + self.model = backbone + + self.downsample_layers = nn.ModuleList() + for i, s in enumerate(stride): + downsample_layer = nn.Conv1d( + in_dim, + in_dim * 2, + kernel_size=2 * s, + stride=s, + padding=math.ceil(s / 2), + ) + self.downsample_layers.append(downsample_layer) + in_dim = in_dim * 2 + + self.context_cfg = EmbeddingCFG(in_dim) + + def forward(self, x, timesteps, + context, x_mask=None, + cfg=False, cfg_prob=0.0): + + # todo: support 2D input + # x: B, C, L + # context: B, C, L + + for downsample_layer in self.downsample_layers: + context = downsample_layer(context) + + context = context.transpose(1, 2) + context = self.context_cfg(caption=context, + cfg=cfg, cfg_prob=cfg_prob) + context = context.transpose(1, 2) + + assert context.shape[-1] == x.shape[-1] + x = torch.cat([context, x], dim=1) + x = self.model(x=x, timesteps=timesteps, + context=None, x_mask=x_mask, context_mask=None) + return x + + +class MaskDiT(nn.Module): + def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs): + super().__init__() + self.model = UDiT(**kwargs) + self.mae = mae + if self.mae: + out_channel = kwargs.pop('out_chans', None) + self.mask_embed = nn.Parameter(torch.zeros((out_channel))) + self.mae_prob = mae_prob + self.mask_ratio = mask_ratio + self.mask_span = mask_span + + def random_masking(self, gt, mask_ratios, mae_mask_infer=None): + B, D, L = gt.shape + if mae_mask_infer is None: + # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1) + mask_ratios = mask_ratios.cpu().numpy() + mask = compute_mask_indices(shape=[B, L], + padding_mask=None, + mask_prob=mask_ratios, + mask_length=self.mask_span, + mask_type="static", + mask_other=0.0, + min_masks=1, + no_overlap=False, + min_space=0,) + mask = mask.unsqueeze(1).expand_as(gt) + else: + mask = mae_mask_infer + mask = mask.expand_as(gt) + gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask] + return gt, mask.type_as(gt) + + def forward(self, x, timesteps, context, + x_mask=None, context_mask=None, cls_token=None, + gt=None, mae_mask_infer=None, + forward_model=True): + # todo: handle controlnet inside + mae_mask = torch.ones_like(x) + if self.mae: + if gt is not None: + B, D, L = gt.shape + mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device) + gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer) + # apply mae only to the selected batches + if mae_mask_infer is None: + # determine mae batch + mae_batch = torch.rand(B) < self.mae_prob + gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch] + mae_mask[~mae_batch] = 1.0 + else: + B, D, L = x.shape + gt = self.mask_embed.view(1, D, 1).expand_as(x) + x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1) + + if forward_model: + x = self.model(x=x, timesteps=timesteps, context=context, + x_mask=x_mask, context_mask=context_mask, + cls_token=cls_token) + # print(mae_mask[:, 0, :].sum(dim=-1)) + return x, mae_mask diff --git a/src/models/conditions/.ipynb_checkpoints/__init__-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..0a60c5319ac77a8cdebd2835527256b547101700 --- /dev/null +++ b/src/models/conditions/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1 @@ +from .condition_wrapper import Conditioner \ No newline at end of file diff --git a/src/models/conditions/.ipynb_checkpoints/chroma-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/chroma-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a07aebb898f6d8048099a22f4f9b4d8f1a3117fd --- /dev/null +++ b/src/models/conditions/.ipynb_checkpoints/chroma-checkpoint.py @@ -0,0 +1,80 @@ +import typing as tp + +from einops import rearrange +from librosa import filters +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio + + +class ChromaExtractor(nn.Module): + """Chroma extraction and quantization. + + Args: + sample_rate (int): Sample rate for the chroma extraction. + n_chroma (int): Number of chroma bins for the chroma extraction. + radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). + nfft (int, optional): Number of FFT. + winlen (int, optional): Window length. + winhop (int, optional): Window hop size. + argmax (bool, optional): Whether to use argmax. Defaults to False. + norm (float, optional): Norm for chroma normalization. Defaults to inf. + """ + + def __init__(self, + sample_rate: int, + n_chroma: int = 12, radix2_exp: int = 12, + nfft: tp.Optional[int] = None, + winlen: tp.Optional[int] = None, + winhop: tp.Optional[int] = None, argmax: bool = True, + norm: float = torch.inf): + super().__init__() + self.winlen = winlen or 2 ** radix2_exp + self.nfft = nfft or self.winlen + self.winhop = winhop or (self.winlen // 4) + self.sample_rate = sample_rate + self.n_chroma = n_chroma + self.norm = norm + self.argmax = argmax + self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, + n_chroma=self.n_chroma)), persistent=False) + self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, + hop_length=self.winhop, power=2, center=False, + pad=0, normalized=True) + + def forward(self, wav: torch.Tensor) -> torch.Tensor: + T = wav.shape[-1] + # in case we are getting a wav that was dropped out (nullified) + # from the conditioner, make sure wav length is no less that nfft + if T < self.nfft: + pad = self.nfft - T + r = 0 if pad % 2 == 0 else 1 + wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) + assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" + + wav = F.pad(wav, (int(self.nfft // 2 - self.winhop // 2 ), + int(self.nfft // 2 - self.winhop // 2 )), mode="reflect") + + spec = self.spec(wav).squeeze(1) + raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) + norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) + norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') + + if self.argmax: + idx = norm_chroma.argmax(-1, keepdim=True) + norm_chroma[:] = 0 + norm_chroma.scatter_(dim=-1, index=idx, value=1) + + return norm_chroma + + +if __name__ == "__main__": + chroma = ChromaExtractor(sample_rate=16000, + n_chroma=4, + radix2_exp=None, + winlen=16000, + nfft=16000, + winhop=4000) + audio = torch.rand(1, 16000) + c = chroma(audio) \ No newline at end of file diff --git a/src/models/conditions/.ipynb_checkpoints/condition_wrapper-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/condition_wrapper-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..20e44795abc16c3b9a9a30c8f579767879c25fe7 --- /dev/null +++ b/src/models/conditions/.ipynb_checkpoints/condition_wrapper-checkpoint.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +from .chroma import ChromaExtractor +from .energy import EnergyExtractor +from .voice import VoiceConversionExtractor +from .mbenergy import MultibandEnergyExtractor + + +class Conditioner(nn.Module): + def __init__(self, + condition_type, + **kwargs + ): + super().__init__() + if condition_type == 'energy': + self.conditioner = EnergyExtractor(**kwargs) + elif condition_type == 'chroma': + self.conditioner = ChromaExtractor(**kwargs) + elif condition_type == 'vc': + self.conditioner = VoiceConversionExtractor(**kwargs) + elif condition_type == 'mb_energy': + self.conditioner = MultibandEnergyExtractor(**kwargs) + else: + raise NotImplementedError + + def forward(self, waveform, latent_shape): + # B T C + condition = self.conditioner(waveform) + # B C T + condition = condition.permute(0, 2, 1).contiguous() + + if len(latent_shape) == 4: + # 2d spectrogram B C T F + assert (condition.shape[-1] % latent_shape[-2]) == 0 + X = latent_shape[-1] * condition.shape[-1] // latent_shape[-2] + # copy on F direction + condition = condition.unsqueeze(-1).expand(-1, -1, -1, X) + elif len(latent_shape) == 3: + condition = condition + else: + raise NotImplementedError + return condition + + +if __name__ == '__main__': + conditioner = Conditioner(condition_type='energy', + hop_size=160, window_size=1024, padding='reflect', + min_db=-80, norm=True) + audio = torch.rand(4, 16000) # Example audio signal + energy = conditioner(audio, (4, 8, 100, 64)) \ No newline at end of file diff --git a/src/models/conditions/.ipynb_checkpoints/debug-checkpoint.png b/src/models/conditions/.ipynb_checkpoints/debug-checkpoint.png new file mode 100644 index 0000000000000000000000000000000000000000..a092693dc8e7cacb09b8eef632d1e7ececbf51e3 Binary files /dev/null and b/src/models/conditions/.ipynb_checkpoints/debug-checkpoint.png differ diff --git a/src/models/conditions/.ipynb_checkpoints/energy-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/energy-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..f86481723850af419fa701254863789108596e88 --- /dev/null +++ b/src/models/conditions/.ipynb_checkpoints/energy-checkpoint.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +class EnergyExtractor(nn.Module): + def __init__(self, hop_size: int = 512, window_size: int = 1024, + padding: str = 'reflect', min_db: float = -60, + norm: bool = True, quantize_levels: int = None): + super().__init__() + self.hop_size = hop_size + self.window_size = window_size + self.padding = padding + self.min_db = min_db + self.norm = norm + self.quantize_levels = quantize_levels + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + # Compute number of frames + n_frames = int(audio.size(-1) // self.hop_size) + + # Pad the audio signal + pad_amount = (self.window_size - self.hop_size) // 2 + audio_padded = F.pad(audio, (pad_amount, pad_amount), mode=self.padding) + + # Square the padded audio signal + audio_squared = audio_padded ** 2 + + # Compute the mean energy for each frame using unfold and mean + audio_squared = audio_squared[:, None, None, :] + energy = F.unfold(audio_squared, (1, self.window_size), stride=self.hop_size)[:, :, :n_frames] + energy = energy.mean(dim=1) + + # Compute the square root of the mean energy to get the RMS energy + # energy = torch.sqrt(energy) + + # Normalize the energy using the min_db value + gain = torch.maximum(energy, torch.tensor(np.power(10, self.min_db / 10), device=audio.device)) + gain_db = 10 * torch.log10(gain) + + if self.norm: + # Find the min and max of gain_db + # min_gain_db = torch.min(gain_db) + min_gain_db = self.min_db + max_gain_db = torch.max(gain_db, dim=-1, keepdim=True)[0] + + # Avoid numerical error by adding a small epsilon to the denominator + epsilon = 1e-8 + gain_db = (gain_db - min_gain_db) / (max_gain_db - min_gain_db + epsilon) + + if self.quantize_levels is not None: + # Quantize the result to the given number of levels + gain_db = torch.round(gain_db * (self.quantize_levels - 1)) / (self.quantize_levels - 1) + + return gain_db.unsqueeze(-1) + + +if __name__ == "__main__": + energy_extractor = EnergyExtractor(hop_size=512, window_size=1024, padding='reflect', + min_db=-60, norm=True) + audio = torch.rand(1, 16000) + energy = energy_extractor(audio) + print(energy.shape) + import librosa + import matplotlib.pyplot as plt + # a1, _ = librosa.load('eg1.wav', sr=16000) + # a2, _ = librosa.load('eg2.wav', sr=16000) + # audio = torch.tensor([a1[:5*16000], a2[:5*16000]]) + a1, _ = librosa.load('eg2.wav', sr=24000) + audio = torch.tensor(a1[:5*16000]).unsqueeze(0) + energy = energy_extractor(audio) + print(energy.shape) + + # Plot the energy for each audio sample + plt.figure(figsize=(12, 6)) + + for i in range(energy.shape[0]): + plt.plot(energy[i, :, 0].cpu().numpy(), label=f'Audio {i+1}') + + plt.xlabel('Frame') + plt.ylabel('Energy (dB)') + plt.title('Energy over Time') + plt.legend() + plt.savefig('debug.png') \ No newline at end of file diff --git a/src/models/conditions/.ipynb_checkpoints/mbenergy-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/mbenergy-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..39dedf081f61f15a5e49a922a1863b560888aa6c --- /dev/null +++ b/src/models/conditions/.ipynb_checkpoints/mbenergy-checkpoint.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import julius +import soundfile as sf + + +class MultibandEnergyExtractor(nn.Module): + def __init__(self, hop_size: int = 512, window_size: int = 1024, + padding: str = 'reflect', min_db: float = -60, + norm: bool = True, quantize_levels: int = None, + n_bands: int = 8, control_bands: int = 4, + sample_rate: int = 24000,): + super().__init__() + self.hop_size = hop_size + self.window_size = window_size + self.padding = padding + self.min_db = min_db + self.norm = norm + self.quantize_levels = quantize_levels + self.n_bands = n_bands + self.control_bands = control_bands + self.sample_rate = sample_rate + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + # Split the audio into frequency bands + audio = julius.split_bands(audio, n_bands=self.n_bands, + sample_rate=self.sample_rate)[:self.control_bands].transpose(0, 1) + B, C, _ = audio.shape + for i in range(C): + sf.write(f'output_{i}.wav', audio[0][i], self.sample_rate) + + # Compute number of frames + n_frames = int(audio.size(-1) // self.hop_size) + + # Pad the audio signal + pad_amount = (self.window_size - self.hop_size) // 2 + audio_padded = F.pad(audio, (pad_amount, pad_amount), mode=self.padding) + + # Square the padded audio signal + audio_squared = audio_padded ** 2 + + # Compute the mean energy for each frame using unfold and mean + energy = audio_squared.unfold(dimension=-1, size=self.window_size, step=self.hop_size) + energy = energy[:, :, :n_frames] + print(energy.shape) + energy = energy.mean(dim=-1) + print(energy.shape) + + # Compute the square root of the mean energy to get the RMS energy + # energy = torch.sqrt(energy) + + # Normalize the energy using the min_db value + gain = torch.maximum(energy, torch.tensor(np.power(10, self.min_db / 10), device=audio.device)) + gain_db = 10 * torch.log10(gain) + + if self.norm: + # Find the min and max of gain_db + # min_gain_db = torch.min(gain_db) + min_gain_db = self.min_db + max_gain_db = torch.amax(gain_db, dim=(-1, -2), keepdim=True) + + # Avoid numerical error by adding a small epsilon to the denominator + epsilon = 1e-8 + gain_db = (gain_db - min_gain_db) / (max_gain_db - min_gain_db + epsilon) + + if self.quantize_levels is not None: + # Quantize the result to the given number of levels + gain_db = torch.round(gain_db * (self.quantize_levels - 1)) / (self.quantize_levels - 1) + + return gain_db.transpose(-1, -2) + + +if __name__ == "__main__": + energy_extractor = MultibandEnergyExtractor(hop_size=320, window_size=1280, + padding='reflect', + min_db=-60, norm=True) + audio = torch.rand(4, 24000) + energy = energy_extractor(audio) + print(energy.shape) + import librosa + import matplotlib.pyplot as plt + a1, _ = librosa.load('eg2.wav', sr=24000) + audio = torch.tensor(a1[:5*16000]).unsqueeze(0) + energy = energy_extractor(audio) + print(energy.shape) + + # Plot the energy for each audio sample + plt.figure(figsize=(12, 6)) + + for i in range(energy.shape[-1]): + plt.plot(energy[0, :, i].cpu().numpy(), label=f'Band {i+1}') + + plt.xlabel('Frame') + plt.ylabel('Energy (dB)') + plt.title('Energy over Time') + plt.legend() + plt.savefig('debug.png') diff --git a/src/models/conditions/.ipynb_checkpoints/voice-checkpoint.py b/src/models/conditions/.ipynb_checkpoints/voice-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..9739d69ee93d4b5b81c1e6bd83c94e48c1e5b2b3 --- /dev/null +++ b/src/models/conditions/.ipynb_checkpoints/voice-checkpoint.py @@ -0,0 +1,46 @@ +from transformers import HubertModel +import torch.nn as nn +import torch +import torch.nn.functional as F +import torchaudio +import librosa + + +class HubertModelWithFinalProj(HubertModel): + def __init__(self, config): + super().__init__(config) + + # The final projection layer is only used for backward compatibility. + # Following https://github.com/auspicious3000/contentvec/issues/6 + # Remove this layer is necessary to achieve the desired outcome. + self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) + + +class VoiceConversionExtractor(nn.Module): + # training on the fly might be slow + def __init__(self, config, sr): + super().__init__() + self.encoder = HubertModelWithFinalProj.from_pretrained(config) + self.encoder.eval() + self.sr = sr + self.target_sr = 16000 + if self.sr != self.target_sr: + self.resampler = torchaudio.transforms.Resample(orig_freq=self.sr, + new_freq=self.target_sr) + + def forward(self, audio): + if self.sr != self.target_sr: + audio = self.resampler(audio) + audio = F.pad(audio, ((400 - 320) // 2, (400 - 320) // 2)) + logits = self.encoder(audio)['last_hidden_state'] + return logits + + +if __name__ == '__main__': + model = VoiceConversionExtractor('lengyue233/content-vec-best', 24000) + audio, sr = librosa.load('test.wav', sr=24000) + audio = audio[:round(100*320*1.5)] + audio = torch.tensor([audio]) + with torch.no_grad(): + content = model(audio) + print(content.shape) \ No newline at end of file diff --git a/src/models/conditions/__init__.py b/src/models/conditions/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..0a60c5319ac77a8cdebd2835527256b547101700 --- /dev/null +++ b/src/models/conditions/__init__.py @@ -0,0 +1 @@ +from .condition_wrapper import Conditioner \ No newline at end of file diff --git a/src/models/conditions/__pycache__/__init__.cpython-311.pyc b/src/models/conditions/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3af6332d60ff9aad61b16821c885967ebe1dd22e Binary files /dev/null and b/src/models/conditions/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/models/conditions/__pycache__/chroma.cpython-311.pyc b/src/models/conditions/__pycache__/chroma.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..5c61a05935c9a648b6200517444dc4baa5be6712 Binary files /dev/null and b/src/models/conditions/__pycache__/chroma.cpython-311.pyc differ diff --git a/src/models/conditions/__pycache__/condition_wrapper.cpython-311.pyc b/src/models/conditions/__pycache__/condition_wrapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..3e6f9d47688777ed749d1e6dad04e2f65269e48a Binary files /dev/null and b/src/models/conditions/__pycache__/condition_wrapper.cpython-311.pyc differ diff --git a/src/models/conditions/__pycache__/energy.cpython-311.pyc b/src/models/conditions/__pycache__/energy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..684b9a9193b8b9b28f1c518cc0e389ad1e08aace Binary files /dev/null and b/src/models/conditions/__pycache__/energy.cpython-311.pyc differ diff --git a/src/models/conditions/__pycache__/mbenergy.cpython-311.pyc b/src/models/conditions/__pycache__/mbenergy.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..116f7fe8ddc26a388ce59c8d337537a36f0a4e0c Binary files /dev/null and b/src/models/conditions/__pycache__/mbenergy.cpython-311.pyc differ diff --git a/src/models/conditions/__pycache__/sound_event.cpython-311.pyc b/src/models/conditions/__pycache__/sound_event.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96294a7643d5f951b6c1e3826f075aa89dd54cf0 Binary files /dev/null and b/src/models/conditions/__pycache__/sound_event.cpython-311.pyc differ diff --git a/src/models/conditions/__pycache__/voice.cpython-311.pyc b/src/models/conditions/__pycache__/voice.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51751a890b358f952f8549a118af45671a0411ce Binary files /dev/null and b/src/models/conditions/__pycache__/voice.cpython-311.pyc differ diff --git a/src/models/conditions/chroma.py b/src/models/conditions/chroma.py new file mode 100644 index 0000000000000000000000000000000000000000..a07aebb898f6d8048099a22f4f9b4d8f1a3117fd --- /dev/null +++ b/src/models/conditions/chroma.py @@ -0,0 +1,80 @@ +import typing as tp + +from einops import rearrange +from librosa import filters +import torch +from torch import nn +import torch.nn.functional as F +import torchaudio + + +class ChromaExtractor(nn.Module): + """Chroma extraction and quantization. + + Args: + sample_rate (int): Sample rate for the chroma extraction. + n_chroma (int): Number of chroma bins for the chroma extraction. + radix2_exp (int): Size of stft window for the chroma extraction (power of 2, e.g. 12 -> 2^12). + nfft (int, optional): Number of FFT. + winlen (int, optional): Window length. + winhop (int, optional): Window hop size. + argmax (bool, optional): Whether to use argmax. Defaults to False. + norm (float, optional): Norm for chroma normalization. Defaults to inf. + """ + + def __init__(self, + sample_rate: int, + n_chroma: int = 12, radix2_exp: int = 12, + nfft: tp.Optional[int] = None, + winlen: tp.Optional[int] = None, + winhop: tp.Optional[int] = None, argmax: bool = True, + norm: float = torch.inf): + super().__init__() + self.winlen = winlen or 2 ** radix2_exp + self.nfft = nfft or self.winlen + self.winhop = winhop or (self.winlen // 4) + self.sample_rate = sample_rate + self.n_chroma = n_chroma + self.norm = norm + self.argmax = argmax + self.register_buffer('fbanks', torch.from_numpy(filters.chroma(sr=sample_rate, n_fft=self.nfft, tuning=0, + n_chroma=self.n_chroma)), persistent=False) + self.spec = torchaudio.transforms.Spectrogram(n_fft=self.nfft, win_length=self.winlen, + hop_length=self.winhop, power=2, center=False, + pad=0, normalized=True) + + def forward(self, wav: torch.Tensor) -> torch.Tensor: + T = wav.shape[-1] + # in case we are getting a wav that was dropped out (nullified) + # from the conditioner, make sure wav length is no less that nfft + if T < self.nfft: + pad = self.nfft - T + r = 0 if pad % 2 == 0 else 1 + wav = F.pad(wav, (pad // 2, pad // 2 + r), 'constant', 0) + assert wav.shape[-1] == self.nfft, f"expected len {self.nfft} but got {wav.shape[-1]}" + + wav = F.pad(wav, (int(self.nfft // 2 - self.winhop // 2 ), + int(self.nfft // 2 - self.winhop // 2 )), mode="reflect") + + spec = self.spec(wav).squeeze(1) + raw_chroma = torch.einsum('cf,...ft->...ct', self.fbanks, spec) + norm_chroma = torch.nn.functional.normalize(raw_chroma, p=self.norm, dim=-2, eps=1e-6) + norm_chroma = rearrange(norm_chroma, 'b d t -> b t d') + + if self.argmax: + idx = norm_chroma.argmax(-1, keepdim=True) + norm_chroma[:] = 0 + norm_chroma.scatter_(dim=-1, index=idx, value=1) + + return norm_chroma + + +if __name__ == "__main__": + chroma = ChromaExtractor(sample_rate=16000, + n_chroma=4, + radix2_exp=None, + winlen=16000, + nfft=16000, + winhop=4000) + audio = torch.rand(1, 16000) + c = chroma(audio) \ No newline at end of file diff --git a/src/models/conditions/condition_wrapper.py b/src/models/conditions/condition_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..20e44795abc16c3b9a9a30c8f579767879c25fe7 --- /dev/null +++ b/src/models/conditions/condition_wrapper.py @@ -0,0 +1,50 @@ +import torch +import torch.nn as nn +from .chroma import ChromaExtractor +from .energy import EnergyExtractor +from .voice import VoiceConversionExtractor +from .mbenergy import MultibandEnergyExtractor + + +class Conditioner(nn.Module): + def __init__(self, + condition_type, + **kwargs + ): + super().__init__() + if condition_type == 'energy': + self.conditioner = EnergyExtractor(**kwargs) + elif condition_type == 'chroma': + self.conditioner = ChromaExtractor(**kwargs) + elif condition_type == 'vc': + self.conditioner = VoiceConversionExtractor(**kwargs) + elif condition_type == 'mb_energy': + self.conditioner = MultibandEnergyExtractor(**kwargs) + else: + raise NotImplementedError + + def forward(self, waveform, latent_shape): + # B T C + condition = self.conditioner(waveform) + # B C T + condition = condition.permute(0, 2, 1).contiguous() + + if len(latent_shape) == 4: + # 2d spectrogram B C T F + assert (condition.shape[-1] % latent_shape[-2]) == 0 + X = latent_shape[-1] * condition.shape[-1] // latent_shape[-2] + # copy on F direction + condition = condition.unsqueeze(-1).expand(-1, -1, -1, X) + elif len(latent_shape) == 3: + condition = condition + else: + raise NotImplementedError + return condition + + +if __name__ == '__main__': + conditioner = Conditioner(condition_type='energy', + hop_size=160, window_size=1024, padding='reflect', + min_db=-80, norm=True) + audio = torch.rand(4, 16000) # Example audio signal + energy = conditioner(audio, (4, 8, 100, 64)) \ No newline at end of file diff --git a/src/models/conditions/debug.png b/src/models/conditions/debug.png new file mode 100644 index 0000000000000000000000000000000000000000..999c9593896a2890de740ca5e3f878828c18d3df Binary files /dev/null and b/src/models/conditions/debug.png differ diff --git a/src/models/conditions/eg1.wav b/src/models/conditions/eg1.wav new file mode 100644 index 0000000000000000000000000000000000000000..6eb190a33a2d35c169811fa764df280fa2e906fb Binary files /dev/null and b/src/models/conditions/eg1.wav differ diff --git a/src/models/conditions/eg2.wav b/src/models/conditions/eg2.wav new file mode 100644 index 0000000000000000000000000000000000000000..281bda1f4c65d48de53f5a21aaa8e7f185cd914b Binary files /dev/null and b/src/models/conditions/eg2.wav differ diff --git a/src/models/conditions/energy.py b/src/models/conditions/energy.py new file mode 100644 index 0000000000000000000000000000000000000000..f86481723850af419fa701254863789108596e88 --- /dev/null +++ b/src/models/conditions/energy.py @@ -0,0 +1,85 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np + + +class EnergyExtractor(nn.Module): + def __init__(self, hop_size: int = 512, window_size: int = 1024, + padding: str = 'reflect', min_db: float = -60, + norm: bool = True, quantize_levels: int = None): + super().__init__() + self.hop_size = hop_size + self.window_size = window_size + self.padding = padding + self.min_db = min_db + self.norm = norm + self.quantize_levels = quantize_levels + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + # Compute number of frames + n_frames = int(audio.size(-1) // self.hop_size) + + # Pad the audio signal + pad_amount = (self.window_size - self.hop_size) // 2 + audio_padded = F.pad(audio, (pad_amount, pad_amount), mode=self.padding) + + # Square the padded audio signal + audio_squared = audio_padded ** 2 + + # Compute the mean energy for each frame using unfold and mean + audio_squared = audio_squared[:, None, None, :] + energy = F.unfold(audio_squared, (1, self.window_size), stride=self.hop_size)[:, :, :n_frames] + energy = energy.mean(dim=1) + + # Compute the square root of the mean energy to get the RMS energy + # energy = torch.sqrt(energy) + + # Normalize the energy using the min_db value + gain = torch.maximum(energy, torch.tensor(np.power(10, self.min_db / 10), device=audio.device)) + gain_db = 10 * torch.log10(gain) + + if self.norm: + # Find the min and max of gain_db + # min_gain_db = torch.min(gain_db) + min_gain_db = self.min_db + max_gain_db = torch.max(gain_db, dim=-1, keepdim=True)[0] + + # Avoid numerical error by adding a small epsilon to the denominator + epsilon = 1e-8 + gain_db = (gain_db - min_gain_db) / (max_gain_db - min_gain_db + epsilon) + + if self.quantize_levels is not None: + # Quantize the result to the given number of levels + gain_db = torch.round(gain_db * (self.quantize_levels - 1)) / (self.quantize_levels - 1) + + return gain_db.unsqueeze(-1) + + +if __name__ == "__main__": + energy_extractor = EnergyExtractor(hop_size=512, window_size=1024, padding='reflect', + min_db=-60, norm=True) + audio = torch.rand(1, 16000) + energy = energy_extractor(audio) + print(energy.shape) + import librosa + import matplotlib.pyplot as plt + # a1, _ = librosa.load('eg1.wav', sr=16000) + # a2, _ = librosa.load('eg2.wav', sr=16000) + # audio = torch.tensor([a1[:5*16000], a2[:5*16000]]) + a1, _ = librosa.load('eg2.wav', sr=24000) + audio = torch.tensor(a1[:5*16000]).unsqueeze(0) + energy = energy_extractor(audio) + print(energy.shape) + + # Plot the energy for each audio sample + plt.figure(figsize=(12, 6)) + + for i in range(energy.shape[0]): + plt.plot(energy[i, :, 0].cpu().numpy(), label=f'Audio {i+1}') + + plt.xlabel('Frame') + plt.ylabel('Energy (dB)') + plt.title('Energy over Time') + plt.legend() + plt.savefig('debug.png') \ No newline at end of file diff --git a/src/models/conditions/mbenergy.py b/src/models/conditions/mbenergy.py new file mode 100644 index 0000000000000000000000000000000000000000..39dedf081f61f15a5e49a922a1863b560888aa6c --- /dev/null +++ b/src/models/conditions/mbenergy.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import numpy as np +import julius +import soundfile as sf + + +class MultibandEnergyExtractor(nn.Module): + def __init__(self, hop_size: int = 512, window_size: int = 1024, + padding: str = 'reflect', min_db: float = -60, + norm: bool = True, quantize_levels: int = None, + n_bands: int = 8, control_bands: int = 4, + sample_rate: int = 24000,): + super().__init__() + self.hop_size = hop_size + self.window_size = window_size + self.padding = padding + self.min_db = min_db + self.norm = norm + self.quantize_levels = quantize_levels + self.n_bands = n_bands + self.control_bands = control_bands + self.sample_rate = sample_rate + + def forward(self, audio: torch.Tensor) -> torch.Tensor: + # Split the audio into frequency bands + audio = julius.split_bands(audio, n_bands=self.n_bands, + sample_rate=self.sample_rate)[:self.control_bands].transpose(0, 1) + B, C, _ = audio.shape + for i in range(C): + sf.write(f'output_{i}.wav', audio[0][i], self.sample_rate) + + # Compute number of frames + n_frames = int(audio.size(-1) // self.hop_size) + + # Pad the audio signal + pad_amount = (self.window_size - self.hop_size) // 2 + audio_padded = F.pad(audio, (pad_amount, pad_amount), mode=self.padding) + + # Square the padded audio signal + audio_squared = audio_padded ** 2 + + # Compute the mean energy for each frame using unfold and mean + energy = audio_squared.unfold(dimension=-1, size=self.window_size, step=self.hop_size) + energy = energy[:, :, :n_frames] + print(energy.shape) + energy = energy.mean(dim=-1) + print(energy.shape) + + # Compute the square root of the mean energy to get the RMS energy + # energy = torch.sqrt(energy) + + # Normalize the energy using the min_db value + gain = torch.maximum(energy, torch.tensor(np.power(10, self.min_db / 10), device=audio.device)) + gain_db = 10 * torch.log10(gain) + + if self.norm: + # Find the min and max of gain_db + # min_gain_db = torch.min(gain_db) + min_gain_db = self.min_db + max_gain_db = torch.amax(gain_db, dim=(-1, -2), keepdim=True) + + # Avoid numerical error by adding a small epsilon to the denominator + epsilon = 1e-8 + gain_db = (gain_db - min_gain_db) / (max_gain_db - min_gain_db + epsilon) + + if self.quantize_levels is not None: + # Quantize the result to the given number of levels + gain_db = torch.round(gain_db * (self.quantize_levels - 1)) / (self.quantize_levels - 1) + + return gain_db.transpose(-1, -2) + + +if __name__ == "__main__": + energy_extractor = MultibandEnergyExtractor(hop_size=320, window_size=1280, + padding='reflect', + min_db=-60, norm=True) + audio = torch.rand(4, 24000) + energy = energy_extractor(audio) + print(energy.shape) + import librosa + import matplotlib.pyplot as plt + a1, _ = librosa.load('eg2.wav', sr=24000) + audio = torch.tensor(a1[:5*16000]).unsqueeze(0) + energy = energy_extractor(audio) + print(energy.shape) + + # Plot the energy for each audio sample + plt.figure(figsize=(12, 6)) + + for i in range(energy.shape[-1]): + plt.plot(energy[0, :, i].cpu().numpy(), label=f'Band {i+1}') + + plt.xlabel('Frame') + plt.ylabel('Energy (dB)') + plt.title('Energy over Time') + plt.legend() + plt.savefig('debug.png') diff --git a/src/models/conditions/output_0.wav b/src/models/conditions/output_0.wav new file mode 100644 index 0000000000000000000000000000000000000000..2f10de23902599f046a7e2d045d804a4979188bf Binary files /dev/null and b/src/models/conditions/output_0.wav differ diff --git a/src/models/conditions/output_1.wav b/src/models/conditions/output_1.wav new file mode 100644 index 0000000000000000000000000000000000000000..ce38cd82d88373c27b5bf02510c208c2cfa29bac Binary files /dev/null and b/src/models/conditions/output_1.wav differ diff --git a/src/models/conditions/output_2.wav b/src/models/conditions/output_2.wav new file mode 100644 index 0000000000000000000000000000000000000000..a8df7bd2f6b5c7478d57914bdc382138e4f7101d Binary files /dev/null and b/src/models/conditions/output_2.wav differ diff --git a/src/models/conditions/output_3.wav b/src/models/conditions/output_3.wav new file mode 100644 index 0000000000000000000000000000000000000000..5b56740b310fc6cecbcf342034b54e8cde7cc3f3 Binary files /dev/null and b/src/models/conditions/output_3.wav differ diff --git a/src/models/conditions/voice.py b/src/models/conditions/voice.py new file mode 100644 index 0000000000000000000000000000000000000000..9739d69ee93d4b5b81c1e6bd83c94e48c1e5b2b3 --- /dev/null +++ b/src/models/conditions/voice.py @@ -0,0 +1,46 @@ +from transformers import HubertModel +import torch.nn as nn +import torch +import torch.nn.functional as F +import torchaudio +import librosa + + +class HubertModelWithFinalProj(HubertModel): + def __init__(self, config): + super().__init__(config) + + # The final projection layer is only used for backward compatibility. + # Following https://github.com/auspicious3000/contentvec/issues/6 + # Remove this layer is necessary to achieve the desired outcome. + self.final_proj = nn.Linear(config.hidden_size, config.classifier_proj_size) + + +class VoiceConversionExtractor(nn.Module): + # training on the fly might be slow + def __init__(self, config, sr): + super().__init__() + self.encoder = HubertModelWithFinalProj.from_pretrained(config) + self.encoder.eval() + self.sr = sr + self.target_sr = 16000 + if self.sr != self.target_sr: + self.resampler = torchaudio.transforms.Resample(orig_freq=self.sr, + new_freq=self.target_sr) + + def forward(self, audio): + if self.sr != self.target_sr: + audio = self.resampler(audio) + audio = F.pad(audio, ((400 - 320) // 2, (400 - 320) // 2)) + logits = self.encoder(audio)['last_hidden_state'] + return logits + + +if __name__ == '__main__': + model = VoiceConversionExtractor('lengyue233/content-vec-best', 24000) + audio, sr = librosa.load('test.wav', sr=24000) + audio = audio[:round(100*320*1.5)] + audio = torch.tensor([audio]) + with torch.no_grad(): + content = model(audio) + print(content.shape) \ No newline at end of file diff --git a/src/models/controlnet.py b/src/models/controlnet.py new file mode 100644 index 0000000000000000000000000000000000000000..1750621847ed116a6fbab55a50e67963699d6a5a --- /dev/null +++ b/src/models/controlnet.py @@ -0,0 +1,318 @@ +import torch +import torch.nn as nn + +from .utils.modules import PatchEmbed, TimestepEmbedder +from .utils.modules import PE_wrapper, RMSNorm +from .blocks import DiTBlock, JointDiTBlock +from .utils.span_mask import compute_mask_indices + + +class DiTControlNetEmbed(nn.Module): + def __init__(self, in_chans, out_chans, blocks, + cond_mask=False, cond_mask_prob=None, + cond_mask_ratio=None, cond_mask_span=None): + super().__init__() + self.conv_in = nn.Conv1d(in_chans, blocks[0], kernel_size=1) + + self.cond_mask = cond_mask + if self.cond_mask: + self.mask_embed = nn.Parameter(torch.zeros((blocks[0]))) + self.mask_prob = cond_mask_prob + self.mask_ratio = cond_mask_ratio + self.mask_span = cond_mask_span + blocks[0] = blocks[0] + 1 + + conv_blocks = [] + for i in range(len(blocks) - 1): + channel_in = blocks[i] + channel_out = blocks[i + 1] + block = nn.Sequential( + nn.Conv1d(channel_in, channel_in, kernel_size=3, padding=1), + nn.SiLU(), + nn.Conv1d(channel_in, channel_out, kernel_size=3, padding=1, stride=2), + nn.SiLU(),) + conv_blocks.append(block) + self.blocks = nn.ModuleList(conv_blocks) + + self.conv_out = nn.Conv1d(blocks[-1], out_chans, kernel_size=1) + nn.init.zeros_(self.conv_out.weight) + nn.init.zeros_(self.conv_out.bias) + + def random_masking(self, gt, mask_ratios, mae_mask_infer=None): + B, D, L = gt.shape + if mae_mask_infer is None: + # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1) + mask_ratios = mask_ratios.cpu().numpy() + mask = compute_mask_indices(shape=[B, L], + padding_mask=None, + mask_prob=mask_ratios, + mask_length=self.mask_span, + mask_type="static", + mask_other=0.0, + min_masks=1, + no_overlap=False, + min_space=0,) + # only apply mask to some batches + mask_batch = torch.rand(B) < self.mask_prob + mask[~mask_batch] = False + mask = mask.unsqueeze(1).expand_as(gt) + else: + mask = mae_mask_infer + mask = mask.expand_as(gt) + gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask].type_as(gt) + return gt, mask.type_as(gt) + + def forward(self, conditioning, cond_mask_infer=None): + embedding = self.conv_in(conditioning) + + if self.cond_mask: + B, D, L = embedding.shape + if not self.training and cond_mask_infer is None: + cond_mask_infer = torch.zeros_like(embedding).bool() + mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(embedding.device) + embedding, cond_mask = self.random_masking(embedding, mask_ratios, cond_mask_infer) + embedding = torch.cat([embedding, cond_mask[:, 0:1, :]], dim=1) + + for block in self.blocks: + embedding = block(embedding) + + embedding = self.conv_out(embedding) + + # B, L, C + embedding = embedding.transpose(1, 2).contiguous() + + return embedding + + +class DiTControlNet(nn.Module): + def __init__(self, + img_size=(224, 224), patch_size=16, in_chans=3, + input_type='2d', out_chans=None, + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., + qkv_bias=False, qk_scale=None, qk_norm=None, + act_layer='gelu', norm_layer='layernorm', + context_norm=False, + use_checkpoint=False, + # time fusion ada or token + time_fusion='token', + ada_lora_rank=None, ada_lora_alpha=None, + cls_dim=None, + # max length is only used for concat + context_dim=768, context_fusion='concat', + context_max_length=128, context_pe_method='sinu', + pe_method='abs', rope_mode='none', + use_conv=True, + skip=True, skip_norm=True, + # controlnet configs + cond_in=None, cond_blocks=None, + cond_mask=False, cond_mask_prob=None, + cond_mask_ratio=None, cond_mask_span=None, + **kwargs): + super().__init__() + self.num_features = self.embed_dim = embed_dim + # input + self.in_chans = in_chans + self.input_type = input_type + if self.input_type == '2d': + num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + elif self.input_type == '1d': + num_patches = img_size // patch_size + self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, input_type=input_type) + out_chans = in_chans if out_chans is None else out_chans + self.out_chans = out_chans + + # position embedding + self.rope = rope_mode + self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method, + length=num_patches) + + print(f'x position embedding: {pe_method}') + print(f'rope mode: {self.rope}') + + # time embed + self.time_embed = TimestepEmbedder(embed_dim) + self.time_fusion = time_fusion + self.use_adanorm = False + + # cls embed + if cls_dim is not None: + self.cls_embed = nn.Sequential( + nn.Linear(cls_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True),) + else: + self.cls_embed = None + + # time fusion + if time_fusion == 'token': + # put token at the beginning of sequence + self.extras = 2 if self.cls_embed else 1 + self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras) + elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']: + self.use_adanorm = True + # aviod repetitive silu for each adaln block + self.time_act = nn.SiLU() + self.extras = 0 + if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']: + # shared adaln + self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) + else: + self.time_ada = None + else: + raise NotImplementedError + print(f'time fusion mode: {self.time_fusion}') + + # context + # use a simple projection + self.use_context = False + self.context_cross = False + self.context_max_length = context_max_length + self.context_fusion = 'none' + if context_dim is not None: + self.use_context = True + self.context_embed = nn.Sequential( + nn.Linear(context_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True),) + self.context_fusion = context_fusion + if context_fusion == 'concat' or context_fusion == 'joint': + self.extras += context_max_length + self.context_pe = PE_wrapper(dim=embed_dim, + method=context_pe_method, + length=context_max_length) + # no cross attention layers + context_dim = None + elif context_fusion == 'cross': + self.context_pe = PE_wrapper(dim=embed_dim, + method=context_pe_method, + length=context_max_length) + self.context_cross = True + context_dim = embed_dim + else: + raise NotImplementedError + print(f'context fusion mode: {context_fusion}') + print(f'context position embedding: {context_pe_method}') + + if self.context_fusion == 'joint': + Block = JointDiTBlock + else: + Block = DiTBlock + + # norm layers + if norm_layer == 'layernorm': + norm_layer = nn.LayerNorm + elif norm_layer == 'rmsnorm': + norm_layer = RMSNorm + else: + raise NotImplementedError + + self.in_blocks = nn.ModuleList([ + Block( + dim=embed_dim, context_dim=context_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, + act_layer=act_layer, norm_layer=norm_layer, + time_fusion=time_fusion, + ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, + skip=False, skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + self.controlnet_pre = DiTControlNetEmbed(in_chans=cond_in, out_chans=embed_dim, + blocks=cond_blocks, + cond_mask=cond_mask, + cond_mask_prob=cond_mask_prob, + cond_mask_ratio=cond_mask_ratio, + cond_mask_span=cond_mask_span) + + controlnet_zero_blocks = [] + for i in range(depth // 2): + block = nn.Linear(embed_dim, embed_dim) + nn.init.zeros_(block.weight) + nn.init.zeros_(block.bias) + controlnet_zero_blocks.append(block) + self.controlnet_zero_blocks = nn.ModuleList(controlnet_zero_blocks) + + print('ControlNet ready \n') + + def set_trainable(self): + for param in self.parameters(): + param.requires_grad = False + + # only train input_proj, blocks, and output_proj + for module_name in ['controlnet_pre', 'in_blocks', 'controlnet_zero_blocks']: + module = getattr(self, module_name, None) + if module is not None: + for param in module.parameters(): + param.requires_grad = True + module.train() + else: + print(f'\n!!!warning missing trainable blocks: {module_name}!!!\n') + + def forward(self, x, timesteps, context, + x_mask=None, context_mask=None, + cls_token=None, + condition=None, cond_mask_infer=None, + conditioning_scale=1.0): + # make it compatible with int time step during inference + if timesteps.dim() == 0: + timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long) + + x = self.patch_embed(x) + # add condition to x + condition = self.controlnet_pre(condition) + x = x + condition + x = self.x_pe(x) + + B, L, D = x.shape + + if self.use_context: + context_token = self.context_embed(context) + context_token = self.context_pe(context_token) + if self.context_fusion == 'concat' or self.context_fusion == 'joint': + x, x_mask = self._concat_x_context(x=x, context=context_token, + x_mask=x_mask, + context_mask=context_mask) + context_token, context_mask = None, None + else: + context_token, context_mask = None, None + + time_token = self.time_embed(timesteps) + if self.cls_embed: + cls_token = self.cls_embed(cls_token) + time_ada = None + if self.use_adanorm: + if self.cls_embed: + time_token = time_token + cls_token + time_token = self.time_act(time_token) + if self.time_ada is not None: + time_ada = self.time_ada(time_token) + else: + time_token = time_token.unsqueeze(dim=1) + if self.cls_embed: + cls_token = cls_token.unsqueeze(dim=1) + time_token = torch.cat([time_token, cls_token], dim=1) + time_token = self.time_pe(time_token) + x = torch.cat((time_token, x), dim=1) + if x_mask is not None: + x_mask = torch.cat( + [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(), + x_mask], dim=1) + time_token = None + + skips = [] + for blk in self.in_blocks: + x = blk(x=x, time_token=time_token, time_ada=time_ada, + skip=None, context=context_token, + x_mask=x_mask, context_mask=context_mask, + extras=self.extras) + skips.append(x) + + controlnet_skips = [] + for skip, controlnet_block in zip(skips, self.controlnet_zero_blocks): + controlnet_skips.append(controlnet_block(skip) * conditioning_scale) + + return controlnet_skips \ No newline at end of file diff --git a/src/models/udit.py b/src/models/udit.py new file mode 100644 index 0000000000000000000000000000000000000000..e126efd370efabbfcc4f4359194f9c95c6e9d154 --- /dev/null +++ b/src/models/udit.py @@ -0,0 +1,365 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint +import math +from .utils.modules import PatchEmbed, TimestepEmbedder +from .utils.modules import PE_wrapper, RMSNorm +from .blocks import DiTBlock, JointDiTBlock, FinalBlock + + +class UDiT(nn.Module): + def __init__(self, + img_size=224, patch_size=16, in_chans=3, + input_type='2d', out_chans=None, + embed_dim=768, depth=12, num_heads=12, mlp_ratio=4., + qkv_bias=False, qk_scale=None, qk_norm=None, + act_layer='gelu', norm_layer='layernorm', + context_norm=False, + use_checkpoint=False, + # time fusion ada or token + time_fusion='token', + ada_lora_rank=None, ada_lora_alpha=None, + cls_dim=None, + # max length is only used for concat + context_dim=768, context_fusion='concat', + context_max_length=128, context_pe_method='sinu', + pe_method='abs', rope_mode='none', + use_conv=True, + skip=True, skip_norm=True): + super().__init__() + self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models + + # input + self.in_chans = in_chans + self.input_type = input_type + if self.input_type == '2d': + num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size) + elif self.input_type == '1d': + num_patches = img_size // patch_size + self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans, + embed_dim=embed_dim, input_type=input_type) + out_chans = in_chans if out_chans is None else out_chans + self.out_chans = out_chans + + # position embedding + self.rope = rope_mode + self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method, + length=num_patches) + + print(f'x position embedding: {pe_method}') + print(f'rope mode: {self.rope}') + + # time embed + self.time_embed = TimestepEmbedder(embed_dim) + self.time_fusion = time_fusion + self.use_adanorm = False + + # cls embed + if cls_dim is not None: + self.cls_embed = nn.Sequential( + nn.Linear(cls_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True),) + else: + self.cls_embed = None + + # time fusion + if time_fusion == 'token': + # put token at the beginning of sequence + self.extras = 2 if self.cls_embed else 1 + self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras) + elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']: + self.use_adanorm = True + # aviod repetitive silu for each adaln block + self.time_act = nn.SiLU() + self.extras = 0 + self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True) + if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']: + # shared adaln + self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True) + else: + self.time_ada = None + else: + raise NotImplementedError + print(f'time fusion mode: {self.time_fusion}') + + # context + # use a simple projection + self.use_context = False + self.context_cross = False + self.context_max_length = context_max_length + self.context_fusion = 'none' + if context_dim is not None: + self.use_context = True + self.context_embed = nn.Sequential( + nn.Linear(context_dim, embed_dim, bias=True), + nn.SiLU(), + nn.Linear(embed_dim, embed_dim, bias=True),) + self.context_fusion = context_fusion + if context_fusion == 'concat' or context_fusion == 'joint': + self.extras += context_max_length + self.context_pe = PE_wrapper(dim=embed_dim, + method=context_pe_method, + length=context_max_length) + # no cross attention layers + context_dim = None + elif context_fusion == 'cross': + self.context_pe = PE_wrapper(dim=embed_dim, + method=context_pe_method, + length=context_max_length) + self.context_cross = True + context_dim = embed_dim + else: + raise NotImplementedError + print(f'context fusion mode: {context_fusion}') + print(f'context position embedding: {context_pe_method}') + + if self.context_fusion == 'joint': + Block = JointDiTBlock + self.use_skip = skip[0] + else: + Block = DiTBlock + self.use_skip = skip + + # norm layers + if norm_layer == 'layernorm': + norm_layer = nn.LayerNorm + elif norm_layer == 'rmsnorm': + norm_layer = RMSNorm + else: + raise NotImplementedError + + print(f'use long skip connection: {skip}') + self.in_blocks = nn.ModuleList([ + Block( + dim=embed_dim, context_dim=context_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, + act_layer=act_layer, norm_layer=norm_layer, + time_fusion=time_fusion, + ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, + skip=False, skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + self.mid_block = Block( + dim=embed_dim, context_dim=context_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, + act_layer=act_layer, norm_layer=norm_layer, + time_fusion=time_fusion, + ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, + skip=False, skip_norm=False, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint) + + self.out_blocks = nn.ModuleList([ + Block( + dim=embed_dim, context_dim=context_dim, num_heads=num_heads, + mlp_ratio=mlp_ratio, + qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm, + act_layer=act_layer, norm_layer=norm_layer, + time_fusion=time_fusion, + ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha, + skip=skip, skip_norm=skip_norm, + rope_mode=self.rope, + context_norm=context_norm, + use_checkpoint=use_checkpoint) + for _ in range(depth // 2)]) + + # FinalLayer block + self.use_conv = use_conv + self.final_block = FinalBlock(embed_dim=embed_dim, + patch_size=patch_size, + img_size=img_size, + in_chans=out_chans, + input_type=input_type, + norm_layer=norm_layer, + use_conv=use_conv, + use_adanorm=self.use_adanorm) + self.initialize_weights() + + def _init_ada(self): + if self.time_fusion == 'ada': + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + for block in self.in_blocks: + nn.init.constant_(block.adaln.time_ada.weight, 0) + nn.init.constant_(block.adaln.time_ada.bias, 0) + nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0) + nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0) + for block in self.out_blocks: + nn.init.constant_(block.adaln.time_ada.weight, 0) + nn.init.constant_(block.adaln.time_ada.bias, 0) + elif self.time_fusion == 'ada_single': + nn.init.constant_(self.time_ada.weight, 0) + nn.init.constant_(self.time_ada.bias, 0) + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + elif self.time_fusion in ['ada_lora', 'ada_lora_bias']: + nn.init.constant_(self.time_ada.weight, 0) + nn.init.constant_(self.time_ada.bias, 0) + nn.init.constant_(self.time_ada_final.weight, 0) + nn.init.constant_(self.time_ada_final.bias, 0) + for block in self.in_blocks: + nn.init.kaiming_uniform_(block.adaln.lora_a.weight, + a=math.sqrt(5)) + nn.init.constant_(block.adaln.lora_b.weight, 0) + nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight, + a=math.sqrt(5)) + nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0) + for block in self.out_blocks: + nn.init.kaiming_uniform_(block.adaln.lora_a.weight, + a=math.sqrt(5)) + nn.init.constant_(block.adaln.lora_b.weight, 0) + + def initialize_weights(self): + # Basic init for all layers + def _basic_init(module): + if isinstance(module, nn.Linear): + torch.nn.init.xavier_uniform_(module.weight) + if module.bias is not None: + nn.init.constant_(module.bias, 0) + self.apply(_basic_init) + + # init patch Conv like Linear + w = self.patch_embed.proj.weight.data + nn.init.xavier_uniform_(w.view([w.shape[0], -1])) + nn.init.constant_(self.patch_embed.proj.bias, 0) + + # Zero-out AdaLN + if self.use_adanorm: + self._init_ada() + + # Zero-out Cross Attention + if self.context_cross: + for block in self.in_blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0) + nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0) + for block in self.out_blocks: + nn.init.constant_(block.cross_attn.proj.weight, 0) + nn.init.constant_(block.cross_attn.proj.bias, 0) + + # Zero-out cls embedding + if self.cls_embed: + if self.use_adanorm: + nn.init.constant_(self.cls_embed[-1].weight, 0) + nn.init.constant_(self.cls_embed[-1].bias, 0) + + # Zero-out Output + # might not zero-out this when using v-prediction + # it could be good when using noise-prediction + # nn.init.constant_(self.final_block.linear.weight, 0) + # nn.init.constant_(self.final_block.linear.bias, 0) + # if self.use_conv: + # nn.init.constant_(self.final_block.final_layer.weight.data, 0) + # nn.init.constant_(self.final_block.final_layer.bias, 0) + + # init out Conv + if self.use_conv: + nn.init.xavier_uniform_(self.final_block.final_layer.weight) + nn.init.constant_(self.final_block.final_layer.bias, 0) + + def _concat_x_context(self, x, context, x_mask=None, context_mask=None): + assert context.shape[-2] == self.context_max_length + # Check if either x_mask or context_mask is provided + B = x.shape[0] + # Create default masks if they are not provided + if x_mask is None: + x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() + if context_mask is None: + context_mask = torch.ones(B, context.shape[-2], + device=context.device).bool() + # Concatenate the masks along the second dimension (dim=1) + x_mask = torch.cat([context_mask, x_mask], dim=1) + # Concatenate context and x along the second dimension (dim=1) + x = torch.cat((context, x), dim=1) + return x, x_mask + + def forward(self, x, timesteps, context, + x_mask=None, context_mask=None, + cls_token=None, controlnet_skips=None, + ): + # make it compatible with int time step during inference + if timesteps.dim() == 0: + timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long) + + x = self.patch_embed(x) + x = self.x_pe(x) + + B, L, D = x.shape + + if self.use_context: + context_token = self.context_embed(context) + context_token = self.context_pe(context_token) + if self.context_fusion == 'concat' or self.context_fusion == 'joint': + x, x_mask = self._concat_x_context(x=x, context=context_token, + x_mask=x_mask, + context_mask=context_mask) + context_token, context_mask = None, None + else: + context_token, context_mask = None, None + + time_token = self.time_embed(timesteps) + if self.cls_embed: + cls_token = self.cls_embed(cls_token) + time_ada = None + time_ada_final = None + if self.use_adanorm: + if self.cls_embed: + time_token = time_token + cls_token + time_token = self.time_act(time_token) + time_ada_final = self.time_ada_final(time_token) + if self.time_ada is not None: + time_ada = self.time_ada(time_token) + else: + time_token = time_token.unsqueeze(dim=1) + if self.cls_embed: + cls_token = cls_token.unsqueeze(dim=1) + time_token = torch.cat([time_token, cls_token], dim=1) + time_token = self.time_pe(time_token) + x = torch.cat((time_token, x), dim=1) + if x_mask is not None: + x_mask = torch.cat( + [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(), + x_mask], dim=1) + time_token = None + + skips = [] + for blk in self.in_blocks: + x = blk(x=x, time_token=time_token, time_ada=time_ada, + skip=None, context=context_token, + x_mask=x_mask, context_mask=context_mask, + extras=self.extras) + if self.use_skip: + skips.append(x) + + x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada, + skip=None, context=context_token, + x_mask=x_mask, context_mask=context_mask, + extras=self.extras) + for blk in self.out_blocks: + if self.use_skip: + skip = skips.pop() + if controlnet_skips: + # add to skip like u-net controlnet + skip = skip + controlnet_skips.pop() + else: + skip = None + if controlnet_skips: + # directly add to x + x = x + controlnet_skips.pop() + + x = blk(x=x, time_token=time_token, time_ada=time_ada, + skip=skip, context=context_token, + x_mask=x_mask, context_mask=context_mask, + extras=self.extras) + + x = self.final_block(x, time_ada=time_ada_final, extras=self.extras) + + return x \ No newline at end of file diff --git a/src/models/utils/.ipynb_checkpoints/__init__-checkpoint.py b/src/models/utils/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py b/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..5f77c9148fc54916e1bedc2f36f77f6a2164986a --- /dev/null +++ b/src/models/utils/.ipynb_checkpoints/attention-checkpoint.py @@ -0,0 +1,290 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .rotary import RotaryEmbedding +from .modules import RMSNorm + + +if hasattr(nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + ATTENTION_MODE = 'math' +print(f'attention mode is {ATTENTION_MODE}') + + +def add_mask(sim, mask): + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + + +def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None): + def default(val, d): + return val if val is not None else (d() if isfunction(d) else d) + b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device + q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool)) + k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool)) + attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j') + return attn_mask + + +class Attention(nn.Module): + def __init__(self, dim, context_dim=None, num_heads=8, + qkv_bias=False, qk_scale=None, qk_norm=None, + attn_drop=0., proj_drop=0., rope_mode='none'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + if context_dim is None: + self.cross_attn = False + else: + self.cross_attn = True + + context_dim = dim if context_dim is None else context_dim + + self.to_q = nn.Linear(dim, dim, bias=qkv_bias) + self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias) + self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias) + + if qk_norm is None: + self.norm_q = nn.Identity() + self.norm_k = nn.Identity() + elif qk_norm == 'layernorm': + self.norm_q = nn.LayerNorm(head_dim) + self.norm_k = nn.LayerNorm(head_dim) + elif qk_norm == 'rmsnorm': + self.norm_q = RMSNorm(head_dim) + self.norm_k = RMSNorm(head_dim) + else: + raise NotImplementedError + + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if self.cross_attn: + assert rope_mode == 'none' + self.rope_mode = rope_mode + if self.rope_mode == 'shared' or self.rope_mode == 'x_only': + self.rotary = RotaryEmbedding(dim=head_dim) + elif self.rope_mode == 'dual': + self.rotary_x = RotaryEmbedding(dim=head_dim) + self.rotary_c = RotaryEmbedding(dim=head_dim) + + def _rotary(self, q, k, extras): + if self.rope_mode == 'shared': + q, k = self.rotary(q=q, k=k) + elif self.rope_mode == 'x_only': + q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'dual': + q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :]) + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'none': + pass + else: + raise NotImplementedError + return q, k + + def _attn(self, q, k, v, mask_binary): + if ATTENTION_MODE == 'flash': + x = F.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + else: + raise NotImplementedError + return x + + def forward(self, x, context=None, context_mask=None, extras=0): + B, L, C = x.shape + if context is None: + context = x + + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + + if context_mask is not None: + mask_binary = create_mask(x.shape, context.shape, + x.device, None, context_mask) + else: + mask_binary = None + + q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads) + k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads) + v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads) + + q = self.norm_q(q) + k = self.norm_k(k) + + q, k = self._rotary(q, k, extras) + + x = self._attn(q, k, v, mask_binary) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class JointAttention(nn.Module): + def __init__(self, dim, num_heads=8, + qkv_bias=False, qk_scale=None, qk_norm=None, + attn_drop=0., proj_drop=0., + rope_mode='none'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(dim, qkv_bias) + self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(dim, qkv_bias) + + self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim) + self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim) + + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + + self.proj_x = nn.Linear(dim, dim) + self.proj_drop_x = nn.Dropout(proj_drop) + + self.proj_c = nn.Linear(dim, dim) + self.proj_drop_c = nn.Dropout(proj_drop) + + self.rope_mode = rope_mode + if self.rope_mode == 'shared' or self.rope_mode == 'x_only': + self.rotary = RotaryEmbedding(dim=head_dim) + elif self.rope_mode == 'dual': + self.rotary_x = RotaryEmbedding(dim=head_dim) + self.rotary_c = RotaryEmbedding(dim=head_dim) + + def _make_qkv_layers(self, dim, qkv_bias): + return (nn.Linear(dim, dim, bias=qkv_bias), + nn.Linear(dim, dim, bias=qkv_bias), + nn.Linear(dim, dim, bias=qkv_bias)) + + def _make_norm_layers(self, qk_norm, head_dim): + if qk_norm is None: + norm_q = nn.Identity() + norm_k = nn.Identity() + elif qk_norm == 'layernorm': + norm_q = nn.LayerNorm(head_dim) + norm_k = nn.LayerNorm(head_dim) + elif qk_norm == 'rmsnorm': + norm_q = RMSNorm(head_dim) + norm_k = RMSNorm(head_dim) + else: + raise NotImplementedError + return norm_q, norm_k + + def _rotary(self, q, k, extras): + if self.rope_mode == 'shared': + q, k = self.rotary(q=q, k=k) + elif self.rope_mode == 'x_only': + q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'dual': + q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :]) + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'none': + pass + else: + raise NotImplementedError + return q, k + + def _attn(self, q, k, v, mask_binary): + if ATTENTION_MODE == 'flash': + x = F.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + else: + raise NotImplementedError + return x + + def _cat_mask(self, x, context, x_mask=None, context_mask=None): + B = x.shape[0] + if x_mask is None: + x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() + if context_mask is None: + context_mask = torch.ones(B, context.shape[-2], device=context.device).bool() + mask = torch.cat([context_mask, x_mask], dim=1) + return mask + + def forward(self, x, context, x_mask=None, context_mask=None, extras=0): + B, Lx, C = x.shape + _, Lc, _ = context.shape + if x_mask is not None or context_mask is not None: + mask = self._cat_mask(x, context, + x_mask=x_mask, + context_mask=context_mask) + shape = [B, Lx+Lc, C] + mask_binary = create_mask(q_shape=shape, k_shape=shape, + device=x.device, + q_mask=None, k_mask=mask) + else: + mask_binary = None + + qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x) + qc, kc, vc = self.to_qc(context), self.to_kc(context), self.to_vc(context) + + qx, kx, vx = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D', + H=self.num_heads), [qx, kx, vx]) + qc, kc, vc = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D', + H=self.num_heads), [qc, kc, vc]) + + qx, kx = self.norm_qx(qx), self.norm_kx(kx) + qc, kc = self.norm_qc(qc), self.norm_kc(kc) + + q, k, v = (torch.cat([qc, qx], dim=2), + torch.cat([kc, kx], dim=2), + torch.cat([vc, vx], dim=2)) + + q, k = self._rotary(q, k, extras) + + x = self._attn(q, k, v, mask_binary) + + context, x = x[:, :Lc, :], x[:, Lc:, :] + + x = self.proj_x(x) + x = self.proj_drop_x(x) + + context = self.proj_c(context) + context = self.proj_drop_c(context) + + return x, context \ No newline at end of file diff --git a/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py b/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..c825b988439d9b91e1e1d30c1cf842880252c0bf --- /dev/null +++ b/src/models/utils/.ipynb_checkpoints/modules-checkpoint.py @@ -0,0 +1,374 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.cuda.amp import autocast +import math +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .timm import trunc_normal_ + + +# disable in checkpoint mode +# @torch.jit.script +def film_modulate(x, shift, scale): + return x * (1 + scale) + shift + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256, + out_size=None): + super().__init__() + if out_size is None: + out_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type( + self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +def patchify(imgs, patch_size, input_type='2d'): + if input_type == '2d': + x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) + elif input_type == '1d': + x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size) + return x + + +def unpatchify(x, channels=3, input_type='2d', img_size=None): + if input_type == '2d': + patch_size = int((x.shape[2] // channels) ** 0.5) + # h = w = int(x.shape[1] ** .5) + h, w = img_size[0] // patch_size, img_size[1] // patch_size + assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] + x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, + p1=patch_size, p2=patch_size) + elif input_type == '1d': + patch_size = int((x.shape[2] // channels)) + h = x.shape[1] + assert patch_size * channels == x.shape[2] + x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size) + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'): + super().__init__() + self.patch_size = patch_size + self.input_type = input_type + if input_type == '2d': + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + elif input_type == '1d': + self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + + def forward(self, x): + if self.input_type == '2d': + B, C, H, W = x.shape + assert H % self.patch_size == 0 and W % self.patch_size == 0 + elif self.input_type == '1d': + B, C, H = x.shape + assert H % self.patch_size == 0 + + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PositionalConvEmbedding(nn.Module): + """ + Relative positional embedding used in HuBERT + """ + + def __init__(self, dim=768, kernel_size=128, groups=16): + super().__init__() + self.conv = nn.Conv1d( + dim, + dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + bias=True + ) + self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x): + # B C T + x = self.conv(x) + x = F.gelu(x[:, :, :-1]) + return x + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, dim, length): + super(SinusoidalPositionalEncoding, self).__init__() + self.length = length + self.dim = dim + self.register_buffer('pe', self._generate_positional_encoding(length, dim)) + + def _generate_positional_encoding(self, length, dim): + pe = torch.zeros(length, dim) + position = torch.arange(0, length, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + return pe + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return x + + +class PE_wrapper(nn.Module): + def __init__(self, dim=768, method='abs', length=None, **kwargs): + super().__init__() + self.method = method + if method == 'abs': + # init absolute pe like UViT + self.length = length + self.abs_pe = nn.Parameter(torch.zeros(1, length, dim)) + trunc_normal_(self.abs_pe, std=.02) + elif method == 'conv': + self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs) + elif method == 'sinu': + self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length) + elif method == 'none': + # skip pe + self.id = nn.Identity() + else: + raise NotImplementedError + + def forward(self, x): + if self.method == 'abs': + _, L, _ = x.shape + assert L <= self.length + x = x + self.abs_pe[:, :L, :] + elif self.method == 'conv': + x = x + self.conv_pe(x) + elif self.method == 'sinu': + x = self.sinu_pe(x) + elif self.method == 'none': + x = self.id(x) + else: + raise NotImplementedError + return x + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class GELU(nn.Module): + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", + bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), + approximate=self.approximate).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +# disable in checkpoint mode +# @torch.jit.script +def snake_beta(x, alpha, beta): + return x + beta * torch.sin(x * alpha).pow(2) + + +class Snake(nn.Module): + def __init__(self, dim_in, dim_out, bias, + alpha_trainable=True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.alpha = nn.Parameter(torch.ones(1, 1, dim_out)) + self.beta = nn.Parameter(torch.ones(1, 1, dim_out)) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + def forward(self, x): + x = self.proj(x) + x = snake_beta(x, self.alpha, self.beta) + return x + + +class GESnake(nn.Module): + def __init__(self, dim_in, dim_out, bias, + alpha_trainable=True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.alpha = nn.Parameter(torch.ones(1, 1, dim_out)) + self.beta = nn.Parameter(torch.ones(1, 1, dim_out)) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + def forward(self, x): + x = self.proj(x) + x, gate = x.chunk(2, dim=-1) + return x * snake_beta(gate, self.alpha, self.beta) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out=None, + mult=4, + dropout=0.0, + activation_fn="geglu", + final_dropout=False, + inner_dim=None, + bias=True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "snake": + act_fn = Snake(dim, inner_dim, bias=bias) + elif activation_fn == "gesnake": + act_fn = GESnake(dim, inner_dim, bias=bias) + else: + raise NotImplementedError + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states \ No newline at end of file diff --git a/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py b/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..636fbf6558b0d469f6802b10c180bbbb6fc431cc --- /dev/null +++ b/src/models/utils/.ipynb_checkpoints/rotary-checkpoint.py @@ -0,0 +1,91 @@ +import torch + +"this rope is faster than llama rope with jit script" + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +# disable in checkpoint mode +# @torch.jit.script +def apply_rotary_pos_emb(x, cos, sin): + # NOTE: This could probably be moved to Triton + # Handle a possible sequence length mismatch in between q and k + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + return (x * cos) + (rotate_half(x) * sin) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + + .. warning: Please note that this embedding is not registered on purpose, as it is transformative + (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis + """ + + def __init__(self, dim: int): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=-2): + # expect input: B, H, L, D + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + # also make sure dtype wont change + if ( + seq_len != self._seq_len_cached + or self._cos_cached.device != x.device + or self._cos_cached.dtype != x.dtype + ): + self._seq_len_cached = seq_len + t = torch.arange( + x.shape[seq_dimension], device=x.device, dtype=torch.float32 + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) + self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) + + return self._cos_cached, self._sin_cached + + def forward(self, q, k): + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + q.float(), seq_dimension=-2 + ) + if k is not None: + return ( + apply_rotary_pos_emb(q.float(), + self._cos_cached, + self._sin_cached).type_as(q), + apply_rotary_pos_emb(k.float(), + self._cos_cached, + self._sin_cached).type_as(k), + ) + else: + return ( + apply_rotary_pos_emb(q.float(), + self._cos_cached, + self._sin_cached).type_as(q), + None + ) \ No newline at end of file diff --git a/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py b/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..6d003a6c08c1675967f992e3d052b293a202d446 --- /dev/null +++ b/src/models/utils/.ipynb_checkpoints/span_mask-checkpoint.py @@ -0,0 +1,146 @@ +import numpy as np +import torch +from typing import Optional, Tuple + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + # Convert mask_prob to a NumPy array + mask_prob = np.array(mask_prob) + + # Calculate all_num_mask for each element in the batch + all_num_mask = np.floor(mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)).astype(int) + + # Apply the max operation with min_masks for each element + all_num_mask = np.maximum(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask[i] + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + # min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + # if len(mask_idc) > min_len: + # mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return torch.tensor(mask) + + +if __name__ == '__main__': + mask = compute_mask_indices( + shape=[4, 500], + padding_mask=None, + mask_prob=[0.65, 0.5, 0.65, 0.65], + mask_length=10, + mask_type="static", + mask_other=0.0, + min_masks=1, + no_overlap=False, + min_space=0, + ) + print(mask) + print(mask.sum(dim=1)) \ No newline at end of file diff --git a/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py b/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ae72f4fca681617778028b64da5daed26b3ed3d4 --- /dev/null +++ b/src/models/utils/.ipynb_checkpoints/timm-checkpoint.py @@ -0,0 +1,114 @@ +# code from timm 0.3.2 +import torch +import torch.nn as nn +import math +import warnings + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x \ No newline at end of file diff --git a/src/models/utils/__init__.py b/src/models/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/models/utils/__pycache__/__init__.cpython-310.pyc b/src/models/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fc31b0a19bf83eef2a703df69b4272a12bfbe577 Binary files /dev/null and b/src/models/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/__init__.cpython-311.pyc b/src/models/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c349d3285b7eb700500142d0d54cccaad6d0a80 Binary files /dev/null and b/src/models/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/models/utils/__pycache__/attention.cpython-310.pyc b/src/models/utils/__pycache__/attention.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9ebbb2fc665b2c2ff2115bc98ec054430333e6ee Binary files /dev/null and b/src/models/utils/__pycache__/attention.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/attention.cpython-311.pyc b/src/models/utils/__pycache__/attention.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..51bf17ebc671077c9fd467e530c92da966b3095a Binary files /dev/null and b/src/models/utils/__pycache__/attention.cpython-311.pyc differ diff --git a/src/models/utils/__pycache__/modules.cpython-310.pyc b/src/models/utils/__pycache__/modules.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d507df3df9cbf9fa29fbabb4591df36aedf6bdd4 Binary files /dev/null and b/src/models/utils/__pycache__/modules.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/modules.cpython-311.pyc b/src/models/utils/__pycache__/modules.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1e5628f9e5e0a0c83402e4c6a350bb50fdd802c9 Binary files /dev/null and b/src/models/utils/__pycache__/modules.cpython-311.pyc differ diff --git a/src/models/utils/__pycache__/rotary.cpython-310.pyc b/src/models/utils/__pycache__/rotary.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..716bfd4abc834b3e912c4f4574ddc3d3597183a5 Binary files /dev/null and b/src/models/utils/__pycache__/rotary.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/rotary.cpython-311.pyc b/src/models/utils/__pycache__/rotary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96fbebedb56e837be29b9c7b0fd2a3c053571679 Binary files /dev/null and b/src/models/utils/__pycache__/rotary.cpython-311.pyc differ diff --git a/src/models/utils/__pycache__/span_mask.cpython-310.pyc b/src/models/utils/__pycache__/span_mask.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8dc66ee9fc445b41e939f29ba32cfdeb6169bcc2 Binary files /dev/null and b/src/models/utils/__pycache__/span_mask.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/span_mask.cpython-311.pyc b/src/models/utils/__pycache__/span_mask.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1570b3fd02c86a35de5716cfa17d9c4384595bc7 Binary files /dev/null and b/src/models/utils/__pycache__/span_mask.cpython-311.pyc differ diff --git a/src/models/utils/__pycache__/timm.cpython-310.pyc b/src/models/utils/__pycache__/timm.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0fa38c2a3330015ecde142ac1e187afd6afd3aa5 Binary files /dev/null and b/src/models/utils/__pycache__/timm.cpython-310.pyc differ diff --git a/src/models/utils/__pycache__/timm.cpython-311.pyc b/src/models/utils/__pycache__/timm.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..9f30fbb13efb3ca0c053bbf2b0d9c48c71b82465 Binary files /dev/null and b/src/models/utils/__pycache__/timm.cpython-311.pyc differ diff --git a/src/models/utils/attention.py b/src/models/utils/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..5f77c9148fc54916e1bedc2f36f77f6a2164986a --- /dev/null +++ b/src/models/utils/attention.py @@ -0,0 +1,290 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .rotary import RotaryEmbedding +from .modules import RMSNorm + + +if hasattr(nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + ATTENTION_MODE = 'math' +print(f'attention mode is {ATTENTION_MODE}') + + +def add_mask(sim, mask): + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + + +def create_mask(q_shape, k_shape, device, q_mask=None, k_mask=None): + def default(val, d): + return val if val is not None else (d() if isfunction(d) else d) + b, i, j, device = q_shape[0], q_shape[-2], k_shape[-2], device + q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool)) + k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool)) + attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j') + return attn_mask + + +class Attention(nn.Module): + def __init__(self, dim, context_dim=None, num_heads=8, + qkv_bias=False, qk_scale=None, qk_norm=None, + attn_drop=0., proj_drop=0., rope_mode='none'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + if context_dim is None: + self.cross_attn = False + else: + self.cross_attn = True + + context_dim = dim if context_dim is None else context_dim + + self.to_q = nn.Linear(dim, dim, bias=qkv_bias) + self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias) + self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias) + + if qk_norm is None: + self.norm_q = nn.Identity() + self.norm_k = nn.Identity() + elif qk_norm == 'layernorm': + self.norm_q = nn.LayerNorm(head_dim) + self.norm_k = nn.LayerNorm(head_dim) + elif qk_norm == 'rmsnorm': + self.norm_q = RMSNorm(head_dim) + self.norm_k = RMSNorm(head_dim) + else: + raise NotImplementedError + + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + if self.cross_attn: + assert rope_mode == 'none' + self.rope_mode = rope_mode + if self.rope_mode == 'shared' or self.rope_mode == 'x_only': + self.rotary = RotaryEmbedding(dim=head_dim) + elif self.rope_mode == 'dual': + self.rotary_x = RotaryEmbedding(dim=head_dim) + self.rotary_c = RotaryEmbedding(dim=head_dim) + + def _rotary(self, q, k, extras): + if self.rope_mode == 'shared': + q, k = self.rotary(q=q, k=k) + elif self.rope_mode == 'x_only': + q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'dual': + q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :]) + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'none': + pass + else: + raise NotImplementedError + return q, k + + def _attn(self, q, k, v, mask_binary): + if ATTENTION_MODE == 'flash': + x = F.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + else: + raise NotImplementedError + return x + + def forward(self, x, context=None, context_mask=None, extras=0): + B, L, C = x.shape + if context is None: + context = x + + q = self.to_q(x) + k = self.to_k(context) + v = self.to_v(context) + + if context_mask is not None: + mask_binary = create_mask(x.shape, context.shape, + x.device, None, context_mask) + else: + mask_binary = None + + q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads) + k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads) + v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads) + + q = self.norm_q(q) + k = self.norm_k(k) + + q, k = self._rotary(q, k, extras) + + x = self._attn(q, k, v, mask_binary) + + x = self.proj(x) + x = self.proj_drop(x) + return x + + +class JointAttention(nn.Module): + def __init__(self, dim, num_heads=8, + qkv_bias=False, qk_scale=None, qk_norm=None, + attn_drop=0., proj_drop=0., + rope_mode='none'): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + self.to_qx, self.to_kx, self.to_vx = self._make_qkv_layers(dim, qkv_bias) + self.to_qc, self.to_kc, self.to_vc = self._make_qkv_layers(dim, qkv_bias) + + self.norm_qx, self.norm_kx = self._make_norm_layers(qk_norm, head_dim) + self.norm_qc, self.norm_kc = self._make_norm_layers(qk_norm, head_dim) + + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + + self.proj_x = nn.Linear(dim, dim) + self.proj_drop_x = nn.Dropout(proj_drop) + + self.proj_c = nn.Linear(dim, dim) + self.proj_drop_c = nn.Dropout(proj_drop) + + self.rope_mode = rope_mode + if self.rope_mode == 'shared' or self.rope_mode == 'x_only': + self.rotary = RotaryEmbedding(dim=head_dim) + elif self.rope_mode == 'dual': + self.rotary_x = RotaryEmbedding(dim=head_dim) + self.rotary_c = RotaryEmbedding(dim=head_dim) + + def _make_qkv_layers(self, dim, qkv_bias): + return (nn.Linear(dim, dim, bias=qkv_bias), + nn.Linear(dim, dim, bias=qkv_bias), + nn.Linear(dim, dim, bias=qkv_bias)) + + def _make_norm_layers(self, qk_norm, head_dim): + if qk_norm is None: + norm_q = nn.Identity() + norm_k = nn.Identity() + elif qk_norm == 'layernorm': + norm_q = nn.LayerNorm(head_dim) + norm_k = nn.LayerNorm(head_dim) + elif qk_norm == 'rmsnorm': + norm_q = RMSNorm(head_dim) + norm_k = RMSNorm(head_dim) + else: + raise NotImplementedError + return norm_q, norm_k + + def _rotary(self, q, k, extras): + if self.rope_mode == 'shared': + q, k = self.rotary(q=q, k=k) + elif self.rope_mode == 'x_only': + q_x, k_x = self.rotary(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = q[:, :, :extras, :], k[:, :, :extras, :] + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'dual': + q_x, k_x = self.rotary_x(q=q[:, :, extras:, :], k=k[:, :, extras:, :]) + q_c, k_c = self.rotary_c(q=q[:, :, :extras, :], k=k[:, :, :extras, :]) + q = torch.cat((q_c, q_x), dim=2) + k = torch.cat((k_c, k_x), dim=2) + elif self.rope_mode == 'none': + pass + else: + raise NotImplementedError + return q, k + + def _attn(self, q, k, v, mask_binary): + if ATTENTION_MODE == 'flash': + x = F.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + else: + raise NotImplementedError + return x + + def _cat_mask(self, x, context, x_mask=None, context_mask=None): + B = x.shape[0] + if x_mask is None: + x_mask = torch.ones(B, x.shape[-2], device=x.device).bool() + if context_mask is None: + context_mask = torch.ones(B, context.shape[-2], device=context.device).bool() + mask = torch.cat([context_mask, x_mask], dim=1) + return mask + + def forward(self, x, context, x_mask=None, context_mask=None, extras=0): + B, Lx, C = x.shape + _, Lc, _ = context.shape + if x_mask is not None or context_mask is not None: + mask = self._cat_mask(x, context, + x_mask=x_mask, + context_mask=context_mask) + shape = [B, Lx+Lc, C] + mask_binary = create_mask(q_shape=shape, k_shape=shape, + device=x.device, + q_mask=None, k_mask=mask) + else: + mask_binary = None + + qx, kx, vx = self.to_qx(x), self.to_kx(x), self.to_vx(x) + qc, kc, vc = self.to_qc(context), self.to_kc(context), self.to_vc(context) + + qx, kx, vx = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D', + H=self.num_heads), [qx, kx, vx]) + qc, kc, vc = map(lambda t: einops.rearrange(t, 'B L (H D) -> B H L D', + H=self.num_heads), [qc, kc, vc]) + + qx, kx = self.norm_qx(qx), self.norm_kx(kx) + qc, kc = self.norm_qc(qc), self.norm_kc(kc) + + q, k, v = (torch.cat([qc, qx], dim=2), + torch.cat([kc, kx], dim=2), + torch.cat([vc, vx], dim=2)) + + q, k = self._rotary(q, k, extras) + + x = self._attn(q, k, v, mask_binary) + + context, x = x[:, :Lc, :], x[:, Lc:, :] + + x = self.proj_x(x) + x = self.proj_drop_x(x) + + context = self.proj_c(context) + context = self.proj_drop_c(context) + + return x, context \ No newline at end of file diff --git a/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py b/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..7ba4f700842a91611ad1eda0f872df04162d1e59 --- /dev/null +++ b/src/models/utils/bk/.ipynb_checkpoints/attention-checkpoint.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .rotary import RotaryEmbedding + +if hasattr(nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + ATTENTION_MODE = 'math' +print(f'attention mode is {ATTENTION_MODE}') + + +def add_mask(sim, mask): + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + + +def create_mask(q, k, q_mask=None, k_mask=None): + def default(val, d): + return val if val is not None else (d() if isfunction(d) else d) + + b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device + q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool)) + k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool)) + attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j') + return attn_mask + + +class Attention(nn.Module): + def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, + attn_drop=0., proj_drop=0., use_rope=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + context_dim = dim if context_dim is None else context_dim + + self.to_q = nn.Linear(dim, dim, bias=qkv_bias) + self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias) + self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias) + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.use_rope = use_rope + if self.use_rope: + self.rotary = RotaryEmbedding(dim=head_dim) + + def forward(self, x, context=None, context_mask=None): + B, L, C = x.shape + q = self.to_q(x) + if context is None: + context = x + else: + assert self.use_rope is False + + k = self.to_k(context) + v = self.to_v(context) + + if context_mask is not None: + mask_binary = create_mask(x, context, None, context_mask) + else: + mask_binary = None + + q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float() + k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float() + v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float() + + if self.use_rope: + q, k = self.rotary(q=q, k=k) + + if ATTENTION_MODE == 'flash': + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + else: + raise NotImplementedError + + x = self.proj(x) + x = self.proj_drop(x) + return x \ No newline at end of file diff --git a/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py b/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..1126d1a589ca106a6399896742f511df43b0d0ae --- /dev/null +++ b/src/models/utils/bk/.ipynb_checkpoints/llama_rotary-checkpoint.py @@ -0,0 +1,74 @@ +import torch +from typing import Tuple +from rotary import RotaryEmbedding +import time + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, + x: torch.Tensor,): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def compute_rope(q, freqs_cis): + return q * freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + # xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq1, xq2 = xq.chunk(2, dim=-1) + xq_ = torch.view_as_complex(torch.stack((xq1, xq2), dim=-1).float()) + + xk1, xk2 = xk.chunk(2, dim=-1) + xk_ = torch.view_as_complex(torch.stack((xk1, xk2), dim=-1).float()) + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(compute_rope(xq_, freqs_cis)).flatten(3) + xk_out = torch.view_as_real(compute_rope(xk_, freqs_cis)).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +if __name__ == '__main__': + # Move data to CUDA + freq_cis = precompute_freqs_cis(4, 5).cuda() + x = torch.rand(1, 5, 1, 4).cuda() + y = torch.rand(1, 5, 1, 4).cuda() + + # First method + start_time = time.time() + for _ in range(20000): + x1, y1 = apply_rotary_emb(x, y, freq_cis) + end_time = time.time() + print(f"Method 1 time cost: {end_time - start_time} seconds") + + # Prepare data for the second method + x = x.permute(0, 2, 1, 3) + y = y.permute(0, 2, 1, 3) + rope = RotaryEmbedding(4).cuda() + + # Second method + start_time = time.time() + for _ in range(20000): + x2, y2 = rope(x, y) + end_time = time.time() + print(f"Method 2 time cost: {end_time - start_time} seconds") + + # Print the results + print(x1) + print(x2) \ No newline at end of file diff --git a/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc b/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..de78be925584642b52de19239fd67bdcf6173d95 Binary files /dev/null and b/src/models/utils/bk/__pycache__/rotary.cpython-311.pyc differ diff --git a/src/models/utils/bk/attention.py b/src/models/utils/bk/attention.py new file mode 100644 index 0000000000000000000000000000000000000000..7ba4f700842a91611ad1eda0f872df04162d1e59 --- /dev/null +++ b/src/models/utils/bk/attention.py @@ -0,0 +1,99 @@ +import torch +import torch.nn as nn +import torch.utils.checkpoint +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .rotary import RotaryEmbedding + +if hasattr(nn.functional, 'scaled_dot_product_attention'): + ATTENTION_MODE = 'flash' +else: + ATTENTION_MODE = 'math' +print(f'attention mode is {ATTENTION_MODE}') + + +def add_mask(sim, mask): + b, ndim = sim.shape[0], mask.ndim + if ndim == 3: + mask = rearrange(mask, "b n m -> b 1 n m") + if ndim == 2: + mask = repeat(mask, "n m -> b 1 n m", b=b) + max_neg_value = -torch.finfo(sim.dtype).max + sim = sim.masked_fill(~mask, max_neg_value) + return sim + + +def create_mask(q, k, q_mask=None, k_mask=None): + def default(val, d): + return val if val is not None else (d() if isfunction(d) else d) + + b, i, j, device = q.shape[0], q.shape[-2], k.shape[-2], q.device + q_mask = default(q_mask, torch.ones((b, i), device=device, dtype=torch.bool)) + k_mask = default(k_mask, torch.ones((b, j), device=device, dtype=torch.bool)) + attn_mask = rearrange(q_mask, 'b i -> b 1 i 1') * rearrange(k_mask, 'b j -> b 1 1 j') + return attn_mask + + +class Attention(nn.Module): + def __init__(self, dim, context_dim=None, num_heads=8, qkv_bias=False, qk_scale=None, + attn_drop=0., proj_drop=0., use_rope=False): + super().__init__() + self.num_heads = num_heads + head_dim = dim // num_heads + self.scale = qk_scale or head_dim ** -0.5 + + context_dim = dim if context_dim is None else context_dim + + self.to_q = nn.Linear(dim, dim, bias=qkv_bias) + self.to_k = nn.Linear(context_dim, dim, bias=qkv_bias) + self.to_v = nn.Linear(context_dim, dim, bias=qkv_bias) + self.attn_drop_p = attn_drop + self.attn_drop = nn.Dropout(attn_drop) + self.proj = nn.Linear(dim, dim) + self.proj_drop = nn.Dropout(proj_drop) + + self.use_rope = use_rope + if self.use_rope: + self.rotary = RotaryEmbedding(dim=head_dim) + + def forward(self, x, context=None, context_mask=None): + B, L, C = x.shape + q = self.to_q(x) + if context is None: + context = x + else: + assert self.use_rope is False + + k = self.to_k(context) + v = self.to_v(context) + + if context_mask is not None: + mask_binary = create_mask(x, context, None, context_mask) + else: + mask_binary = None + + q = einops.rearrange(q, 'B L (H D) -> B H L D', H=self.num_heads).float() + k = einops.rearrange(k, 'B L (H D) -> B H L D', H=self.num_heads).float() + v = einops.rearrange(v, 'B L (H D) -> B H L D', H=self.num_heads).float() + + if self.use_rope: + q, k = self.rotary(q=q, k=k) + + if ATTENTION_MODE == 'flash': + x = torch.nn.functional.scaled_dot_product_attention(q, k, v, + dropout_p=self.attn_drop_p, + attn_mask=mask_binary) + x = einops.rearrange(x, 'B H L D -> B L (H D)') + elif ATTENTION_MODE == 'math': + attn = (q @ k.transpose(-2, -1)) * self.scale + attn = add_mask(attn, mask_binary) if mask_binary is not None else attn + attn = attn.softmax(dim=-1) + attn = self.attn_drop(attn) + x = (attn @ v).transpose(1, 2).reshape(B, L, C) + else: + raise NotImplementedError + + x = self.proj(x) + x = self.proj_drop(x) + return x \ No newline at end of file diff --git a/src/models/utils/bk/llama_rotary.py b/src/models/utils/bk/llama_rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..1126d1a589ca106a6399896742f511df43b0d0ae --- /dev/null +++ b/src/models/utils/bk/llama_rotary.py @@ -0,0 +1,74 @@ +import torch +from typing import Tuple +from rotary import RotaryEmbedding +import time + + +def precompute_freqs_cis(dim: int, end: int, theta: float = 10000.0): + freqs = 1.0 / (theta ** (torch.arange(0, dim, 2)[: (dim // 2)].float() / dim)) + t = torch.arange(end, device=freqs.device, dtype=torch.float32) + freqs = torch.outer(t, freqs) + freqs_cis = torch.polar(torch.ones_like(freqs), freqs) # complex64 + return freqs_cis + + +def reshape_for_broadcast(freqs_cis: torch.Tensor, + x: torch.Tensor,): + ndim = x.ndim + assert 0 <= 1 < ndim + assert freqs_cis.shape == (x.shape[1], x.shape[-1]) + shape = [d if i == 1 or i == ndim - 1 else 1 for i, d in enumerate(x.shape)] + return freqs_cis.view(*shape) + + +def compute_rope(q, freqs_cis): + return q * freqs_cis + + +def apply_rotary_emb( + xq: torch.Tensor, + xk: torch.Tensor, + freqs_cis: torch.Tensor, +) -> Tuple[torch.Tensor, torch.Tensor]: + # xq_ = torch.view_as_complex(xq.float().reshape(*xq.shape[:-1], -1, 2)) + # xk_ = torch.view_as_complex(xk.float().reshape(*xk.shape[:-1], -1, 2)) + xq1, xq2 = xq.chunk(2, dim=-1) + xq_ = torch.view_as_complex(torch.stack((xq1, xq2), dim=-1).float()) + + xk1, xk2 = xk.chunk(2, dim=-1) + xk_ = torch.view_as_complex(torch.stack((xk1, xk2), dim=-1).float()) + + freqs_cis = reshape_for_broadcast(freqs_cis, xq_) + xq_out = torch.view_as_real(compute_rope(xq_, freqs_cis)).flatten(3) + xk_out = torch.view_as_real(compute_rope(xk_, freqs_cis)).flatten(3) + return xq_out.type_as(xq), xk_out.type_as(xk) + + +if __name__ == '__main__': + # Move data to CUDA + freq_cis = precompute_freqs_cis(4, 5).cuda() + x = torch.rand(1, 5, 1, 4).cuda() + y = torch.rand(1, 5, 1, 4).cuda() + + # First method + start_time = time.time() + for _ in range(20000): + x1, y1 = apply_rotary_emb(x, y, freq_cis) + end_time = time.time() + print(f"Method 1 time cost: {end_time - start_time} seconds") + + # Prepare data for the second method + x = x.permute(0, 2, 1, 3) + y = y.permute(0, 2, 1, 3) + rope = RotaryEmbedding(4).cuda() + + # Second method + start_time = time.time() + for _ in range(20000): + x2, y2 = rope(x, y) + end_time = time.time() + print(f"Method 2 time cost: {end_time - start_time} seconds") + + # Print the results + print(x1) + print(x2) \ No newline at end of file diff --git a/src/models/utils/modules.py b/src/models/utils/modules.py new file mode 100644 index 0000000000000000000000000000000000000000..c825b988439d9b91e1e1d30c1cf842880252c0bf --- /dev/null +++ b/src/models/utils/modules.py @@ -0,0 +1,374 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +import torch.utils.checkpoint +from torch.cuda.amp import autocast +import math +import einops +from einops import rearrange, repeat +from inspect import isfunction +from .timm import trunc_normal_ + + +# disable in checkpoint mode +# @torch.jit.script +def film_modulate(x, shift, scale): + return x * (1 + scale) + shift + + +def timestep_embedding(timesteps, dim, max_period=10000): + """ + Create sinusoidal timestep embeddings. + + :param timesteps: a 1-D Tensor of N indices, one per batch element. + These may be fractional. + :param dim: the dimension of the output. + :param max_period: controls the minimum frequency of the embeddings. + :return: an [N x dim] Tensor of positional embeddings. + """ + half = dim // 2 + freqs = torch.exp( + -math.log(max_period) * torch.arange(start=0, end=half, dtype=torch.float32) / half + ).to(device=timesteps.device) + args = timesteps[:, None].float() * freqs[None] + embedding = torch.cat([torch.cos(args), torch.sin(args)], dim=-1) + if dim % 2: + embedding = torch.cat([embedding, torch.zeros_like(embedding[:, :1])], dim=-1) + return embedding + + +class TimestepEmbedder(nn.Module): + """ + Embeds scalar timesteps into vector representations. + """ + + def __init__(self, hidden_size, frequency_embedding_size=256, + out_size=None): + super().__init__() + if out_size is None: + out_size = hidden_size + self.mlp = nn.Sequential( + nn.Linear(frequency_embedding_size, hidden_size, bias=True), + nn.SiLU(), + nn.Linear(hidden_size, out_size, bias=True), + ) + self.frequency_embedding_size = frequency_embedding_size + + def forward(self, t): + t_freq = timestep_embedding(t, self.frequency_embedding_size).type( + self.mlp[0].weight.dtype) + t_emb = self.mlp(t_freq) + return t_emb + + +def patchify(imgs, patch_size, input_type='2d'): + if input_type == '2d': + x = einops.rearrange(imgs, 'B C (h p1) (w p2) -> B (h w) (p1 p2 C)', p1=patch_size, p2=patch_size) + elif input_type == '1d': + x = einops.rearrange(imgs, 'B C (h p1) -> B h (p1 C)', p1=patch_size) + return x + + +def unpatchify(x, channels=3, input_type='2d', img_size=None): + if input_type == '2d': + patch_size = int((x.shape[2] // channels) ** 0.5) + # h = w = int(x.shape[1] ** .5) + h, w = img_size[0] // patch_size, img_size[1] // patch_size + assert h * w == x.shape[1] and patch_size ** 2 * channels == x.shape[2] + x = einops.rearrange(x, 'B (h w) (p1 p2 C) -> B C (h p1) (w p2)', h=h, + p1=patch_size, p2=patch_size) + elif input_type == '1d': + patch_size = int((x.shape[2] // channels)) + h = x.shape[1] + assert patch_size * channels == x.shape[2] + x = einops.rearrange(x, 'B h (p1 C) -> B C (h p1)', h=h, p1=patch_size) + return x + + +class PatchEmbed(nn.Module): + """ + Image to Patch Embedding + """ + + def __init__(self, patch_size, in_chans=3, embed_dim=768, input_type='2d'): + super().__init__() + self.patch_size = patch_size + self.input_type = input_type + if input_type == '2d': + self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + elif input_type == '1d': + self.proj = nn.Conv1d(in_chans, embed_dim, kernel_size=patch_size, stride=patch_size, bias=True) + + def forward(self, x): + if self.input_type == '2d': + B, C, H, W = x.shape + assert H % self.patch_size == 0 and W % self.patch_size == 0 + elif self.input_type == '1d': + B, C, H = x.shape + assert H % self.patch_size == 0 + + x = self.proj(x).flatten(2).transpose(1, 2) + return x + + +class PositionalConvEmbedding(nn.Module): + """ + Relative positional embedding used in HuBERT + """ + + def __init__(self, dim=768, kernel_size=128, groups=16): + super().__init__() + self.conv = nn.Conv1d( + dim, + dim, + kernel_size=kernel_size, + padding=kernel_size // 2, + groups=groups, + bias=True + ) + self.conv = nn.utils.parametrizations.weight_norm(self.conv, name="weight", dim=2) + + def forward(self, x): + # B C T + x = self.conv(x) + x = F.gelu(x[:, :, :-1]) + return x + + +class SinusoidalPositionalEncoding(nn.Module): + def __init__(self, dim, length): + super(SinusoidalPositionalEncoding, self).__init__() + self.length = length + self.dim = dim + self.register_buffer('pe', self._generate_positional_encoding(length, dim)) + + def _generate_positional_encoding(self, length, dim): + pe = torch.zeros(length, dim) + position = torch.arange(0, length, dtype=torch.float).unsqueeze(1) + div_term = torch.exp(torch.arange(0, dim, 2).float() * (-math.log(10000.0) / dim)) + + pe[:, 0::2] = torch.sin(position * div_term) + pe[:, 1::2] = torch.cos(position * div_term) + + pe = pe.unsqueeze(0) + return pe + + def forward(self, x): + x = x + self.pe[:, :x.size(1)] + return x + + +class PE_wrapper(nn.Module): + def __init__(self, dim=768, method='abs', length=None, **kwargs): + super().__init__() + self.method = method + if method == 'abs': + # init absolute pe like UViT + self.length = length + self.abs_pe = nn.Parameter(torch.zeros(1, length, dim)) + trunc_normal_(self.abs_pe, std=.02) + elif method == 'conv': + self.conv_pe = PositionalConvEmbedding(dim=dim, **kwargs) + elif method == 'sinu': + self.sinu_pe = SinusoidalPositionalEncoding(dim=dim, length=length) + elif method == 'none': + # skip pe + self.id = nn.Identity() + else: + raise NotImplementedError + + def forward(self, x): + if self.method == 'abs': + _, L, _ = x.shape + assert L <= self.length + x = x + self.abs_pe[:, :L, :] + elif self.method == 'conv': + x = x + self.conv_pe(x) + elif self.method == 'sinu': + x = self.sinu_pe(x) + elif self.method == 'none': + x = self.id(x) + else: + raise NotImplementedError + return x + + +class RMSNorm(torch.nn.Module): + def __init__(self, dim: int, eps: float = 1e-6): + """ + Initialize the RMSNorm normalization layer. + + Args: + dim (int): The dimension of the input tensor. + eps (float, optional): A small value added to the denominator for numerical stability. Default is 1e-6. + + Attributes: + eps (float): A small value added to the denominator for numerical stability. + weight (nn.Parameter): Learnable scaling parameter. + + """ + super().__init__() + self.eps = eps + self.weight = nn.Parameter(torch.ones(dim)) + + def _norm(self, x): + """ + Apply the RMSNorm normalization to the input tensor. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The normalized tensor. + + """ + return x * torch.rsqrt(x.pow(2).mean(-1, keepdim=True) + self.eps) + + def forward(self, x): + """ + Forward pass through the RMSNorm layer. + + Args: + x (torch.Tensor): The input tensor. + + Returns: + torch.Tensor: The output tensor after applying RMSNorm. + + """ + output = self._norm(x.float()).type_as(x) + return output * self.weight + + +class GELU(nn.Module): + + def __init__(self, dim_in: int, dim_out: int, approximate: str = "none", + bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.approximate = approximate + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate, approximate=self.approximate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32), + approximate=self.approximate).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states = self.gelu(hidden_states) + return hidden_states + + +class GEGLU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + + def gelu(self, gate: torch.Tensor) -> torch.Tensor: + if gate.device.type != "mps": + return F.gelu(gate) + # mps: gelu is not implemented for float16 + return F.gelu(gate.to(dtype=torch.float32)).to(dtype=gate.dtype) + + def forward(self, hidden_states): + hidden_states = self.proj(hidden_states) + hidden_states, gate = hidden_states.chunk(2, dim=-1) + return hidden_states * self.gelu(gate) + + +class ApproximateGELU(nn.Module): + def __init__(self, dim_in: int, dim_out: int, bias: bool = True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + + def forward(self, x: torch.Tensor) -> torch.Tensor: + x = self.proj(x) + return x * torch.sigmoid(1.702 * x) + + +# disable in checkpoint mode +# @torch.jit.script +def snake_beta(x, alpha, beta): + return x + beta * torch.sin(x * alpha).pow(2) + + +class Snake(nn.Module): + def __init__(self, dim_in, dim_out, bias, + alpha_trainable=True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out, bias=bias) + self.alpha = nn.Parameter(torch.ones(1, 1, dim_out)) + self.beta = nn.Parameter(torch.ones(1, 1, dim_out)) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + def forward(self, x): + x = self.proj(x) + x = snake_beta(x, self.alpha, self.beta) + return x + + +class GESnake(nn.Module): + def __init__(self, dim_in, dim_out, bias, + alpha_trainable=True): + super().__init__() + self.proj = nn.Linear(dim_in, dim_out * 2, bias=bias) + self.alpha = nn.Parameter(torch.ones(1, 1, dim_out)) + self.beta = nn.Parameter(torch.ones(1, 1, dim_out)) + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + def forward(self, x): + x = self.proj(x) + x, gate = x.chunk(2, dim=-1) + return x * snake_beta(gate, self.alpha, self.beta) + + +class FeedForward(nn.Module): + def __init__( + self, + dim, + dim_out=None, + mult=4, + dropout=0.0, + activation_fn="geglu", + final_dropout=False, + inner_dim=None, + bias=True, + ): + super().__init__() + if inner_dim is None: + inner_dim = int(dim * mult) + dim_out = dim_out if dim_out is not None else dim + + if activation_fn == "gelu": + act_fn = GELU(dim, inner_dim, bias=bias) + elif activation_fn == "gelu-approximate": + act_fn = GELU(dim, inner_dim, approximate="tanh", bias=bias) + elif activation_fn == "geglu": + act_fn = GEGLU(dim, inner_dim, bias=bias) + elif activation_fn == "geglu-approximate": + act_fn = ApproximateGELU(dim, inner_dim, bias=bias) + elif activation_fn == "snake": + act_fn = Snake(dim, inner_dim, bias=bias) + elif activation_fn == "gesnake": + act_fn = GESnake(dim, inner_dim, bias=bias) + else: + raise NotImplementedError + + self.net = nn.ModuleList([]) + # project in + self.net.append(act_fn) + # project dropout + self.net.append(nn.Dropout(dropout)) + # project out + self.net.append(nn.Linear(inner_dim, dim_out, bias=bias)) + # FF as used in Vision Transformer, MLP-Mixer, etc. have a final dropout + if final_dropout: + self.net.append(nn.Dropout(dropout)) + + def forward(self, hidden_states: torch.Tensor) -> torch.Tensor: + for module in self.net: + hidden_states = module(hidden_states) + return hidden_states \ No newline at end of file diff --git a/src/models/utils/rotary.py b/src/models/utils/rotary.py new file mode 100644 index 0000000000000000000000000000000000000000..636fbf6558b0d469f6802b10c180bbbb6fc431cc --- /dev/null +++ b/src/models/utils/rotary.py @@ -0,0 +1,91 @@ +import torch + +"this rope is faster than llama rope with jit script" + + +def rotate_half(x): + x1, x2 = x.chunk(2, dim=-1) + return torch.cat((-x2, x1), dim=-1) + + +# disable in checkpoint mode +# @torch.jit.script +def apply_rotary_pos_emb(x, cos, sin): + # NOTE: This could probably be moved to Triton + # Handle a possible sequence length mismatch in between q and k + cos = cos[:, :, : x.shape[-2], :] + sin = sin[:, :, : x.shape[-2], :] + return (x * cos) + (rotate_half(x) * sin) + + +class RotaryEmbedding(torch.nn.Module): + """ + The rotary position embeddings from RoFormer_ (Su et. al). + A crucial insight from the method is that the query and keys are + transformed by rotation matrices which depend on the relative positions. + + Other implementations are available in the Rotary Transformer repo_ and in + GPT-NeoX_, GPT-NeoX was an inspiration + + .. _RoFormer: https://arxiv.org/abs/2104.09864 + .. _repo: https://github.com/ZhuiyiTechnology/roformer + .. _GPT-NeoX: https://github.com/EleutherAI/gpt-neox + + + .. warning: Please note that this embedding is not registered on purpose, as it is transformative + (it does not create the embedding dimension) and will likely be picked up (imported) on a ad-hoc basis + """ + + def __init__(self, dim: int): + super().__init__() + # Generate and save the inverse frequency buffer (non trainable) + inv_freq = 1.0 / (10000 ** (torch.arange(0, dim, 2).float() / dim)) + self.register_buffer("inv_freq", inv_freq) + self._seq_len_cached = None + self._cos_cached = None + self._sin_cached = None + + def _update_cos_sin_tables(self, x, seq_dimension=-2): + # expect input: B, H, L, D + seq_len = x.shape[seq_dimension] + + # Reset the tables if the sequence length has changed, + # or if we're on a new device (possibly due to tracing for instance) + # also make sure dtype wont change + if ( + seq_len != self._seq_len_cached + or self._cos_cached.device != x.device + or self._cos_cached.dtype != x.dtype + ): + self._seq_len_cached = seq_len + t = torch.arange( + x.shape[seq_dimension], device=x.device, dtype=torch.float32 + ) + freqs = torch.einsum("i,j->ij", t, self.inv_freq.to(x.dtype)) + emb = torch.cat((freqs, freqs), dim=-1).to(x.device) + + self._cos_cached = emb.cos()[None, None, :, :].to(x.dtype) + self._sin_cached = emb.sin()[None, None, :, :].to(x.dtype) + + return self._cos_cached, self._sin_cached + + def forward(self, q, k): + self._cos_cached, self._sin_cached = self._update_cos_sin_tables( + q.float(), seq_dimension=-2 + ) + if k is not None: + return ( + apply_rotary_pos_emb(q.float(), + self._cos_cached, + self._sin_cached).type_as(q), + apply_rotary_pos_emb(k.float(), + self._cos_cached, + self._sin_cached).type_as(k), + ) + else: + return ( + apply_rotary_pos_emb(q.float(), + self._cos_cached, + self._sin_cached).type_as(q), + None + ) \ No newline at end of file diff --git a/src/models/utils/span_mask.py b/src/models/utils/span_mask.py new file mode 100644 index 0000000000000000000000000000000000000000..6d003a6c08c1675967f992e3d052b293a202d446 --- /dev/null +++ b/src/models/utils/span_mask.py @@ -0,0 +1,146 @@ +import numpy as np +import torch +from typing import Optional, Tuple + + +def compute_mask_indices( + shape: Tuple[int, int], + padding_mask: Optional[torch.Tensor], + mask_prob: float, + mask_length: int, + mask_type: str = "static", + mask_other: float = 0.0, + min_masks: int = 0, + no_overlap: bool = False, + min_space: int = 0, +) -> np.ndarray: + """ + Computes random mask spans for a given shape + + Args: + shape: the the shape for which to compute masks. + should be of size 2 where first element is batch size and 2nd is timesteps + padding_mask: optional padding mask of the same size as shape, which will prevent masking padded elements + mask_prob: probability for each token to be chosen as start of the span to be masked. this will be multiplied by + number of timesteps divided by length of mask span to mask approximately this percentage of all elements. + however due to overlaps, the actual number will be smaller (unless no_overlap is True) + mask_type: how to compute mask lengths + static = fixed size + uniform = sample from uniform distribution [mask_other, mask_length*2] + normal = sample from normal distribution with mean mask_length and stdev mask_other. mask is min 1 element + poisson = sample from possion distribution with lambda = mask length + min_masks: minimum number of masked spans + no_overlap: if false, will switch to an alternative recursive algorithm that prevents spans from overlapping + min_space: only used if no_overlap is True, this is how many elements to keep unmasked between spans + """ + + bsz, all_sz = shape + mask = np.full((bsz, all_sz), False) + + # Convert mask_prob to a NumPy array + mask_prob = np.array(mask_prob) + + # Calculate all_num_mask for each element in the batch + all_num_mask = np.floor(mask_prob * all_sz / float(mask_length) + np.random.rand(bsz)).astype(int) + + # Apply the max operation with min_masks for each element + all_num_mask = np.maximum(min_masks, all_num_mask) + + mask_idcs = [] + for i in range(bsz): + if padding_mask is not None: + sz = all_sz - padding_mask[i].long().sum().item() + num_mask = int( + # add a random number for probabilistic rounding + mask_prob * sz / float(mask_length) + + np.random.rand() + ) + num_mask = max(min_masks, num_mask) + else: + sz = all_sz + num_mask = all_num_mask[i] + + if mask_type == "static": + lengths = np.full(num_mask, mask_length) + elif mask_type == "uniform": + lengths = np.random.randint(mask_other, mask_length * 2 + 1, size=num_mask) + elif mask_type == "normal": + lengths = np.random.normal(mask_length, mask_other, size=num_mask) + lengths = [max(1, int(round(x))) for x in lengths] + elif mask_type == "poisson": + lengths = np.random.poisson(mask_length, size=num_mask) + lengths = [int(round(x)) for x in lengths] + else: + raise Exception("unknown mask selection " + mask_type) + + if sum(lengths) == 0: + lengths[0] = min(mask_length, sz - 1) + + if no_overlap: + mask_idc = [] + + def arrange(s, e, length, keep_length): + span_start = np.random.randint(s, e - length) + mask_idc.extend(span_start + i for i in range(length)) + + new_parts = [] + if span_start - s - min_space >= keep_length: + new_parts.append((s, span_start - min_space + 1)) + if e - span_start - keep_length - min_space > keep_length: + new_parts.append((span_start + length + min_space, e)) + return new_parts + + parts = [(0, sz)] + min_length = min(lengths) + for length in sorted(lengths, reverse=True): + lens = np.fromiter( + (e - s if e - s >= length + min_space else 0 for s, e in parts), + np.int, + ) + l_sum = np.sum(lens) + if l_sum == 0: + break + probs = lens / np.sum(lens) + c = np.random.choice(len(parts), p=probs) + s, e = parts.pop(c) + parts.extend(arrange(s, e, length, min_length)) + mask_idc = np.asarray(mask_idc) + else: + min_len = min(lengths) + if sz - min_len <= num_mask: + min_len = sz - num_mask - 1 + + mask_idc = np.random.choice(sz - min_len, num_mask, replace=False) + + mask_idc = np.asarray( + [ + mask_idc[j] + offset + for j in range(len(mask_idc)) + for offset in range(lengths[j]) + ] + ) + + mask_idcs.append(np.unique(mask_idc[mask_idc < sz])) + # min_len = min([len(m) for m in mask_idcs]) + for i, mask_idc in enumerate(mask_idcs): + # if len(mask_idc) > min_len: + # mask_idc = np.random.choice(mask_idc, min_len, replace=False) + mask[i, mask_idc] = True + + return torch.tensor(mask) + + +if __name__ == '__main__': + mask = compute_mask_indices( + shape=[4, 500], + padding_mask=None, + mask_prob=[0.65, 0.5, 0.65, 0.65], + mask_length=10, + mask_type="static", + mask_other=0.0, + min_masks=1, + no_overlap=False, + min_space=0, + ) + print(mask) + print(mask.sum(dim=1)) \ No newline at end of file diff --git a/src/models/utils/timm.py b/src/models/utils/timm.py new file mode 100644 index 0000000000000000000000000000000000000000..ae72f4fca681617778028b64da5daed26b3ed3d4 --- /dev/null +++ b/src/models/utils/timm.py @@ -0,0 +1,114 @@ +# code from timm 0.3.2 +import torch +import torch.nn as nn +import math +import warnings + + +def _no_grad_trunc_normal_(tensor, mean, std, a, b): + # Cut & paste from PyTorch official master until it's in a few official releases - RW + # Method based on https://people.sc.fsu.edu/~jburkardt/presentations/truncated_normal.pdf + def norm_cdf(x): + # Computes standard normal cumulative distribution function + return (1. + math.erf(x / math.sqrt(2.))) / 2. + + if (mean < a - 2 * std) or (mean > b + 2 * std): + warnings.warn("mean is more than 2 std from [a, b] in nn.init.trunc_normal_. " + "The distribution of values may be incorrect.", + stacklevel=2) + + with torch.no_grad(): + # Values are generated by using a truncated uniform distribution and + # then using the inverse CDF for the normal distribution. + # Get upper and lower cdf values + l = norm_cdf((a - mean) / std) + u = norm_cdf((b - mean) / std) + + # Uniformly fill tensor with values from [l, u], then translate to + # [2l-1, 2u-1]. + tensor.uniform_(2 * l - 1, 2 * u - 1) + + # Use inverse cdf transform for normal distribution to get truncated + # standard normal + tensor.erfinv_() + + # Transform to proper mean, std + tensor.mul_(std * math.sqrt(2.)) + tensor.add_(mean) + + # Clamp to ensure it's in the proper range + tensor.clamp_(min=a, max=b) + return tensor + + +def trunc_normal_(tensor, mean=0., std=1., a=-2., b=2.): + # type: (Tensor, float, float, float, float) -> Tensor + r"""Fills the input Tensor with values drawn from a truncated + normal distribution. The values are effectively drawn from the + normal distribution :math:`\mathcal{N}(\text{mean}, \text{std}^2)` + with values outside :math:`[a, b]` redrawn until they are within + the bounds. The method used for generating the random values works + best when :math:`a \leq \text{mean} \leq b`. + Args: + tensor: an n-dimensional `torch.Tensor` + mean: the mean of the normal distribution + std: the standard deviation of the normal distribution + a: the minimum cutoff value + b: the maximum cutoff value + Examples: + >>> w = torch.empty(3, 5) + >>> nn.init.trunc_normal_(w) + """ + return _no_grad_trunc_normal_(tensor, mean, std, a, b) + + +def drop_path(x, drop_prob: float = 0., training: bool = False): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + + This is the same as the DropConnect impl I created for EfficientNet, etc networks, however, + the original name is misleading as 'Drop Connect' is a different form of dropout in a separate paper... + See discussion: https://github.com/tensorflow/tpu/issues/494#issuecomment-532968956 ... I've opted for + changing the layer and argument names to 'drop path' rather than mix DropConnect as a layer name and use + 'survival rate' as the argument. + + """ + if drop_prob == 0. or not training: + return x + keep_prob = 1 - drop_prob + shape = (x.shape[0],) + (1,) * (x.ndim - 1) # work with diff dim tensors, not just 2D ConvNets + random_tensor = keep_prob + torch.rand(shape, dtype=x.dtype, device=x.device) + random_tensor.floor_() # binarize + output = x.div(keep_prob) * random_tensor + return output + + +class DropPath(nn.Module): + """Drop paths (Stochastic Depth) per sample (when applied in main path of residual blocks). + """ + + def __init__(self, drop_prob=None): + super(DropPath, self).__init__() + self.drop_prob = drop_prob + + def forward(self, x): + return drop_path(x, self.drop_prob, self.training) + + +class Mlp(nn.Module): + def __init__(self, in_features, hidden_features=None, out_features=None, + act_layer=nn.GELU, drop=0.): + super().__init__() + out_features = out_features or in_features + hidden_features = hidden_features or in_features + self.fc1 = nn.Linear(in_features, hidden_features) + self.act = act_layer() + self.fc2 = nn.Linear(hidden_features, out_features) + self.drop = nn.Dropout(drop) + + def forward(self, x): + x = self.fc1(x) + x = self.act(x) + x = self.drop(x) + x = self.fc2(x) + x = self.drop(x) + return x \ No newline at end of file diff --git a/src/modules/.ipynb_checkpoints/autoencoder_wrapper-checkpoint.py b/src/modules/.ipynb_checkpoints/autoencoder_wrapper-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf799cb010332a6adbb2c74213df24c2602a6e8 --- /dev/null +++ b/src/modules/.ipynb_checkpoints/autoencoder_wrapper-checkpoint.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +from .dac import DAC +from .stable_vae import load_vae + + +class Autoencoder(nn.Module): + def __init__(self, ckpt_path, model_type='dac', quantization_first=False): + super(Autoencoder, self).__init__() + self.model_type = model_type + if self.model_type == 'dac': + model = DAC.load(ckpt_path) + elif self.model_type == 'stable_vae': + model = load_vae(ckpt_path) + else: + raise NotImplementedError(f"Model type not implemented: {self.model_type}") + self.ae = model.eval() + self.quantization_first = quantization_first + print(f'Autoencoder quantization first mode: {quantization_first}') + + @torch.no_grad() + def forward(self, audio=None, embedding=None): + if self.model_type == 'dac': + return self.process_dac(audio, embedding) + elif self.model_type == 'encodec': + return self.process_encodec(audio, embedding) + elif self.model_type == 'stable_vae': + return self.process_stable_vae(audio, embedding) + else: + raise NotImplementedError(f"Model type not implemented: {self.model_type}") + + def process_dac(self, audio=None, embedding=None): + if audio is not None: + z = self.ae.encoder(audio) + if self.quantization_first: + z, *_ = self.ae.quantizer(z, None) + return z + elif embedding is not None: + z = embedding + if self.quantization_first: + audio = self.ae.decoder(z) + else: + z, *_ = self.ae.quantizer(z, None) + audio = self.ae.decoder(z) + return audio + else: + raise ValueError("Either audio or embedding must be provided.") + + def process_encodec(self, audio=None, embedding=None): + if audio is not None: + z = self.ae.encoder(audio) + if self.quantization_first: + code = self.ae.quantizer.encode(z) + z = self.ae.quantizer.decode(code) + return z + elif embedding is not None: + z = embedding + if self.quantization_first: + audio = self.ae.decoder(z) + else: + code = self.ae.quantizer.encode(z) + z = self.ae.quantizer.decode(code) + audio = self.ae.decoder(z) + return audio + else: + raise ValueError("Either audio or embedding must be provided.") + + def process_stable_vae(self, audio=None, embedding=None): + if audio is not None: + z = self.ae.encoder(audio) + if self.quantization_first: + z = self.ae.bottleneck.encode(z) + return z + if embedding is not None: + z = embedding + if self.quantization_first: + audio = self.ae.decoder(z) + else: + z = self.ae.bottleneck.encode(z) + audio = self.ae.decoder(z) + return audio + else: + raise ValueError("Either audio or embedding must be provided.") diff --git a/src/modules/.ipynb_checkpoints/clap_wrapper-checkpoint.py b/src/modules/.ipynb_checkpoints/clap_wrapper-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/modules/__pycache__/autoencoder_wrapper.cpython-310.pyc b/src/modules/__pycache__/autoencoder_wrapper.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..148ad771c41a0d927fe7068eeb6f755caba628eb Binary files /dev/null and b/src/modules/__pycache__/autoencoder_wrapper.cpython-310.pyc differ diff --git a/src/modules/__pycache__/autoencoder_wrapper.cpython-311.pyc b/src/modules/__pycache__/autoencoder_wrapper.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eb87e64ba7232c1d1a59a276db00eb8aea7f9fea Binary files /dev/null and b/src/modules/__pycache__/autoencoder_wrapper.cpython-311.pyc differ diff --git a/src/modules/autoencoder_wrapper.py b/src/modules/autoencoder_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..5cf799cb010332a6adbb2c74213df24c2602a6e8 --- /dev/null +++ b/src/modules/autoencoder_wrapper.py @@ -0,0 +1,83 @@ +import torch +import torch.nn as nn +from .dac import DAC +from .stable_vae import load_vae + + +class Autoencoder(nn.Module): + def __init__(self, ckpt_path, model_type='dac', quantization_first=False): + super(Autoencoder, self).__init__() + self.model_type = model_type + if self.model_type == 'dac': + model = DAC.load(ckpt_path) + elif self.model_type == 'stable_vae': + model = load_vae(ckpt_path) + else: + raise NotImplementedError(f"Model type not implemented: {self.model_type}") + self.ae = model.eval() + self.quantization_first = quantization_first + print(f'Autoencoder quantization first mode: {quantization_first}') + + @torch.no_grad() + def forward(self, audio=None, embedding=None): + if self.model_type == 'dac': + return self.process_dac(audio, embedding) + elif self.model_type == 'encodec': + return self.process_encodec(audio, embedding) + elif self.model_type == 'stable_vae': + return self.process_stable_vae(audio, embedding) + else: + raise NotImplementedError(f"Model type not implemented: {self.model_type}") + + def process_dac(self, audio=None, embedding=None): + if audio is not None: + z = self.ae.encoder(audio) + if self.quantization_first: + z, *_ = self.ae.quantizer(z, None) + return z + elif embedding is not None: + z = embedding + if self.quantization_first: + audio = self.ae.decoder(z) + else: + z, *_ = self.ae.quantizer(z, None) + audio = self.ae.decoder(z) + return audio + else: + raise ValueError("Either audio or embedding must be provided.") + + def process_encodec(self, audio=None, embedding=None): + if audio is not None: + z = self.ae.encoder(audio) + if self.quantization_first: + code = self.ae.quantizer.encode(z) + z = self.ae.quantizer.decode(code) + return z + elif embedding is not None: + z = embedding + if self.quantization_first: + audio = self.ae.decoder(z) + else: + code = self.ae.quantizer.encode(z) + z = self.ae.quantizer.decode(code) + audio = self.ae.decoder(z) + return audio + else: + raise ValueError("Either audio or embedding must be provided.") + + def process_stable_vae(self, audio=None, embedding=None): + if audio is not None: + z = self.ae.encoder(audio) + if self.quantization_first: + z = self.ae.bottleneck.encode(z) + return z + if embedding is not None: + z = embedding + if self.quantization_first: + audio = self.ae.decoder(z) + else: + z = self.ae.bottleneck.encode(z) + audio = self.ae.decoder(z) + return audio + else: + raise ValueError("Either audio or embedding must be provided.") diff --git a/src/modules/clap_wrapper.py b/src/modules/clap_wrapper.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/modules/dac/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/dac/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..51205ef6ded9c6735a988b76008e0f6bdce8e215 --- /dev/null +++ b/src/modules/dac/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,16 @@ +__version__ = "1.0.0" + +# preserved here for legacy reasons +__model_version__ = "latest" + +import audiotools + +audiotools.ml.BaseModel.INTERN += ["dac.**"] +audiotools.ml.BaseModel.EXTERN += ["einops"] + + +from . import nn +from . import model +from . import utils +from .model import DAC +from .model import DACFile diff --git a/src/modules/dac/__init__.py b/src/modules/dac/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..51205ef6ded9c6735a988b76008e0f6bdce8e215 --- /dev/null +++ b/src/modules/dac/__init__.py @@ -0,0 +1,16 @@ +__version__ = "1.0.0" + +# preserved here for legacy reasons +__model_version__ = "latest" + +import audiotools + +audiotools.ml.BaseModel.INTERN += ["dac.**"] +audiotools.ml.BaseModel.EXTERN += ["einops"] + + +from . import nn +from . import model +from . import utils +from .model import DAC +from .model import DACFile diff --git a/src/modules/dac/__main__.py b/src/modules/dac/__main__.py new file mode 100644 index 0000000000000000000000000000000000000000..2fa8d15307997663f8143669c2bd56e0889cb021 --- /dev/null +++ b/src/modules/dac/__main__.py @@ -0,0 +1,36 @@ +import sys + +import argbind + +from dac.utils import download +from dac.utils.decode import decode +from dac.utils.encode import encode + +STAGES = ["encode", "decode", "download"] + + +def run(stage: str): + """Run stages. + + Parameters + ---------- + stage : str + Stage to run + """ + if stage not in STAGES: + raise ValueError(f"Unknown command: {stage}. Allowed commands are {STAGES}") + stage_fn = globals()[stage] + + if stage == "download": + stage_fn() + return + + stage_fn() + + +if __name__ == "__main__": + group = sys.argv.pop(1) + args = argbind.parse_args(group=group) + + with argbind.scope(args): + run(group) diff --git a/src/modules/dac/__pycache__/__init__.cpython-310.pyc b/src/modules/dac/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..37ab8619ee9f735925d2c6134f01379f9f07fe80 Binary files /dev/null and b/src/modules/dac/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/modules/dac/__pycache__/__init__.cpython-311.pyc b/src/modules/dac/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..996c60afc76637baa76609f79e1180a9db1ba68c Binary files /dev/null and b/src/modules/dac/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/modules/dac/compare/__init__.py b/src/modules/dac/compare/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..e69de29bb2d1d6434b8b29ae775ad8c2e48c5391 diff --git a/src/modules/dac/compare/encodec.py b/src/modules/dac/compare/encodec.py new file mode 100644 index 0000000000000000000000000000000000000000..42877de3cffa7d681b28266e4e1f537d48b749eb --- /dev/null +++ b/src/modules/dac/compare/encodec.py @@ -0,0 +1,54 @@ +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from encodec import EncodecModel + + +class Encodec(BaseModel): + def __init__(self, sample_rate: int = 24000, bandwidth: float = 24.0): + super().__init__() + + if sample_rate == 24000: + self.model = EncodecModel.encodec_model_24khz() + else: + self.model = EncodecModel.encodec_model_48khz() + self.model.set_target_bandwidth(bandwidth) + self.sample_rate = 44100 + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = 44100, + n_quantizers: int = None, + ): + signal = AudioSignal(audio_data, sample_rate) + signal.resample(self.model.sample_rate) + recons = self.model(signal.audio_data) + recons = AudioSignal(recons, self.model.sample_rate) + recons.resample(sample_rate) + return {"audio": recons.audio_data} + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = Encodec() + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p/1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + + print(x.shape, out.shape) diff --git a/src/modules/dac/model/.ipynb_checkpoints/dac-checkpoint.py b/src/modules/dac/model/.ipynb_checkpoints/dac-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..6d44a18a9e98fdcce9377a744b6f9d7dfa6a607b --- /dev/null +++ b/src/modules/dac/model/.ipynb_checkpoints/dac-checkpoint.py @@ -0,0 +1,364 @@ +import math +from typing import List +from typing import Union + +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from torch import nn + +from .base import CodecMixin +from ..nn.layers import Snake1d +from ..nn.layers import WNConv1d +from ..nn.layers import WNConvTranspose1d +from ..nn.quantize import ResidualVectorQuantize + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 44100, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) + + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) + z, codes, latents, commitment_loss, codebook_loss = self.quantizer( + z, n_quantizers + ) + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + return self.decoder(z) + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + z, codes, latents, commitment_loss, codebook_loss = self.encode( + audio_data, n_quantizers + ) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = DAC().to("cpu") + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p/1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) + + # Create gradient variable + grad = torch.zeros_like(out) + grad[:, :, grad.shape[-1] // 2] = 1 + + # Make a backward pass + out.backward(grad) + + # Check non-zero values + gradmap = x.grad.squeeze(0) + gradmap = (gradmap != 0).sum(0) # sum across features + rf = (gradmap != 0).sum() + + print(f"Receptive field: {rf.item()}") + + x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/src/modules/dac/model/__init__.py b/src/modules/dac/model/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..02a75b7ad6028f5c41b6a8285b0257d4c23bdfcf --- /dev/null +++ b/src/modules/dac/model/__init__.py @@ -0,0 +1,4 @@ +from .base import CodecMixin +from .base import DACFile +from .dac import DAC +from .discriminator import Discriminator diff --git a/src/modules/dac/model/__pycache__/__init__.cpython-310.pyc b/src/modules/dac/model/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4f412ea22c0e7baf9a7fa637c96a7d84dda476b Binary files /dev/null and b/src/modules/dac/model/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/modules/dac/model/__pycache__/__init__.cpython-311.pyc b/src/modules/dac/model/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8c7e138a9f5631558eb4dbde2f623bf5b2863c03 Binary files /dev/null and b/src/modules/dac/model/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/modules/dac/model/__pycache__/base.cpython-310.pyc b/src/modules/dac/model/__pycache__/base.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d7dc58569bd1c938c3f478449d3a8636fcdfc4ad Binary files /dev/null and b/src/modules/dac/model/__pycache__/base.cpython-310.pyc differ diff --git a/src/modules/dac/model/__pycache__/base.cpython-311.pyc b/src/modules/dac/model/__pycache__/base.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..b89297ea647f3e03add49e7d8c9bb11f46fe3bad Binary files /dev/null and b/src/modules/dac/model/__pycache__/base.cpython-311.pyc differ diff --git a/src/modules/dac/model/__pycache__/dac.cpython-310.pyc b/src/modules/dac/model/__pycache__/dac.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..66adf0ce2b256c971fca5aedb051c1e9af4c4b1e Binary files /dev/null and b/src/modules/dac/model/__pycache__/dac.cpython-310.pyc differ diff --git a/src/modules/dac/model/__pycache__/dac.cpython-311.pyc b/src/modules/dac/model/__pycache__/dac.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0608cb168829bf41f769cd0f6b2923f7f421b17b Binary files /dev/null and b/src/modules/dac/model/__pycache__/dac.cpython-311.pyc differ diff --git a/src/modules/dac/model/__pycache__/discriminator.cpython-310.pyc b/src/modules/dac/model/__pycache__/discriminator.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..1af4a16e2b17cdf8620db8a033ebcee17c2ce516 Binary files /dev/null and b/src/modules/dac/model/__pycache__/discriminator.cpython-310.pyc differ diff --git a/src/modules/dac/model/__pycache__/discriminator.cpython-311.pyc b/src/modules/dac/model/__pycache__/discriminator.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..638e7e8901d6e40b105363abc2743bf326acb661 Binary files /dev/null and b/src/modules/dac/model/__pycache__/discriminator.cpython-311.pyc differ diff --git a/src/modules/dac/model/base.py b/src/modules/dac/model/base.py new file mode 100644 index 0000000000000000000000000000000000000000..546b3cb7092d6bd1837ec780228d2a5b3e01fe8d --- /dev/null +++ b/src/modules/dac/model/base.py @@ -0,0 +1,294 @@ +import math +from dataclasses import dataclass +from pathlib import Path +from typing import Union + +import numpy as np +import torch +import tqdm +from audiotools import AudioSignal +from torch import nn + +SUPPORTED_VERSIONS = ["1.0.0"] + + +@dataclass +class DACFile: + codes: torch.Tensor + + # Metadata + chunk_length: int + original_length: int + input_db: float + channels: int + sample_rate: int + padding: bool + dac_version: str + + def save(self, path): + artifacts = { + "codes": self.codes.numpy().astype(np.uint16), + "metadata": { + "input_db": self.input_db.numpy().astype(np.float32), + "original_length": self.original_length, + "sample_rate": self.sample_rate, + "chunk_length": self.chunk_length, + "channels": self.channels, + "padding": self.padding, + "dac_version": SUPPORTED_VERSIONS[-1], + }, + } + path = Path(path).with_suffix(".dac") + with open(path, "wb") as f: + np.save(f, artifacts) + return path + + @classmethod + def load(cls, path): + artifacts = np.load(path, allow_pickle=True)[()] + codes = torch.from_numpy(artifacts["codes"].astype(int)) + if artifacts["metadata"].get("dac_version", None) not in SUPPORTED_VERSIONS: + raise RuntimeError( + f"Given file {path} can't be loaded with this version of descript-audio-codec." + ) + return cls(codes=codes, **artifacts["metadata"]) + + +class CodecMixin: + @property + def padding(self): + if not hasattr(self, "_padding"): + self._padding = True + return self._padding + + @padding.setter + def padding(self, value): + assert isinstance(value, bool) + + layers = [ + l for l in self.modules() if isinstance(l, (nn.Conv1d, nn.ConvTranspose1d)) + ] + + for layer in layers: + if value: + if hasattr(layer, "original_padding"): + layer.padding = layer.original_padding + else: + layer.original_padding = layer.padding + layer.padding = tuple(0 for _ in range(len(layer.padding))) + + self._padding = value + + def get_delay(self): + # Any number works here, delay is invariant to input length + l_out = self.get_output_length(0) + L = l_out + + layers = [] + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + layers.append(layer) + + for layer in reversed(layers): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.ConvTranspose1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.Conv1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.ceil(L) + + l_in = L + + return (l_in - l_out) // 2 + + def get_output_length(self, input_length): + L = input_length + # Calculate output length + for layer in self.modules(): + if isinstance(layer, (nn.Conv1d, nn.ConvTranspose1d)): + d = layer.dilation[0] + k = layer.kernel_size[0] + s = layer.stride[0] + + if isinstance(layer, nn.Conv1d): + L = ((L - d * (k - 1) - 1) / s) + 1 + elif isinstance(layer, nn.ConvTranspose1d): + L = (L - 1) * s + d * (k - 1) + 1 + + L = math.floor(L) + return L + + @torch.no_grad() + def compress( + self, + audio_path_or_signal: Union[str, Path, AudioSignal], + win_duration: float = 1.0, + verbose: bool = False, + normalize_db: float = -16, + n_quantizers: int = None, + ) -> DACFile: + """Processes an audio signal from a file or AudioSignal object into + discrete codes. This function processes the signal in short windows, + using constant GPU memory. + + Parameters + ---------- + audio_path_or_signal : Union[str, Path, AudioSignal] + audio signal to reconstruct + win_duration : float, optional + window duration in seconds, by default 5.0 + verbose : bool, optional + by default False + normalize_db : float, optional + normalize db, by default -16 + + Returns + ------- + DACFile + Object containing compressed codes and metadata + required for decompression + """ + audio_signal = audio_path_or_signal + if isinstance(audio_signal, (str, Path)): + audio_signal = AudioSignal.load_from_file_with_ffmpeg(str(audio_signal)) + + self.eval() + original_padding = self.padding + original_device = audio_signal.device + + audio_signal = audio_signal.clone() + original_sr = audio_signal.sample_rate + + resample_fn = audio_signal.resample + loudness_fn = audio_signal.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if audio_signal.signal_duration >= 10 * 60 * 60: + resample_fn = audio_signal.ffmpeg_resample + loudness_fn = audio_signal.ffmpeg_loudness + + original_length = audio_signal.signal_length + resample_fn(self.sample_rate) + input_db = loudness_fn() + + if normalize_db is not None: + audio_signal.normalize(normalize_db) + audio_signal.ensure_max_of_audio() + + nb, nac, nt = audio_signal.audio_data.shape + audio_signal.audio_data = audio_signal.audio_data.reshape(nb * nac, 1, nt) + win_duration = ( + audio_signal.signal_duration if win_duration is None else win_duration + ) + + if audio_signal.signal_duration <= win_duration: + # Unchunked compression (used if signal length < win duration) + self.padding = True + n_samples = nt + hop = nt + else: + # Chunked inference + self.padding = False + # Zero-pad signal on either side by the delay + audio_signal.zero_pad(self.delay, self.delay) + n_samples = int(win_duration * self.sample_rate) + # Round n_samples to nearest hop length multiple + n_samples = int(math.ceil(n_samples / self.hop_length) * self.hop_length) + hop = self.get_output_length(n_samples) + + codes = [] + range_fn = range if not verbose else tqdm.trange + + for i in range_fn(0, nt, hop): + x = audio_signal[..., i : i + n_samples] + x = x.zero_pad(0, max(0, n_samples - x.shape[-1])) + + audio_data = x.audio_data.to(self.device) + audio_data = self.preprocess(audio_data, self.sample_rate) + _, c, _, _, _ = self.encode(audio_data, n_quantizers) + codes.append(c.to(original_device)) + chunk_length = c.shape[-1] + + codes = torch.cat(codes, dim=-1) + + dac_file = DACFile( + codes=codes, + chunk_length=chunk_length, + original_length=original_length, + input_db=input_db, + channels=nac, + sample_rate=original_sr, + padding=self.padding, + dac_version=SUPPORTED_VERSIONS[-1], + ) + + if n_quantizers is not None: + codes = codes[:, :n_quantizers, :] + + self.padding = original_padding + return dac_file + + @torch.no_grad() + def decompress( + self, + obj: Union[str, Path, DACFile], + verbose: bool = False, + ) -> AudioSignal: + """Reconstruct audio from a given .dac file + + Parameters + ---------- + obj : Union[str, Path, DACFile] + .dac file location or corresponding DACFile object. + verbose : bool, optional + Prints progress if True, by default False + + Returns + ------- + AudioSignal + Object with the reconstructed audio + """ + self.eval() + if isinstance(obj, (str, Path)): + obj = DACFile.load(obj) + + original_padding = self.padding + self.padding = obj.padding + + range_fn = range if not verbose else tqdm.trange + codes = obj.codes + original_device = codes.device + chunk_length = obj.chunk_length + recons = [] + + for i in range_fn(0, codes.shape[-1], chunk_length): + c = codes[..., i : i + chunk_length].to(self.device) + z = self.quantizer.from_codes(c)[0] + r = self.decode(z) + recons.append(r.to(original_device)) + + recons = torch.cat(recons, dim=-1) + recons = AudioSignal(recons, self.sample_rate) + + resample_fn = recons.resample + loudness_fn = recons.loudness + + # If audio is > 10 minutes long, use the ffmpeg versions + if recons.signal_duration >= 10 * 60 * 60: + resample_fn = recons.ffmpeg_resample + loudness_fn = recons.ffmpeg_loudness + + recons.normalize(obj.input_db) + resample_fn(obj.sample_rate) + recons = recons[..., : obj.original_length] + loudness_fn() + recons.audio_data = recons.audio_data.reshape( + -1, obj.channels, obj.original_length + ) + + self.padding = original_padding + return recons diff --git a/src/modules/dac/model/dac.py b/src/modules/dac/model/dac.py new file mode 100644 index 0000000000000000000000000000000000000000..6d44a18a9e98fdcce9377a744b6f9d7dfa6a607b --- /dev/null +++ b/src/modules/dac/model/dac.py @@ -0,0 +1,364 @@ +import math +from typing import List +from typing import Union + +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.ml import BaseModel +from torch import nn + +from .base import CodecMixin +from ..nn.layers import Snake1d +from ..nn.layers import WNConv1d +from ..nn.layers import WNConvTranspose1d +from ..nn.quantize import ResidualVectorQuantize + + +def init_weights(m): + if isinstance(m, nn.Conv1d): + nn.init.trunc_normal_(m.weight, std=0.02) + nn.init.constant_(m.bias, 0) + + +class ResidualUnit(nn.Module): + def __init__(self, dim: int = 16, dilation: int = 1): + super().__init__() + pad = ((7 - 1) * dilation) // 2 + self.block = nn.Sequential( + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=7, dilation=dilation, padding=pad), + Snake1d(dim), + WNConv1d(dim, dim, kernel_size=1), + ) + + def forward(self, x): + y = self.block(x) + pad = (x.shape[-1] - y.shape[-1]) // 2 + if pad > 0: + x = x[..., pad:-pad] + return x + y + + +class EncoderBlock(nn.Module): + def __init__(self, dim: int = 16, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + ResidualUnit(dim // 2, dilation=1), + ResidualUnit(dim // 2, dilation=3), + ResidualUnit(dim // 2, dilation=9), + Snake1d(dim // 2), + WNConv1d( + dim // 2, + dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ) + + def forward(self, x): + return self.block(x) + + +class Encoder(nn.Module): + def __init__( + self, + d_model: int = 64, + strides: list = [2, 4, 8, 8], + d_latent: int = 64, + ): + super().__init__() + # Create first convolution + self.block = [WNConv1d(1, d_model, kernel_size=7, padding=3)] + + # Create EncoderBlocks that double channels as they downsample by `stride` + for stride in strides: + d_model *= 2 + self.block += [EncoderBlock(d_model, stride=stride)] + + # Create last convolution + self.block += [ + Snake1d(d_model), + WNConv1d(d_model, d_latent, kernel_size=3, padding=1), + ] + + # Wrap black into nn.Sequential + self.block = nn.Sequential(*self.block) + self.enc_dim = d_model + + def forward(self, x): + return self.block(x) + + +class DecoderBlock(nn.Module): + def __init__(self, input_dim: int = 16, output_dim: int = 8, stride: int = 1): + super().__init__() + self.block = nn.Sequential( + Snake1d(input_dim), + WNConvTranspose1d( + input_dim, + output_dim, + kernel_size=2 * stride, + stride=stride, + padding=math.ceil(stride / 2), + ), + ResidualUnit(output_dim, dilation=1), + ResidualUnit(output_dim, dilation=3), + ResidualUnit(output_dim, dilation=9), + ) + + def forward(self, x): + return self.block(x) + + +class Decoder(nn.Module): + def __init__( + self, + input_channel, + channels, + rates, + d_out: int = 1, + ): + super().__init__() + + # Add first conv layer + layers = [WNConv1d(input_channel, channels, kernel_size=7, padding=3)] + + # Add upsampling + MRF blocks + for i, stride in enumerate(rates): + input_dim = channels // 2**i + output_dim = channels // 2 ** (i + 1) + layers += [DecoderBlock(input_dim, output_dim, stride)] + + # Add final conv layer + layers += [ + Snake1d(output_dim), + WNConv1d(output_dim, d_out, kernel_size=7, padding=3), + nn.Tanh(), + ] + + self.model = nn.Sequential(*layers) + + def forward(self, x): + return self.model(x) + + +class DAC(BaseModel, CodecMixin): + def __init__( + self, + encoder_dim: int = 64, + encoder_rates: List[int] = [2, 4, 8, 8], + latent_dim: int = None, + decoder_dim: int = 1536, + decoder_rates: List[int] = [8, 8, 4, 2], + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: bool = False, + sample_rate: int = 44100, + ): + super().__init__() + + self.encoder_dim = encoder_dim + self.encoder_rates = encoder_rates + self.decoder_dim = decoder_dim + self.decoder_rates = decoder_rates + self.sample_rate = sample_rate + + if latent_dim is None: + latent_dim = encoder_dim * (2 ** len(encoder_rates)) + + self.latent_dim = latent_dim + + self.hop_length = np.prod(encoder_rates) + self.encoder = Encoder(encoder_dim, encoder_rates, latent_dim) + + self.n_codebooks = n_codebooks + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + self.quantizer = ResidualVectorQuantize( + input_dim=latent_dim, + n_codebooks=n_codebooks, + codebook_size=codebook_size, + codebook_dim=codebook_dim, + quantizer_dropout=quantizer_dropout, + ) + + self.decoder = Decoder( + latent_dim, + decoder_dim, + decoder_rates, + ) + self.sample_rate = sample_rate + self.apply(init_weights) + + self.delay = self.get_delay() + + def preprocess(self, audio_data, sample_rate): + if sample_rate is None: + sample_rate = self.sample_rate + assert sample_rate == self.sample_rate + + length = audio_data.shape[-1] + right_pad = math.ceil(length / self.hop_length) * self.hop_length - length + audio_data = nn.functional.pad(audio_data, (0, right_pad)) + + return audio_data + + def encode( + self, + audio_data: torch.Tensor, + n_quantizers: int = None, + ): + """Encode given audio data and return quantized latent codes + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + n_quantizers : int, optional + Number of quantizers to use, by default None + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + """ + z = self.encoder(audio_data) + z, codes, latents, commitment_loss, codebook_loss = self.quantizer( + z, n_quantizers + ) + return z, codes, latents, commitment_loss, codebook_loss + + def decode(self, z: torch.Tensor): + """Decode given latent codes and return audio data + + Parameters + ---------- + z : Tensor[B x D x T] + Quantized continuous representation of input + length : int, optional + Number of samples in output audio, by default None + + Returns + ------- + dict + A dictionary with the following keys: + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + return self.decoder(z) + + def forward( + self, + audio_data: torch.Tensor, + sample_rate: int = None, + n_quantizers: int = None, + ): + """Model forward pass + + Parameters + ---------- + audio_data : Tensor[B x 1 x T] + Audio data to encode + sample_rate : int, optional + Sample rate of audio data in Hz, by default None + If None, defaults to `self.sample_rate` + n_quantizers : int, optional + Number of quantizers to use, by default None. + If None, all quantizers are used. + + Returns + ------- + dict + A dictionary with the following keys: + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + "length" : int + Number of samples in input audio + "audio" : Tensor[B x 1 x length] + Decoded audio data. + """ + length = audio_data.shape[-1] + audio_data = self.preprocess(audio_data, sample_rate) + z, codes, latents, commitment_loss, codebook_loss = self.encode( + audio_data, n_quantizers + ) + + x = self.decode(z) + return { + "audio": x[..., :length], + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + +if __name__ == "__main__": + import numpy as np + from functools import partial + + model = DAC().to("cpu") + + for n, m in model.named_modules(): + o = m.extra_repr() + p = sum([np.prod(p.size()) for p in m.parameters()]) + fn = lambda o, p: o + f" {p/1e6:<.3f}M params." + setattr(m, "extra_repr", partial(fn, o=o, p=p)) + print(model) + print("Total # of params: ", sum([np.prod(p.size()) for p in model.parameters()])) + + length = 88200 * 2 + x = torch.randn(1, 1, length).to(model.device) + x.requires_grad_(True) + x.retain_grad() + + # Make a forward pass + out = model(x)["audio"] + print("Input shape:", x.shape) + print("Output shape:", out.shape) + + # Create gradient variable + grad = torch.zeros_like(out) + grad[:, :, grad.shape[-1] // 2] = 1 + + # Make a backward pass + out.backward(grad) + + # Check non-zero values + gradmap = x.grad.squeeze(0) + gradmap = (gradmap != 0).sum(0) # sum across features + rf = (gradmap != 0).sum() + + print(f"Receptive field: {rf.item()}") + + x = AudioSignal(torch.randn(1, 1, 44100 * 60), 44100) + model.decompress(model.compress(x, verbose=True), verbose=True) diff --git a/src/modules/dac/model/discriminator.py b/src/modules/dac/model/discriminator.py new file mode 100644 index 0000000000000000000000000000000000000000..09c79d1342ca46bef21daca64667577f05e61638 --- /dev/null +++ b/src/modules/dac/model/discriminator.py @@ -0,0 +1,228 @@ +import torch +import torch.nn as nn +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import ml +from audiotools import STFTParams +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv1d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +def WNConv2d(*args, **kwargs): + act = kwargs.pop("act", True) + conv = weight_norm(nn.Conv2d(*args, **kwargs)) + if not act: + return conv + return nn.Sequential(conv, nn.LeakyReLU(0.1)) + + +class MPD(nn.Module): + def __init__(self, period): + super().__init__() + self.period = period + self.convs = nn.ModuleList( + [ + WNConv2d(1, 32, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(32, 128, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(128, 512, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(512, 1024, (5, 1), (3, 1), padding=(2, 0)), + WNConv2d(1024, 1024, (5, 1), 1, padding=(2, 0)), + ] + ) + self.conv_post = WNConv2d( + 1024, 1, kernel_size=(3, 1), padding=(1, 0), act=False + ) + + def pad_to_period(self, x): + t = x.shape[-1] + x = F.pad(x, (0, self.period - t % self.period), mode="reflect") + return x + + def forward(self, x): + fmap = [] + + x = self.pad_to_period(x) + x = rearrange(x, "b c (l p) -> b c l p", p=self.period) + + for layer in self.convs: + x = layer(x) + fmap.append(x) + + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class MSD(nn.Module): + def __init__(self, rate: int = 1, sample_rate: int = 44100): + super().__init__() + self.convs = nn.ModuleList( + [ + WNConv1d(1, 16, 15, 1, padding=7), + WNConv1d(16, 64, 41, 4, groups=4, padding=20), + WNConv1d(64, 256, 41, 4, groups=16, padding=20), + WNConv1d(256, 1024, 41, 4, groups=64, padding=20), + WNConv1d(1024, 1024, 41, 4, groups=256, padding=20), + WNConv1d(1024, 1024, 5, 1, padding=2), + ] + ) + self.conv_post = WNConv1d(1024, 1, 3, 1, padding=1, act=False) + self.sample_rate = sample_rate + self.rate = rate + + def forward(self, x): + x = AudioSignal(x, self.sample_rate) + x.resample(self.sample_rate // self.rate) + x = x.audio_data + + fmap = [] + + for l in self.convs: + x = l(x) + fmap.append(x) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +BANDS = [(0.0, 0.1), (0.1, 0.25), (0.25, 0.5), (0.5, 0.75), (0.75, 1.0)] + + +class MRD(nn.Module): + def __init__( + self, + window_length: int, + hop_factor: float = 0.25, + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Complex multi-band spectrogram discriminator. + Parameters + ---------- + window_length : int + Window length of STFT. + hop_factor : float, optional + Hop factor of the STFT, defaults to ``0.25 * window_length``. + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run discriminator over. + """ + super().__init__() + + self.window_length = window_length + self.hop_factor = hop_factor + self.sample_rate = sample_rate + self.stft_params = STFTParams( + window_length=window_length, + hop_length=int(window_length * hop_factor), + match_stride=True, + ) + + n_fft = window_length // 2 + 1 + bands = [(int(b[0] * n_fft), int(b[1] * n_fft)) for b in bands] + self.bands = bands + + ch = 32 + convs = lambda: nn.ModuleList( + [ + WNConv2d(2, ch, (3, 9), (1, 1), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 9), (1, 2), padding=(1, 4)), + WNConv2d(ch, ch, (3, 3), (1, 1), padding=(1, 1)), + ] + ) + self.band_convs = nn.ModuleList([convs() for _ in range(len(self.bands))]) + self.conv_post = WNConv2d(ch, 1, (3, 3), (1, 1), padding=(1, 1), act=False) + + def spectrogram(self, x): + x = AudioSignal(x, self.sample_rate, stft_params=self.stft_params) + x = torch.view_as_real(x.stft()) + x = rearrange(x, "b 1 f t c -> (b 1) c t f") + # Split into bands + x_bands = [x[..., b[0] : b[1]] for b in self.bands] + return x_bands + + def forward(self, x): + x_bands = self.spectrogram(x) + fmap = [] + + x = [] + for band, stack in zip(x_bands, self.band_convs): + for layer in stack: + band = layer(band) + fmap.append(band) + x.append(band) + + x = torch.cat(x, dim=-1) + x = self.conv_post(x) + fmap.append(x) + + return fmap + + +class Discriminator(ml.BaseModel): + def __init__( + self, + rates: list = [], + periods: list = [2, 3, 5, 7, 11], + fft_sizes: list = [2048, 1024, 512], + sample_rate: int = 44100, + bands: list = BANDS, + ): + """Discriminator that combines multiple discriminators. + + Parameters + ---------- + rates : list, optional + sampling rates (in Hz) to run MSD at, by default [] + If empty, MSD is not used. + periods : list, optional + periods (of samples) to run MPD at, by default [2, 3, 5, 7, 11] + fft_sizes : list, optional + Window sizes of the FFT to run MRD at, by default [2048, 1024, 512] + sample_rate : int, optional + Sampling rate of audio in Hz, by default 44100 + bands : list, optional + Bands to run MRD at, by default `BANDS` + """ + super().__init__() + discs = [] + discs += [MPD(p) for p in periods] + discs += [MSD(r, sample_rate=sample_rate) for r in rates] + discs += [MRD(f, sample_rate=sample_rate, bands=bands) for f in fft_sizes] + self.discriminators = nn.ModuleList(discs) + + def preprocess(self, y): + # Remove DC offset + y = y - y.mean(dim=-1, keepdims=True) + # Peak normalize the volume of input audio + y = 0.8 * y / (y.abs().max(dim=-1, keepdim=True)[0] + 1e-9) + return y + + def forward(self, x): + x = self.preprocess(x) + fmaps = [d(x) for d in self.discriminators] + return fmaps + + +if __name__ == "__main__": + disc = Discriminator() + x = torch.zeros(1, 1, 44100) + results = disc(x) + for i, result in enumerate(results): + print(f"disc{i}") + for i, r in enumerate(result): + print(r.shape, r.mean(), r.min(), r.max()) + print() diff --git a/src/modules/dac/nn/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/dac/nn/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7 --- /dev/null +++ b/src/modules/dac/nn/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,3 @@ +from . import layers +from . import loss +from . import quantize diff --git a/src/modules/dac/nn/.ipynb_checkpoints/layers-checkpoint.py b/src/modules/dac/nn/.ipynb_checkpoints/layers-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94 --- /dev/null +++ b/src/modules/dac/nn/.ipynb_checkpoints/layers-checkpoint.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/src/modules/dac/nn/.ipynb_checkpoints/loss-checkpoint.py b/src/modules/dac/nn/.ipynb_checkpoints/loss-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b --- /dev/null +++ b/src/modules/dac/nn/.ipynb_checkpoints/loss-checkpoint.py @@ -0,0 +1,368 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/src/modules/dac/nn/.ipynb_checkpoints/quantize-checkpoint.py b/src/modules/dac/nn/.ipynb_checkpoints/quantize-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..bc462bc56ab5191b65d58ebd8bcaff4af7fb5927 --- /dev/null +++ b/src/modules/dac/nn/.ipynb_checkpoints/quantize-checkpoint.py @@ -0,0 +1,262 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from .layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/src/modules/dac/nn/__init__.py b/src/modules/dac/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7 --- /dev/null +++ b/src/modules/dac/nn/__init__.py @@ -0,0 +1,3 @@ +from . import layers +from . import loss +from . import quantize diff --git a/src/modules/dac/nn/__pycache__/__init__.cpython-310.pyc b/src/modules/dac/nn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f834bc23586af3072ded94d4b593733396f89f4c Binary files /dev/null and b/src/modules/dac/nn/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/modules/dac/nn/__pycache__/__init__.cpython-311.pyc b/src/modules/dac/nn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..eafea719f716b876b394d0d5fc1fefda77d5ba0d Binary files /dev/null and b/src/modules/dac/nn/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/modules/dac/nn/__pycache__/layers.cpython-310.pyc b/src/modules/dac/nn/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..be3e7d3ab89240ade829e5e4e82b4c88cc50f41f Binary files /dev/null and b/src/modules/dac/nn/__pycache__/layers.cpython-310.pyc differ diff --git a/src/modules/dac/nn/__pycache__/layers.cpython-311.pyc b/src/modules/dac/nn/__pycache__/layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c60216140d4846cb38f54cfb2ffef70b16f57c7c Binary files /dev/null and b/src/modules/dac/nn/__pycache__/layers.cpython-311.pyc differ diff --git a/src/modules/dac/nn/__pycache__/loss.cpython-310.pyc b/src/modules/dac/nn/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c4a3ecad3947fe1ac0a7e8058136c5d9c629b369 Binary files /dev/null and b/src/modules/dac/nn/__pycache__/loss.cpython-310.pyc differ diff --git a/src/modules/dac/nn/__pycache__/loss.cpython-311.pyc b/src/modules/dac/nn/__pycache__/loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..75b0df5a97a40a5493acf935aa0ab2d9a179afd8 Binary files /dev/null and b/src/modules/dac/nn/__pycache__/loss.cpython-311.pyc differ diff --git a/src/modules/dac/nn/__pycache__/quantize.cpython-310.pyc b/src/modules/dac/nn/__pycache__/quantize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..aaa26e8b6bdbb409f10fb8f5f229697197082fab Binary files /dev/null and b/src/modules/dac/nn/__pycache__/quantize.cpython-310.pyc differ diff --git a/src/modules/dac/nn/__pycache__/quantize.cpython-311.pyc b/src/modules/dac/nn/__pycache__/quantize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..8737b2fa25ecfe93c80d03713896279f36a047ba Binary files /dev/null and b/src/modules/dac/nn/__pycache__/quantize.cpython-311.pyc differ diff --git a/src/modules/dac/nn/layers.py b/src/modules/dac/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94 --- /dev/null +++ b/src/modules/dac/nn/layers.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/src/modules/dac/nn/loss.py b/src/modules/dac/nn/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b --- /dev/null +++ b/src/modules/dac/nn/loss.py @@ -0,0 +1,368 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/src/modules/dac/nn/quantize.py b/src/modules/dac/nn/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..bc462bc56ab5191b65d58ebd8bcaff4af7fb5927 --- /dev/null +++ b/src/modules/dac/nn/quantize.py @@ -0,0 +1,262 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from .layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/src/modules/dac/utils/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/dac/utils/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..36fbd70cf223b04135af71c4b322e1a92431d6ca --- /dev/null +++ b/src/modules/dac/utils/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,122 @@ +from pathlib import Path + +import argbind +from audiotools import ml + +from ..model import DAC + +Accelerator = ml.Accelerator + +__MODEL_LATEST_TAGS__ = { + ("44khz", "8kbps"): "0.0.1", + ("24khz", "8kbps"): "0.0.4", + ("16khz", "8kbps"): "0.0.5", + ("44khz", "16kbps"): "1.0.0", +} + +__MODEL_URLS__ = { + ( + "44khz", + "0.0.1", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", + ( + "24khz", + "0.0.4", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", + ( + "16khz", + "0.0.5", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", + ( + "44khz", + "1.0.0", + "16kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", +} + + +@argbind.bind(group="download", positional=True, without_prefix=True) +def download( + model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" +): + """ + Function that downloads the weights file from URL if a local cache is not found. + + Parameters + ---------- + model_type : str + The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + Only 44khz model supports 16kbps. + tag : str + The tag of the model to download. Defaults to "latest". + + Returns + ------- + Path + Directory path required to load model via audiotools. + """ + model_type = model_type.lower() + tag = tag.lower() + + assert model_type in [ + "44khz", + "24khz", + "16khz", + ], "model_type must be one of '44khz', '24khz', or '16khz'" + + assert model_bitrate in [ + "8kbps", + "16kbps", + ], "model_bitrate must be one of '8kbps', or '16kbps'" + + if tag == "latest": + tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] + + download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) + + if download_link is None: + raise ValueError( + f"Could not find model with tag {tag} and model type {model_type}" + ) + + local_path = ( + Path.home() + / ".cache" + / "descript" + / "dac" + / f"weights_{model_type}_{model_bitrate}_{tag}.pth" + ) + if not local_path.exists(): + local_path.parent.mkdir(parents=True, exist_ok=True) + + # Download the model + import requests + + response = requests.get(download_link) + + if response.status_code != 200: + raise ValueError( + f"Could not download model. Received response code {response.status_code}" + ) + local_path.write_bytes(response.content) + + return local_path + + +def load_model( + model_type: str = "44khz", + model_bitrate: str = "8kbps", + tag: str = "latest", + load_path: str = None, +): + if not load_path: + load_path = download( + model_type=model_type, model_bitrate=model_bitrate, tag=tag + ) + generator = DAC.load(load_path) + return generator diff --git a/src/modules/dac/utils/__init__.py b/src/modules/dac/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..36fbd70cf223b04135af71c4b322e1a92431d6ca --- /dev/null +++ b/src/modules/dac/utils/__init__.py @@ -0,0 +1,122 @@ +from pathlib import Path + +import argbind +from audiotools import ml + +from ..model import DAC + +Accelerator = ml.Accelerator + +__MODEL_LATEST_TAGS__ = { + ("44khz", "8kbps"): "0.0.1", + ("24khz", "8kbps"): "0.0.4", + ("16khz", "8kbps"): "0.0.5", + ("44khz", "16kbps"): "1.0.0", +} + +__MODEL_URLS__ = { + ( + "44khz", + "0.0.1", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.1/weights.pth", + ( + "24khz", + "0.0.4", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.4/weights_24khz.pth", + ( + "16khz", + "0.0.5", + "8kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/0.0.5/weights_16khz.pth", + ( + "44khz", + "1.0.0", + "16kbps", + ): "https://github.com/descriptinc/descript-audio-codec/releases/download/1.0.0/weights_44khz_16kbps.pth", +} + + +@argbind.bind(group="download", positional=True, without_prefix=True) +def download( + model_type: str = "44khz", model_bitrate: str = "8kbps", tag: str = "latest" +): + """ + Function that downloads the weights file from URL if a local cache is not found. + + Parameters + ---------- + model_type : str + The type of model to download. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + Only 44khz model supports 16kbps. + tag : str + The tag of the model to download. Defaults to "latest". + + Returns + ------- + Path + Directory path required to load model via audiotools. + """ + model_type = model_type.lower() + tag = tag.lower() + + assert model_type in [ + "44khz", + "24khz", + "16khz", + ], "model_type must be one of '44khz', '24khz', or '16khz'" + + assert model_bitrate in [ + "8kbps", + "16kbps", + ], "model_bitrate must be one of '8kbps', or '16kbps'" + + if tag == "latest": + tag = __MODEL_LATEST_TAGS__[(model_type, model_bitrate)] + + download_link = __MODEL_URLS__.get((model_type, tag, model_bitrate), None) + + if download_link is None: + raise ValueError( + f"Could not find model with tag {tag} and model type {model_type}" + ) + + local_path = ( + Path.home() + / ".cache" + / "descript" + / "dac" + / f"weights_{model_type}_{model_bitrate}_{tag}.pth" + ) + if not local_path.exists(): + local_path.parent.mkdir(parents=True, exist_ok=True) + + # Download the model + import requests + + response = requests.get(download_link) + + if response.status_code != 200: + raise ValueError( + f"Could not download model. Received response code {response.status_code}" + ) + local_path.write_bytes(response.content) + + return local_path + + +def load_model( + model_type: str = "44khz", + model_bitrate: str = "8kbps", + tag: str = "latest", + load_path: str = None, +): + if not load_path: + load_path = download( + model_type=model_type, model_bitrate=model_bitrate, tag=tag + ) + generator = DAC.load(load_path) + return generator diff --git a/src/modules/dac/utils/__pycache__/__init__.cpython-310.pyc b/src/modules/dac/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..88aac19a3b6f8d5c904369537211025c985d307a Binary files /dev/null and b/src/modules/dac/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/modules/dac/utils/__pycache__/__init__.cpython-311.pyc b/src/modules/dac/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ee7b69ad676b50b2c12c1c2414f54c7974e249a9 Binary files /dev/null and b/src/modules/dac/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/modules/dac/utils/decode.py b/src/modules/dac/utils/decode.py new file mode 100644 index 0000000000000000000000000000000000000000..08d44e8453ec4fa3433c2a9952d1a4da15315939 --- /dev/null +++ b/src/modules/dac/utils/decode.py @@ -0,0 +1,95 @@ +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from tqdm import tqdm + +from dac import DACFile +from dac.utils import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="decode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def decode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + device: str = "cuda", + model_type: str = "44khz", + verbose: bool = False, +): + """Decode audio from codes. + + Parameters + ---------- + input : str + Path to input directory or file + output : str, optional + Path to output directory, by default "". + If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + device : str, optional + Device to use, by default "cuda". If "cpu", the model will be loaded on the CPU. + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + + # Find all .dac files in input directory + _input = Path(input) + input_files = list(_input.glob("**/*.dac")) + + # If input is a .dac file, add it to the list + if _input.suffix == ".dac": + input_files.append(_input) + + # Create output directory + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(input_files)), desc=f"Decoding files"): + # Load file + artifact = DACFile.load(input_files[i]) + + # Reconstruct audio from codes + recons = generator.decompress(artifact, verbose=verbose) + + # Compute output path + relative_path = input_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = input_files[i] + output_name = relative_path.with_suffix(".wav").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + # Write to file + recons.write(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + decode() diff --git a/src/modules/dac/utils/encode.py b/src/modules/dac/utils/encode.py new file mode 100644 index 0000000000000000000000000000000000000000..aa3f6f44b3c210f485da1b1726b85494ff5e7804 --- /dev/null +++ b/src/modules/dac/utils/encode.py @@ -0,0 +1,94 @@ +import math +import warnings +from pathlib import Path + +import argbind +import numpy as np +import torch +from audiotools import AudioSignal +from audiotools.core import util +from tqdm import tqdm + +from dac.utils import load_model + +warnings.filterwarnings("ignore", category=UserWarning) + + +@argbind.bind(group="encode", positional=True, without_prefix=True) +@torch.inference_mode() +@torch.no_grad() +def encode( + input: str, + output: str = "", + weights_path: str = "", + model_tag: str = "latest", + model_bitrate: str = "8kbps", + n_quantizers: int = None, + device: str = "cuda", + model_type: str = "44khz", + win_duration: float = 5.0, + verbose: bool = False, +): + """Encode audio files in input path to .dac format. + + Parameters + ---------- + input : str + Path to input audio file or directory + output : str, optional + Path to output directory, by default "". If `input` is a directory, the directory sub-tree relative to `input` is re-created in `output`. + weights_path : str, optional + Path to weights file, by default "". If not specified, the weights file will be downloaded from the internet using the + model_tag and model_type. + model_tag : str, optional + Tag of the model to use, by default "latest". Ignored if `weights_path` is specified. + model_bitrate: str + Bitrate of the model. Must be one of "8kbps", or "16kbps". Defaults to "8kbps". + n_quantizers : int, optional + Number of quantizers to use, by default None. If not specified, all the quantizers will be used and the model will compress at maximum bitrate. + device : str, optional + Device to use, by default "cuda" + model_type : str, optional + The type of model to use. Must be one of "44khz", "24khz", or "16khz". Defaults to "44khz". Ignored if `weights_path` is specified. + """ + generator = load_model( + model_type=model_type, + model_bitrate=model_bitrate, + tag=model_tag, + load_path=weights_path, + ) + generator.to(device) + generator.eval() + kwargs = {"n_quantizers": n_quantizers} + + # Find all audio files in input path + input = Path(input) + audio_files = util.find_audio(input) + + output = Path(output) + output.mkdir(parents=True, exist_ok=True) + + for i in tqdm(range(len(audio_files)), desc="Encoding files"): + # Load file + signal = AudioSignal(audio_files[i]) + + # Encode audio to .dac format + artifact = generator.compress(signal, win_duration, verbose=verbose, **kwargs) + + # Compute output path + relative_path = audio_files[i].relative_to(input) + output_dir = output / relative_path.parent + if not relative_path.name: + output_dir = output + relative_path = audio_files[i] + output_name = relative_path.with_suffix(".dac").name + output_path = output_dir / output_name + output_path.parent.mkdir(parents=True, exist_ok=True) + + artifact.save(output_path) + + +if __name__ == "__main__": + args = argbind.parse_args() + with argbind.scope(args): + encode() diff --git a/src/modules/stable_vae/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/stable_vae/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..a8065ea65729a235c519de4cb86c5ea07e0ab7be --- /dev/null +++ b/src/modules/stable_vae/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,40 @@ +from .models.autoencoders import create_autoencoder_from_config +import os +import json +import torch +from torch.nn.utils import remove_weight_norm + + +def remove_all_weight_norm(model): + for name, module in model.named_modules(): + if hasattr(module, 'weight_g'): + remove_weight_norm(module) + + +def load_vae(ckpt_path, remove_weight_norm=False): + config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json') + + # Load the model configuration + with open(config_file) as f: + model_config = json.load(f) + + # Create the model from the configuration + model = create_autoencoder_from_config(model_config) + + # Load the state dictionary from the checkpoint + model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict'] + + # Strip the "autoencoder." prefix from the keys + model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")} + + # Load the state dictionary into the model + model.load_state_dict(model_dict) + + # Remove weight normalization + if remove_weight_norm: + remove_all_weight_norm(model) + + # Set the model to evaluation mode + model.eval() + + return model diff --git a/src/modules/stable_vae/__init__.py b/src/modules/stable_vae/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..a8065ea65729a235c519de4cb86c5ea07e0ab7be --- /dev/null +++ b/src/modules/stable_vae/__init__.py @@ -0,0 +1,40 @@ +from .models.autoencoders import create_autoencoder_from_config +import os +import json +import torch +from torch.nn.utils import remove_weight_norm + + +def remove_all_weight_norm(model): + for name, module in model.named_modules(): + if hasattr(module, 'weight_g'): + remove_weight_norm(module) + + +def load_vae(ckpt_path, remove_weight_norm=False): + config_file = os.path.join(os.path.dirname(ckpt_path), 'config.json') + + # Load the model configuration + with open(config_file) as f: + model_config = json.load(f) + + # Create the model from the configuration + model = create_autoencoder_from_config(model_config) + + # Load the state dictionary from the checkpoint + model_dict = torch.load(ckpt_path, map_location='cpu')['state_dict'] + + # Strip the "autoencoder." prefix from the keys + model_dict = {key[len("autoencoder."):]: value for key, value in model_dict.items() if key.startswith("autoencoder.")} + + # Load the state dictionary into the model + model.load_state_dict(model_dict) + + # Remove weight normalization + if remove_weight_norm: + remove_all_weight_norm(model) + + # Set the model to evaluation mode + model.eval() + + return model diff --git a/src/modules/stable_vae/__pycache__/__init__.cpython-310.pyc b/src/modules/stable_vae/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..90d03c41f4de2fc900eed305af55b29a4741b26d Binary files /dev/null and b/src/modules/stable_vae/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/modules/stable_vae/__pycache__/__init__.cpython-311.pyc b/src/modules/stable_vae/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d3351c994034560d767031851b503b096ef03355 Binary files /dev/null and b/src/modules/stable_vae/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/autoencoders-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/autoencoders-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..2741d45e70e3c25bc28f2b4e43b1a3e925a4c9e3 --- /dev/null +++ b/src/modules/stable_vae/models/.ipynb_checkpoints/autoencoders-checkpoint.py @@ -0,0 +1,683 @@ +import torch +import math +import numpy as np + +from torch import nn +from torch.nn import functional as F +from torchaudio import transforms as T +from alias_free_torch import Activation1d +from .nn.layers import WNConv1d, WNConvTranspose1d +from typing import Literal, Dict, Any + +# from .inference.sampling import sample +from .utils import prepare_audio +from .blocks import SnakeBeta +from .bottleneck import Bottleneck, DiscreteBottleneck +from .factory import create_pretransform_from_config, create_bottleneck_from_config +from .pretransforms import Pretransform + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: + if activation == "elu": + act = nn.ELU() + elif activation == "snake": + act = SnakeBeta(channels) + elif activation == "none": + act = nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + + if antialias: + act = Activation1d(act) + + return act + +class ResidualUnit(nn.Module): + def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): + super().__init__() + + self.dilation = dilation + + padding = (dilation * (7-1)) // 2 + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=7, dilation=dilation, padding=padding), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=out_channels, out_channels=out_channels, + kernel_size=1) + ) + + def forward(self, x): + res = x + + #x = checkpoint(self.layers, x) + x = self.layers(x) + + return x + res + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): + super().__init__() + + self.layers = nn.Sequential( + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=9, use_snake=use_snake), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), + ) + + def forward(self, x): + return self.layers(x) + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): + super().__init__() + + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, + stride=1, + bias=False, + padding='same') + ) + else: + upsample_layer = WNConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + upsample_layer, + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=9, use_snake=use_snake), + ) + + def forward(self, x): + return self.layers(x) + +class OobleckEncoder(nn.Module): + def __init__(self, + in_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False + ): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) + ] + + for i in range(self.depth-1): + layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), + WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class OobleckDecoder(nn.Module): + def __init__(self, + out_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=True): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), + ] + + for i in range(self.depth-1, 0, -1): + layers += [DecoderBlock( + in_channels=c_mults[i]*channels, + out_channels=c_mults[i-1]*channels, + stride=strides[i-1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample + ) + ] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), + WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), + nn.Tanh() if final_tanh else nn.Identity() + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class DACEncoderWrapper(nn.Module): + def __init__(self, in_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Encoder as DACEncoder + + latent_dim = kwargs.pop("latent_dim", None) + + encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) + self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) + self.latent_dim = latent_dim + + # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility + self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() + + if in_channels != 1: + self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) + + def forward(self, x): + x = self.encoder(x) + x = self.proj_out(x) + return x + +class DACDecoderWrapper(nn.Module): + def __init__(self, latent_dim, out_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Decoder as DACDecoder + + self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) + + self.latent_dim = latent_dim + + def forward(self, x): + return self.decoder(x) + +class AudioAutoencoder(nn.Module): + def __init__( + self, + encoder, + decoder, + latent_dim, + downsampling_ratio, + sample_rate, + io_channels=2, + bottleneck: Bottleneck = None, + pretransform: Pretransform = None, + in_channels = None, + out_channels = None, + soft_clip = False + ): + super().__init__() + + self.downsampling_ratio = downsampling_ratio + self.sample_rate = sample_rate + + self.latent_dim = latent_dim + self.io_channels = io_channels + self.in_channels = io_channels + self.out_channels = io_channels + + self.min_length = self.downsampling_ratio + + if in_channels is not None: + self.in_channels = in_channels + + if out_channels is not None: + self.out_channels = out_channels + + self.bottleneck = bottleneck + + self.encoder = encoder + + self.decoder = decoder + + self.pretransform = pretransform + + self.soft_clip = soft_clip + + self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete + + def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): + + info = {} + + if self.pretransform is not None and not skip_pretransform: + if self.pretransform.enable_grad: + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + else: + with torch.no_grad(): + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + + if self.encoder is not None: + if iterate_batch: + latents = [] + for i in range(audio.shape[0]): + latents.append(self.encoder(audio[i:i+1])) + latents = torch.cat(latents, dim=0) + else: + latents = self.encoder(audio) + else: + latents = audio + + if self.bottleneck is not None: + # TODO: Add iterate batch logic, needs to merge the info dicts + latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) + + info.update(bottleneck_info) + + if return_info: + return latents, info + + return latents + + def decode(self, latents, iterate_batch=False, **kwargs): + + if self.bottleneck is not None: + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.bottleneck.decode(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + latents = self.bottleneck.decode(latents) + + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.decoder(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + decoded = self.decoder(latents, **kwargs) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + if iterate_batch: + decodeds = [] + for i in range(decoded.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + if iterate_batch: + decodeds = [] + for i in range(latents.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + + if self.soft_clip: + decoded = torch.tanh(decoded) + + return decoded + + def decode_tokens(self, tokens, **kwargs): + ''' + Decode discrete tokens to audio + Only works with discrete autoencoders + ''' + + assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" + + latents = self.bottleneck.decode_tokens(tokens, **kwargs) + + return self.decode(latents, **kwargs) + + + def preprocess_audio_for_encoder(self, audio, in_sr): + ''' + Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. + If the model is mono, stereo audio will be converted to mono. + Audio will be silence-padded to be a multiple of the model's downsampling ratio. + Audio will be resampled to the model's sample rate. + The output will have batch size 1 and be shape (1 x Channels x Length) + ''' + return self.preprocess_audio_list_for_encoder([audio], [in_sr]) + + def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): + ''' + Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. + The audio in that list can be of different lengths and channels. + in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. + All audio will be resampled to the model's sample rate. + Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. + If the model is mono, all audio will be converted to mono. + The output will be a tensor of shape (Batch x Channels x Length) + ''' + batch_size = len(audio_list) + if isinstance(in_sr_list, int): + in_sr_list = [in_sr_list]*batch_size + assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" + new_audio = [] + max_length = 0 + # resample & find the max length + for i in range(batch_size): + audio = audio_list[i] + in_sr = in_sr_list[i] + if len(audio.shape) == 3 and audio.shape[0] == 1: + # batchsize 1 was given by accident. Just squeeze it. + audio = audio.squeeze(0) + elif len(audio.shape) == 1: + # Mono signal, channel dimension is missing, unsqueeze it in + audio = audio.unsqueeze(0) + assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" + # Resample audio + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) + audio = resample_tf(audio) + new_audio.append(audio) + if audio.shape[-1] > max_length: + max_length = audio.shape[-1] + # Pad every audio to the same length, multiple of model's downsampling ratio + padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length + for i in range(batch_size): + # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model + new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, + target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) + # convert to tensor + return torch.stack(new_audio) + + def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. + If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. + Overlap and chunk_size params are both measured in number of latents (not audio samples) + # and therefore you likely could use the same values with decode_audio. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Encode the entire audio in parallel + return self.encode(audio, **kwargs) + else: + # CHUNKED ENCODING + # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) + samples_per_latent = self.downsampling_ratio + total_size = audio.shape[2] # in samples + batch_size = audio.shape[0] + chunk_size *= samples_per_latent # converting metric in latents to samples + overlap *= samples_per_latent # converting metric in latents to samples + hop_size = chunk_size - overlap + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = audio[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = audio[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # Note: y_size might be a different value from the latent length used in diffusion training + # because we can encode audio of varying lengths + # However, the audio should've been padded to a multiple of samples_per_latent by now. + y_size = total_size // samples_per_latent + # Create an empty latent, we will populate it with chunks as we encode them + y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # encode the chunk + y_chunk = self.encode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size // samples_per_latent + t_end = t_start + chunk_size // samples_per_latent + # remove the edges of the overlaps + ol = overlap//samples_per_latent//2 + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Decode latents to audio. + If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Decode the entire latent in parallel + return self.decode(latents, **kwargs) + else: + # chunked decoding + hop_size = chunk_size - overlap + total_size = latents.shape[2] + batch_size = latents.shape[0] + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = latents[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = latents[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # samples_per_latent is just the downsampling ratio + samples_per_latent = self.downsampling_ratio + # Create an empty waveform, we will populate it with chunks as decode them + y_size = total_size * samples_per_latent + y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # decode the chunk + y_chunk = self.decode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size * samples_per_latent + t_end = t_start + chunk_size * samples_per_latent + # remove the edges of the overlaps + ol = (overlap//2) * samples_per_latent + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + +# AE factories + +def create_encoder_from_config(encoder_config: Dict[str, Any]): + encoder_type = encoder_config.get("type", None) + assert encoder_type is not None, "Encoder type must be specified" + + if encoder_type == "oobleck": + encoder = OobleckEncoder( + **encoder_config["config"] + ) + + elif encoder_type == "seanet": + from encodec.modules import SEANetEncoder + seanet_encoder_config = encoder_config["config"] + + #SEANet encoder expects strides in reverse order + seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) + encoder = SEANetEncoder( + **seanet_encoder_config + ) + elif encoder_type == "dac": + dac_config = encoder_config["config"] + + encoder = DACEncoderWrapper(**dac_config) + elif encoder_type == "local_attn": + from .local_attention import TransformerEncoder1D + + local_attn_config = encoder_config["config"] + + encoder = TransformerEncoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown encoder type {encoder_type}") + + requires_grad = encoder_config.get("requires_grad", True) + if not requires_grad: + for param in encoder.parameters(): + param.requires_grad = False + + return encoder + +def create_decoder_from_config(decoder_config: Dict[str, Any]): + decoder_type = decoder_config.get("type", None) + assert decoder_type is not None, "Decoder type must be specified" + + if decoder_type == "oobleck": + decoder = OobleckDecoder( + **decoder_config["config"] + ) + elif decoder_type == "seanet": + from encodec.modules import SEANetDecoder + + decoder = SEANetDecoder( + **decoder_config["config"] + ) + elif decoder_type == "dac": + dac_config = decoder_config["config"] + + decoder = DACDecoderWrapper(**dac_config) + elif decoder_type == "local_attn": + from .local_attention import TransformerDecoder1D + + local_attn_config = decoder_config["config"] + + decoder = TransformerDecoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown decoder type {decoder_type}") + + requires_grad = decoder_config.get("requires_grad", True) + if not requires_grad: + for param in decoder.parameters(): + param.requires_grad = False + + return decoder + +def create_autoencoder_from_config(config: Dict[str, Any]): + + ae_config = config["model"] + + encoder = create_encoder_from_config(ae_config["encoder"]) + decoder = create_decoder_from_config(ae_config["decoder"]) + + bottleneck = ae_config.get("bottleneck", None) + + latent_dim = ae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = ae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = ae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + in_channels = ae_config.get("in_channels", None) + out_channels = ae_config.get("out_channels", None) + + pretransform = ae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + soft_clip = ae_config["decoder"].get("soft_clip", False) + + return AudioAutoencoder( + encoder, + decoder, + io_channels=io_channels, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + sample_rate=sample_rate, + bottleneck=bottleneck, + pretransform=pretransform, + in_channels=in_channels, + out_channels=out_channels, + soft_clip=soft_clip + ) \ No newline at end of file diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/blocks-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/blocks-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..cb310c8980ef5dc0f138e6f9f3478d4cdc63354d --- /dev/null +++ b/src/modules/stable_vae/models/.ipynb_checkpoints/blocks-checkpoint.py @@ -0,0 +1,359 @@ +from functools import reduce +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torch.backends.cuda import sdp_kernel +from packaging import version + +from .nn.layers import Snake1d + + +class ResidualBlock(nn.Module): + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return self.main(input) + self.skip(input) + + +class ResConvBlock(ResidualBlock): + def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): + skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) + super().__init__([ + nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_mid), + Snake1d(c_mid) if use_snake else nn.GELU(), + nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), + (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), + ], skip) + + +class SelfAttention1d(nn.Module): + def __init__(self, c_in, n_head=1, dropout_rate=0.): + super().__init__() + assert c_in % n_head == 0 + self.norm = nn.GroupNorm(1, c_in) + self.n_head = n_head + self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) + self.out_proj = nn.Conv1d(c_in, c_in, 1) + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward(self, input): + n, c, s = input.shape + qkv = self.qkv_proj(self.norm(input)) + qkv = qkv.view( + [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) + q, k, v = qkv.chunk(3, dim=1) + scale = k.shape[3]**-0.25 + + if self.use_flash: + with sdp_kernel(*self.sdp_kernel_config): + y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) + else: + att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) + y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) + + + return input + self.dropout(self.out_proj(y)) + + +class SkipBlock(nn.Module): + def __init__(self, *main): + super().__init__() + self.main = nn.Sequential(*main) + + def forward(self, input): + return torch.cat([self.main(input), input], dim=1) + + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter(torch.randn( + [out_features // 2, in_features]) * std) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + + +def expand_to_planes(input, shape): + return input[..., None].repeat([1, 1, shape[2]]) + +_kernels = { + 'linear': + [1 / 8, 3 / 8, 3 / 8, 1 / 8], + 'cubic': + [-0.01171875, -0.03515625, 0.11328125, 0.43359375, + 0.43359375, 0.11328125, -0.03515625, -0.01171875], + 'lanczos3': + [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, + -0.066637322306633, 0.13550527393817902, 0.44638532400131226, + 0.44638532400131226, 0.13550527393817902, -0.066637322306633, + -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] +} + + +class Downsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, (self.pad,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv1d(x, weight, stride=2) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +class Upsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +def Downsample1d_2( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor, + padding=factor * (kernel_multiplier // 2), + ) + + +def Upsample1d_2( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ), + ) + else: + return nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor, + padding=factor // 2 + factor % 2, + output_padding=factor % 2, + ) + + +def zero_init(layer): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + return layer + + +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +#rms_norm = torch.compile(rms_norm) + +class AdaRMSNorm(nn.Module): + def __init__(self, features, cond_features, eps=1e-6): + super().__init__() + self.eps = eps + self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) + + def extra_repr(self): + return f"eps={self.eps}," + + def forward(self, x, cond): + return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) + + +def normalize(x, eps=1e-4): + dim = list(range(1, x.ndim)) + n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) + alpha = np.sqrt(n.numel() / x.numel()) + return x / torch.add(eps, n, alpha=alpha) + + +class ForcedWNConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1): + super().__init__() + self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) + + def forward(self, x): + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(self.weight)) + + fan_in = self.weight[0].numel() + + w = normalize(self.weight) / math.sqrt(fan_in) + + return F.conv1d(x, w, padding='same') + +# Kernels + +use_compile = True + +def compile(function, *args, **kwargs): + if not use_compile: + return function + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + +@compile +def linear_geglu(x, weight, bias=None): + x = x @ weight.mT + if bias is not None: + x = x + bias + x, gate = x.chunk(2, dim=-1) + return x * F.gelu(gate) + + +@compile +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +# Layers + + +class LinearGEGLU(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features * 2, bias=bias) + self.out_features = out_features + + def forward(self, x): + return linear_geglu(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, shape, fix_scale = False, eps=1e-6): + super().__init__() + self.eps = eps + + if fix_scale: + self.register_buffer("scale", torch.ones(shape)) + else: + self.scale = nn.Parameter(torch.ones(shape)) + + def extra_repr(self): + return f"shape={tuple(self.scale.shape)}, eps={self.eps}" + + def forward(self, x): + return rms_norm(x, self.scale, self.eps) + + +# jit script make it 1.4x faster and save GPU memory +@torch.jit.script +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + +# try: +# snake_beta = torch.compile(snake_beta) +# except RuntimeError: +# pass + + +# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license +# License available in LICENSES/LICENSE_NVIDIA.txt +class SnakeBeta(nn.Module): + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: + # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: + # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + # self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x \ No newline at end of file diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/bottleneck-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/bottleneck-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..df88c5f1b1f5fa3675c1a42f42e5e31e27d00ed3 --- /dev/null +++ b/src/modules/stable_vae/models/.ipynb_checkpoints/bottleneck-checkpoint.py @@ -0,0 +1,346 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from einops import rearrange +from vector_quantize_pytorch import ResidualVQ, FSQ +from .nn.quantize import ResidualVectorQuantize as DACResidualVQ + + +class Bottleneck(nn.Module): + def __init__(self, is_discrete: bool = False): + super().__init__() + + self.is_discrete = is_discrete + + def encode(self, x, return_info=False, **kwargs): + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError + + +class DiscreteBottleneck(Bottleneck): + def __init__(self, num_quantizers, codebook_size, tokens_id): + super().__init__(is_discrete=True) + + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + self.tokens_id = tokens_id + + def decode_tokens(self, codes, **kwargs): + raise NotImplementedError + + +class TanhBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + self.tanh = nn.Tanh() + + def encode(self, x, return_info=False): + info = {} + + x = torch.tanh(x) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + +@torch.jit.script +def vae_sample_kl(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + latents = torch.randn_like(mean) * stdev + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + return latents, kl + + +@torch.jit.script +def vae_sample(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + latents = torch.randn_like(mean) * stdev + mean + return latents + + +class VAEBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False, **kwargs): + mean, scale = x.chunk(2, dim=1) + + if return_info: + info = {} + x, kl = vae_sample_kl(mean, scale) + info["kl"] = kl + return x, info + else: + x = vae_sample(mean, scale) + return x + + def decode(self, x): + return x + + +def compute_mean_kernel(x, y): + kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] + return torch.exp(-kernel_input).mean() + + +def compute_mmd(latents): + latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) + noise = torch.randn_like(latents_reshaped) + + latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) + noise_kernel = compute_mean_kernel(noise, noise) + latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) + + mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel + return mmd.mean() + + +class WassersteinBottleneck(Bottleneck): + def __init__(self, noise_augment_dim: int = 0): + super().__init__(is_discrete=False) + + self.noise_augment_dim = noise_augment_dim + + def encode(self, x, return_info=False): + info = {} + + if self.training and return_info: + mmd = compute_mmd(x) + info["mmd"] = mmd + + if return_info: + return x, info + + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + + +class L2Bottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False): + info = {} + + x = F.normalize(x, dim=1) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return F.normalize(x, dim=1) + + +class RVQBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False, **kwargs): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + + +class RVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False): + info = {} + + x, kl = vae_sample(*x.chunk(2, dim=1)) + + info["kl"] = kl + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + + +class DACRVQBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, **kwargs): + info = {} + + info["pre_quantizer"] = x + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + + +class DACRVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, n_quantizers: int = None): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["pre_quantizer"] = x + info["kl"] = kl + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + + +class FSQBottleneck(DiscreteBottleneck): + def __init__(self, dim, levels): + super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices") + self.quantizer = FSQ(levels=[levels] * dim) + + def encode(self, x, return_info=False): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, tokens, **kwargs): + latents = self.quantizer.indices_to_codes(tokens) + + return self.decode(latents, **kwargs) \ No newline at end of file diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/factory-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/factory-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..4188703000ee176342c7f329342f18d6fe747b04 --- /dev/null +++ b/src/modules/stable_vae/models/.ipynb_checkpoints/factory-checkpoint.py @@ -0,0 +1,153 @@ +import json + +def create_model_from_config(model_config): + model_type = model_config.get('model_type', None) + + assert model_type is not None, 'model_type must be specified in model config' + + if model_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + return create_autoencoder_from_config(model_config) + elif model_type == 'diffusion_uncond': + from .diffusion import create_diffusion_uncond_from_config + return create_diffusion_uncond_from_config(model_config) + elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior": + from .diffusion import create_diffusion_cond_from_config + return create_diffusion_cond_from_config(model_config) + elif model_type == 'diffusion_autoencoder': + from .autoencoders import create_diffAE_from_config + return create_diffAE_from_config(model_config) + elif model_type == 'lm': + from .lm import create_audio_lm_from_config + return create_audio_lm_from_config(model_config) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + +def create_model_from_config_path(model_config_path): + with open(model_config_path) as f: + model_config = json.load(f) + + return create_model_from_config(model_config) + +def create_pretransform_from_config(pretransform_config, sample_rate): + pretransform_type = pretransform_config.get('type', None) + + assert pretransform_type is not None, 'type must be specified in pretransform config' + + if pretransform_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + from .pretransforms import AutoencoderPretransform + + # Create fake top-level config to pass sample rate to autoencoder constructor + # This is a bit of a hack but it keeps us from re-defining the sample rate in the config + autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} + autoencoder = create_autoencoder_from_config(autoencoder_config) + + scale = pretransform_config.get("scale", 1.0) + model_half = pretransform_config.get("model_half", False) + iterate_batch = pretransform_config.get("iterate_batch", False) + chunked = pretransform_config.get("chunked", False) + + pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) + elif pretransform_type == 'wavelet': + from .pretransforms import WaveletPretransform + + wavelet_config = pretransform_config["config"] + channels = wavelet_config["channels"] + levels = wavelet_config["levels"] + wavelet = wavelet_config["wavelet"] + + pretransform = WaveletPretransform(channels, levels, wavelet) + elif pretransform_type == 'pqmf': + from .pretransforms import PQMFPretransform + pqmf_config = pretransform_config["config"] + pretransform = PQMFPretransform(**pqmf_config) + elif pretransform_type == 'dac_pretrained': + from .pretransforms import PretrainedDACPretransform + pretrained_dac_config = pretransform_config["config"] + pretransform = PretrainedDACPretransform(**pretrained_dac_config) + elif pretransform_type == "audiocraft_pretrained": + from .pretransforms import AudiocraftCompressionPretransform + + audiocraft_config = pretransform_config["config"] + pretransform = AudiocraftCompressionPretransform(**audiocraft_config) + else: + raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') + + enable_grad = pretransform_config.get('enable_grad', False) + pretransform.enable_grad = enable_grad + + pretransform.eval().requires_grad_(pretransform.enable_grad) + + return pretransform + +def create_bottleneck_from_config(bottleneck_config): + bottleneck_type = bottleneck_config.get('type', None) + + assert bottleneck_type is not None, 'type must be specified in bottleneck config' + + if bottleneck_type == 'tanh': + from .bottleneck import TanhBottleneck + bottleneck = TanhBottleneck() + elif bottleneck_type == 'vae': + from .bottleneck import VAEBottleneck + bottleneck = VAEBottleneck() + elif bottleneck_type == 'rvq': + from .bottleneck import RVQBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQBottleneck(**quantizer_params) + elif bottleneck_type == "dac_rvq": + from .bottleneck import DACRVQBottleneck + + bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) + + elif bottleneck_type == 'rvq_vae': + from .bottleneck import RVQVAEBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQVAEBottleneck(**quantizer_params) + + elif bottleneck_type == 'dac_rvq_vae': + from .bottleneck import DACRVQVAEBottleneck + bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) + elif bottleneck_type == 'l2_norm': + from .bottleneck import L2Bottleneck + bottleneck = L2Bottleneck() + elif bottleneck_type == "wasserstein": + from .bottleneck import WassersteinBottleneck + bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) + elif bottleneck_type == "fsq": + from .bottleneck import FSQBottleneck + bottleneck = FSQBottleneck(**bottleneck_config["config"]) + else: + raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') + + requires_grad = bottleneck_config.get('requires_grad', True) + if not requires_grad: + for param in bottleneck.parameters(): + param.requires_grad = False + + return bottleneck diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/pretransforms-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/pretransforms-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..c9942db59908ce8135f44e45090e928f6c60393a --- /dev/null +++ b/src/modules/stable_vae/models/.ipynb_checkpoints/pretransforms-checkpoint.py @@ -0,0 +1,258 @@ +import torch +from einops import rearrange +from torch import nn + +class Pretransform(nn.Module): + def __init__(self, enable_grad, io_channels, is_discrete): + super().__init__() + + self.is_discrete = is_discrete + self.io_channels = io_channels + self.encoded_channels = None + self.downsampling_ratio = None + + self.enable_grad = enable_grad + + def encode(self, x): + raise NotImplementedError + + def decode(self, z): + raise NotImplementedError + + def tokenize(self, x): + raise NotImplementedError + + def decode_tokens(self, tokens): + raise NotImplementedError + +class AutoencoderPretransform(Pretransform): + def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): + super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) + self.model = model + self.model.requires_grad_(False).eval() + self.scale=scale + self.downsampling_ratio = model.downsampling_ratio + self.io_channels = model.io_channels + self.sample_rate = model.sample_rate + + self.model_half = model_half + self.iterate_batch = iterate_batch + + self.encoded_channels = model.latent_dim + + self.chunked = chunked + self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None + self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None + + if self.model_half: + self.model.half() + + def encode(self, x, **kwargs): + + if self.model_half: + x = x.half() + self.model.to(torch.float16) + + encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + encoded = encoded.float() + + return encoded / self.scale + + def decode(self, z, **kwargs): + z = z * self.scale + + if self.model_half: + z = z.half() + self.model.to(torch.float16) + + decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + decoded = decoded.float() + + return decoded + + def tokenize(self, x, **kwargs): + assert self.model.is_discrete, "Cannot tokenize with a continuous model" + + _, info = self.model.encode(x, return_info = True, **kwargs) + + return info[self.model.bottleneck.tokens_id] + + def decode_tokens(self, tokens, **kwargs): + assert self.model.is_discrete, "Cannot decode tokens with a continuous model" + + return self.model.decode_tokens(tokens, **kwargs) + + def load_state_dict(self, state_dict, strict=True): + self.model.load_state_dict(state_dict, strict=strict) + +class WaveletPretransform(Pretransform): + def __init__(self, channels, levels, wavelet): + super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) + + from .wavelets import WaveletEncode1d, WaveletDecode1d + + self.encoder = WaveletEncode1d(channels, levels, wavelet) + self.decoder = WaveletDecode1d(channels, levels, wavelet) + + self.downsampling_ratio = 2 ** levels + self.io_channels = channels + self.encoded_channels = channels * self.downsampling_ratio + + def encode(self, x): + return self.encoder(x) + + def decode(self, z): + return self.decoder(z) + +class PQMFPretransform(Pretransform): + def __init__(self, attenuation=100, num_bands=16): + # TODO: Fix PQMF to take in in-channels + super().__init__(enable_grad=False, io_channels=1, is_discrete=False) + from .pqmf import PQMF + self.pqmf = PQMF(attenuation, num_bands) + + + def encode(self, x): + # x is (Batch x Channels x Time) + x = self.pqmf.forward(x) + # pqmf.forward returns (Batch x Channels x Bands x Time) + # but Pretransform needs Batch x Channels x Time + # so concatenate channels and bands into one axis + return rearrange(x, "b c n t -> b (c n) t") + + def decode(self, x): + # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) + x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) + # returns (Batch x Channels x Time) + return self.pqmf.inverse(x) + +class PretrainedDACPretransform(Pretransform): + def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + import dac + + model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) + + self.model = dac.DAC.load(model_path) + + self.quantize_on_decode = quantize_on_decode + + if model_type == "44khz": + self.downsampling_ratio = 512 + else: + self.downsampling_ratio = 320 + + self.io_channels = 1 + + self.scale = scale + + self.chunked = chunked + + self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.n_codebooks + + self.codebook_size = self.model.codebook_size + + def encode(self, x): + + latents = self.model.encoder(x) + + if self.quantize_on_decode: + output = latents + else: + z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + output = z + + if self.scale != 1.0: + output = output / self.scale + + return output + + def decode(self, z): + + if self.scale != 1.0: + z = z * self.scale + + if self.quantize_on_decode: + z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + return self.model.decode(z) + + def tokenize(self, x): + return self.model.encode(x)[1] + + def decode_tokens(self, tokens): + latents = self.model.quantizer.from_codes(tokens) + return self.model.decode(latents) + +class AudiocraftCompressionPretransform(Pretransform): + def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + try: + from audiocraft.models import CompressionModel + except ImportError: + raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") + + self.model = CompressionModel.get_pretrained(model_type) + + self.quantize_on_decode = quantize_on_decode + + self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) + + self.sample_rate = self.model.sample_rate + + self.io_channels = self.model.channels + + self.scale = scale + + #self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.num_codebooks + + self.codebook_size = self.model.cardinality + + self.model.to(torch.float16).eval().requires_grad_(False) + + def encode(self, x): + + assert False, "Audiocraft compression models do not support continuous encoding" + + # latents = self.model.encoder(x) + + # if self.quantize_on_decode: + # output = latents + # else: + # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + # output = z + + # if self.scale != 1.0: + # output = output / self.scale + + # return output + + def decode(self, z): + + assert False, "Audiocraft compression models do not support continuous decoding" + + # if self.scale != 1.0: + # z = z * self.scale + + # if self.quantize_on_decode: + # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + # return self.model.decode(z) + + def tokenize(self, x): + with torch.cuda.amp.autocast(enabled=False): + return self.model.encode(x.to(torch.float16))[0] + + def decode_tokens(self, tokens): + with torch.cuda.amp.autocast(enabled=False): + return self.model.decode(tokens) diff --git a/src/modules/stable_vae/models/.ipynb_checkpoints/utils-checkpoint.py b/src/modules/stable_vae/models/.ipynb_checkpoints/utils-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..ec8eeaf773d47db2c000a3b2237d88d310214dcf --- /dev/null +++ b/src/modules/stable_vae/models/.ipynb_checkpoints/utils-checkpoint.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torchaudio import transforms as T + + +class PadCrop(nn.Module): + def __init__(self, n_samples, randomize=True): + super().__init__() + self.n_samples = n_samples + self.randomize = randomize + + def __call__(self, signal): + n, s = signal.shape + start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() + end = start + self.n_samples + output = signal.new_zeros([n, self.n_samples]) + output[:, :min(s, self.n_samples)] = signal[:, start:end] + return output + + +def set_audio_channels(audio, target_channels): + if target_channels == 1: + # Convert to mono + audio = audio.mean(1, keepdim=True) + elif target_channels == 2: + # Convert to stereo + if audio.shape[1] == 1: + audio = audio.repeat(1, 2, 1) + elif audio.shape[1] > 2: + audio = audio[:, :2, :] + return audio + +def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): + + audio = audio.to(device) + + if in_sr != target_sr: + resample_tf = T.Resample(in_sr, target_sr).to(device) + audio = resample_tf(audio) + + audio = PadCrop(target_length, randomize=False)(audio) + + # Add batch dimension + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + elif audio.dim() == 2: + audio = audio.unsqueeze(0) + + audio = set_audio_channels(audio, target_channels) + + return audio \ No newline at end of file diff --git a/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4213aeeee249e1d2a179e446ba30328d221ea2cd Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..08682d55156c6ecc615f63e6f5965578b382715f Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/autoencoders.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/blocks.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/blocks.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..13098e65f73ca56e09af304cc58a5b91e09b9230 Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/blocks.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/blocks.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/blocks.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..0cda754279d6057bf47564662d75126fdc161716 Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/blocks.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..ed2e2fecc81e0df784bbe5fc60247202f5916b43 Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..280282dd1b335001ac67af2d0a6f567e498544aa Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/bottleneck.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/dac_layers.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/dac_layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..6023bf14f71f20035d26f964428a5cf099bce8ae Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/dac_layers.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/diffusion.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/diffusion.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..11fc6d584244a08c1ac553468ff9630b7fdfa42f Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/diffusion.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/factory.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/factory.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..f1103c38bb0506ad111c4a1231a93aa9bb117791 Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/factory.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/factory.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/factory.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..112c3a4193ea3cc5b7fe48acce3399214d9cb634 Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/factory.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..735577fd4a79f20b8b1da0a4fc546714ab048eca Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d72ad1a597cb988276627b751e986040e489542d Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/pretransforms.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/utils.cpython-310.pyc b/src/modules/stable_vae/models/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..18170d2cfdfc233f2d04fde448d30c4a8e5a12c9 Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/utils.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/__pycache__/utils.cpython-311.pyc b/src/modules/stable_vae/models/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..79028d2ae1b49d2819c08ab69353eb6f15e2e455 Binary files /dev/null and b/src/modules/stable_vae/models/__pycache__/utils.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/autoencoders.py b/src/modules/stable_vae/models/autoencoders.py new file mode 100644 index 0000000000000000000000000000000000000000..2741d45e70e3c25bc28f2b4e43b1a3e925a4c9e3 --- /dev/null +++ b/src/modules/stable_vae/models/autoencoders.py @@ -0,0 +1,683 @@ +import torch +import math +import numpy as np + +from torch import nn +from torch.nn import functional as F +from torchaudio import transforms as T +from alias_free_torch import Activation1d +from .nn.layers import WNConv1d, WNConvTranspose1d +from typing import Literal, Dict, Any + +# from .inference.sampling import sample +from .utils import prepare_audio +from .blocks import SnakeBeta +from .bottleneck import Bottleneck, DiscreteBottleneck +from .factory import create_pretransform_from_config, create_bottleneck_from_config +from .pretransforms import Pretransform + +def checkpoint(function, *args, **kwargs): + kwargs.setdefault("use_reentrant", False) + return torch.utils.checkpoint.checkpoint(function, *args, **kwargs) + +def get_activation(activation: Literal["elu", "snake", "none"], antialias=False, channels=None) -> nn.Module: + if activation == "elu": + act = nn.ELU() + elif activation == "snake": + act = SnakeBeta(channels) + elif activation == "none": + act = nn.Identity() + else: + raise ValueError(f"Unknown activation {activation}") + + if antialias: + act = Activation1d(act) + + return act + +class ResidualUnit(nn.Module): + def __init__(self, in_channels, out_channels, dilation, use_snake=False, antialias_activation=False): + super().__init__() + + self.dilation = dilation + + padding = (dilation * (7-1)) // 2 + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=7, dilation=dilation, padding=padding), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=out_channels), + WNConv1d(in_channels=out_channels, out_channels=out_channels, + kernel_size=1) + ) + + def forward(self, x): + res = x + + #x = checkpoint(self.layers, x) + x = self.layers(x) + + return x + res + +class EncoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False): + super().__init__() + + self.layers = nn.Sequential( + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=in_channels, + out_channels=in_channels, dilation=9, use_snake=use_snake), + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + WNConv1d(in_channels=in_channels, out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)), + ) + + def forward(self, x): + return self.layers(x) + +class DecoderBlock(nn.Module): + def __init__(self, in_channels, out_channels, stride, use_snake=False, antialias_activation=False, use_nearest_upsample=False): + super().__init__() + + if use_nearest_upsample: + upsample_layer = nn.Sequential( + nn.Upsample(scale_factor=stride, mode="nearest"), + WNConv1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, + stride=1, + bias=False, + padding='same') + ) + else: + upsample_layer = WNConvTranspose1d(in_channels=in_channels, + out_channels=out_channels, + kernel_size=2*stride, stride=stride, padding=math.ceil(stride/2)) + + self.layers = nn.Sequential( + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=in_channels), + upsample_layer, + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=1, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=3, use_snake=use_snake), + ResidualUnit(in_channels=out_channels, out_channels=out_channels, + dilation=9, use_snake=use_snake), + ) + + def forward(self, x): + return self.layers(x) + +class OobleckEncoder(nn.Module): + def __init__(self, + in_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False + ): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=in_channels, out_channels=c_mults[0] * channels, kernel_size=7, padding=3) + ] + + for i in range(self.depth-1): + layers += [EncoderBlock(in_channels=c_mults[i]*channels, out_channels=c_mults[i+1]*channels, stride=strides[i], use_snake=use_snake)] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[-1] * channels), + WNConv1d(in_channels=c_mults[-1]*channels, out_channels=latent_dim, kernel_size=3, padding=1) + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class OobleckDecoder(nn.Module): + def __init__(self, + out_channels=2, + channels=128, + latent_dim=32, + c_mults = [1, 2, 4, 8], + strides = [2, 4, 8, 8], + use_snake=False, + antialias_activation=False, + use_nearest_upsample=False, + final_tanh=True): + super().__init__() + + c_mults = [1] + c_mults + + self.depth = len(c_mults) + + layers = [ + WNConv1d(in_channels=latent_dim, out_channels=c_mults[-1]*channels, kernel_size=7, padding=3), + ] + + for i in range(self.depth-1, 0, -1): + layers += [DecoderBlock( + in_channels=c_mults[i]*channels, + out_channels=c_mults[i-1]*channels, + stride=strides[i-1], + use_snake=use_snake, + antialias_activation=antialias_activation, + use_nearest_upsample=use_nearest_upsample + ) + ] + + layers += [ + get_activation("snake" if use_snake else "elu", antialias=antialias_activation, channels=c_mults[0] * channels), + WNConv1d(in_channels=c_mults[0] * channels, out_channels=out_channels, kernel_size=7, padding=3, bias=False), + nn.Tanh() if final_tanh else nn.Identity() + ] + + self.layers = nn.Sequential(*layers) + + def forward(self, x): + return self.layers(x) + + +class DACEncoderWrapper(nn.Module): + def __init__(self, in_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Encoder as DACEncoder + + latent_dim = kwargs.pop("latent_dim", None) + + encoder_out_dim = kwargs["d_model"] * (2 ** len(kwargs["strides"])) + self.encoder = DACEncoder(d_latent=encoder_out_dim, **kwargs) + self.latent_dim = latent_dim + + # Latent-dim support was added to DAC after this was first written, and implemented differently, so this is for backwards compatibility + self.proj_out = nn.Conv1d(self.encoder.enc_dim, latent_dim, kernel_size=1) if latent_dim is not None else nn.Identity() + + if in_channels != 1: + self.encoder.block[0] = WNConv1d(in_channels, kwargs.get("d_model", 64), kernel_size=7, padding=3) + + def forward(self, x): + x = self.encoder(x) + x = self.proj_out(x) + return x + +class DACDecoderWrapper(nn.Module): + def __init__(self, latent_dim, out_channels=1, **kwargs): + super().__init__() + + from dac.model.dac import Decoder as DACDecoder + + self.decoder = DACDecoder(**kwargs, input_channel = latent_dim, d_out=out_channels) + + self.latent_dim = latent_dim + + def forward(self, x): + return self.decoder(x) + +class AudioAutoencoder(nn.Module): + def __init__( + self, + encoder, + decoder, + latent_dim, + downsampling_ratio, + sample_rate, + io_channels=2, + bottleneck: Bottleneck = None, + pretransform: Pretransform = None, + in_channels = None, + out_channels = None, + soft_clip = False + ): + super().__init__() + + self.downsampling_ratio = downsampling_ratio + self.sample_rate = sample_rate + + self.latent_dim = latent_dim + self.io_channels = io_channels + self.in_channels = io_channels + self.out_channels = io_channels + + self.min_length = self.downsampling_ratio + + if in_channels is not None: + self.in_channels = in_channels + + if out_channels is not None: + self.out_channels = out_channels + + self.bottleneck = bottleneck + + self.encoder = encoder + + self.decoder = decoder + + self.pretransform = pretransform + + self.soft_clip = soft_clip + + self.is_discrete = self.bottleneck is not None and self.bottleneck.is_discrete + + def encode(self, audio, return_info=False, skip_pretransform=False, iterate_batch=False, **kwargs): + + info = {} + + if self.pretransform is not None and not skip_pretransform: + if self.pretransform.enable_grad: + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + else: + with torch.no_grad(): + if iterate_batch: + audios = [] + for i in range(audio.shape[0]): + audios.append(self.pretransform.encode(audio[i:i+1])) + audio = torch.cat(audios, dim=0) + else: + audio = self.pretransform.encode(audio) + + if self.encoder is not None: + if iterate_batch: + latents = [] + for i in range(audio.shape[0]): + latents.append(self.encoder(audio[i:i+1])) + latents = torch.cat(latents, dim=0) + else: + latents = self.encoder(audio) + else: + latents = audio + + if self.bottleneck is not None: + # TODO: Add iterate batch logic, needs to merge the info dicts + latents, bottleneck_info = self.bottleneck.encode(latents, return_info=True, **kwargs) + + info.update(bottleneck_info) + + if return_info: + return latents, info + + return latents + + def decode(self, latents, iterate_batch=False, **kwargs): + + if self.bottleneck is not None: + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.bottleneck.decode(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + latents = self.bottleneck.decode(latents) + + if iterate_batch: + decoded = [] + for i in range(latents.shape[0]): + decoded.append(self.decoder(latents[i:i+1])) + decoded = torch.cat(decoded, dim=0) + else: + decoded = self.decoder(latents, **kwargs) + + if self.pretransform is not None: + if self.pretransform.enable_grad: + if iterate_batch: + decodeds = [] + for i in range(decoded.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + else: + with torch.no_grad(): + if iterate_batch: + decodeds = [] + for i in range(latents.shape[0]): + decodeds.append(self.pretransform.decode(decoded[i:i+1])) + decoded = torch.cat(decodeds, dim=0) + else: + decoded = self.pretransform.decode(decoded) + + if self.soft_clip: + decoded = torch.tanh(decoded) + + return decoded + + def decode_tokens(self, tokens, **kwargs): + ''' + Decode discrete tokens to audio + Only works with discrete autoencoders + ''' + + assert isinstance(self.bottleneck, DiscreteBottleneck), "decode_tokens only works with discrete autoencoders" + + latents = self.bottleneck.decode_tokens(tokens, **kwargs) + + return self.decode(latents, **kwargs) + + + def preprocess_audio_for_encoder(self, audio, in_sr): + ''' + Preprocess single audio tensor (Channels x Length) to be compatible with the encoder. + If the model is mono, stereo audio will be converted to mono. + Audio will be silence-padded to be a multiple of the model's downsampling ratio. + Audio will be resampled to the model's sample rate. + The output will have batch size 1 and be shape (1 x Channels x Length) + ''' + return self.preprocess_audio_list_for_encoder([audio], [in_sr]) + + def preprocess_audio_list_for_encoder(self, audio_list, in_sr_list): + ''' + Preprocess a [list] of audio (Channels x Length) into a batch tensor to be compatable with the encoder. + The audio in that list can be of different lengths and channels. + in_sr can be an integer or list. If it's an integer it will be assumed it is the input sample_rate for every audio. + All audio will be resampled to the model's sample rate. + Audio will be silence-padded to the longest length, and further padded to be a multiple of the model's downsampling ratio. + If the model is mono, all audio will be converted to mono. + The output will be a tensor of shape (Batch x Channels x Length) + ''' + batch_size = len(audio_list) + if isinstance(in_sr_list, int): + in_sr_list = [in_sr_list]*batch_size + assert len(in_sr_list) == batch_size, "list of sample rates must be the same length of audio_list" + new_audio = [] + max_length = 0 + # resample & find the max length + for i in range(batch_size): + audio = audio_list[i] + in_sr = in_sr_list[i] + if len(audio.shape) == 3 and audio.shape[0] == 1: + # batchsize 1 was given by accident. Just squeeze it. + audio = audio.squeeze(0) + elif len(audio.shape) == 1: + # Mono signal, channel dimension is missing, unsqueeze it in + audio = audio.unsqueeze(0) + assert len(audio.shape)==2, "Audio should be shape (Channels x Length) with no batch dimension" + # Resample audio + if in_sr != self.sample_rate: + resample_tf = T.Resample(in_sr, self.sample_rate).to(audio.device) + audio = resample_tf(audio) + new_audio.append(audio) + if audio.shape[-1] > max_length: + max_length = audio.shape[-1] + # Pad every audio to the same length, multiple of model's downsampling ratio + padded_audio_length = max_length + (self.min_length - (max_length % self.min_length)) % self.min_length + for i in range(batch_size): + # Pad it & if necessary, mixdown/duplicate stereo/mono channels to support model + new_audio[i] = prepare_audio(new_audio[i], in_sr=in_sr, target_sr=in_sr, target_length=padded_audio_length, + target_channels=self.in_channels, device=new_audio[i].device).squeeze(0) + # convert to tensor + return torch.stack(new_audio) + + def encode_audio(self, audio, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Encode audios into latents. Audios should already be preprocesed by preprocess_audio_for_encoder. + If chunked is True, split the audio into chunks of a given maximum size chunk_size, with given overlap. + Overlap and chunk_size params are both measured in number of latents (not audio samples) + # and therefore you likely could use the same values with decode_audio. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked output and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Encode the entire audio in parallel + return self.encode(audio, **kwargs) + else: + # CHUNKED ENCODING + # samples_per_latent is just the downsampling ratio (which is also the upsampling ratio) + samples_per_latent = self.downsampling_ratio + total_size = audio.shape[2] # in samples + batch_size = audio.shape[0] + chunk_size *= samples_per_latent # converting metric in latents to samples + overlap *= samples_per_latent # converting metric in latents to samples + hop_size = chunk_size - overlap + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = audio[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = audio[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # Note: y_size might be a different value from the latent length used in diffusion training + # because we can encode audio of varying lengths + # However, the audio should've been padded to a multiple of samples_per_latent by now. + y_size = total_size // samples_per_latent + # Create an empty latent, we will populate it with chunks as we encode them + y_final = torch.zeros((batch_size,self.latent_dim,y_size)).to(audio.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # encode the chunk + y_chunk = self.encode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size // samples_per_latent + t_end = t_start + chunk_size // samples_per_latent + # remove the edges of the overlaps + ol = overlap//samples_per_latent//2 + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + def decode_audio(self, latents, chunked=False, overlap=32, chunk_size=128, **kwargs): + ''' + Decode latents to audio. + If chunked is True, split the latents into chunks of a given maximum size chunk_size, with given overlap, both of which are measured in number of latents. + A overlap of zero will cause discontinuity artefacts. Overlap should be => receptive field size. + Every autoencoder will have a different receptive field size, and thus ideal overlap. + You can determine it empirically by diffing unchunked vs chunked audio and looking at maximum diff. + The final chunk may have a longer overlap in order to keep chunk_size consistent for all chunks. + Smaller chunk_size uses less memory, but more compute. + The chunk_size vs memory tradeoff isn't linear, and possibly depends on the GPU and CUDA version + For example, on a A6000 chunk_size 128 is overall faster than 256 and 512 even though it has more chunks + ''' + if not chunked: + # default behavior. Decode the entire latent in parallel + return self.decode(latents, **kwargs) + else: + # chunked decoding + hop_size = chunk_size - overlap + total_size = latents.shape[2] + batch_size = latents.shape[0] + chunks = [] + for i in range(0, total_size - chunk_size + 1, hop_size): + chunk = latents[:,:,i:i+chunk_size] + chunks.append(chunk) + if i+chunk_size != total_size: + # Final chunk + chunk = latents[:,:,-chunk_size:] + chunks.append(chunk) + chunks = torch.stack(chunks) + num_chunks = chunks.shape[0] + # samples_per_latent is just the downsampling ratio + samples_per_latent = self.downsampling_ratio + # Create an empty waveform, we will populate it with chunks as decode them + y_size = total_size * samples_per_latent + y_final = torch.zeros((batch_size,self.out_channels,y_size)).to(latents.device) + for i in range(num_chunks): + x_chunk = chunks[i,:] + # decode the chunk + y_chunk = self.decode(x_chunk) + # figure out where to put the audio along the time domain + if i == num_chunks-1: + # final chunk always goes at the end + t_end = y_size + t_start = t_end - y_chunk.shape[2] + else: + t_start = i * hop_size * samples_per_latent + t_end = t_start + chunk_size * samples_per_latent + # remove the edges of the overlaps + ol = (overlap//2) * samples_per_latent + chunk_start = 0 + chunk_end = y_chunk.shape[2] + if i > 0: + # no overlap for the start of the first chunk + t_start += ol + chunk_start += ol + if i < num_chunks-1: + # no overlap for the end of the last chunk + t_end -= ol + chunk_end -= ol + # paste the chunked audio into our y_final output audio + y_final[:,:,t_start:t_end] = y_chunk[:,:,chunk_start:chunk_end] + return y_final + + +# AE factories + +def create_encoder_from_config(encoder_config: Dict[str, Any]): + encoder_type = encoder_config.get("type", None) + assert encoder_type is not None, "Encoder type must be specified" + + if encoder_type == "oobleck": + encoder = OobleckEncoder( + **encoder_config["config"] + ) + + elif encoder_type == "seanet": + from encodec.modules import SEANetEncoder + seanet_encoder_config = encoder_config["config"] + + #SEANet encoder expects strides in reverse order + seanet_encoder_config["ratios"] = list(reversed(seanet_encoder_config.get("ratios", [2, 2, 2, 2, 2]))) + encoder = SEANetEncoder( + **seanet_encoder_config + ) + elif encoder_type == "dac": + dac_config = encoder_config["config"] + + encoder = DACEncoderWrapper(**dac_config) + elif encoder_type == "local_attn": + from .local_attention import TransformerEncoder1D + + local_attn_config = encoder_config["config"] + + encoder = TransformerEncoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown encoder type {encoder_type}") + + requires_grad = encoder_config.get("requires_grad", True) + if not requires_grad: + for param in encoder.parameters(): + param.requires_grad = False + + return encoder + +def create_decoder_from_config(decoder_config: Dict[str, Any]): + decoder_type = decoder_config.get("type", None) + assert decoder_type is not None, "Decoder type must be specified" + + if decoder_type == "oobleck": + decoder = OobleckDecoder( + **decoder_config["config"] + ) + elif decoder_type == "seanet": + from encodec.modules import SEANetDecoder + + decoder = SEANetDecoder( + **decoder_config["config"] + ) + elif decoder_type == "dac": + dac_config = decoder_config["config"] + + decoder = DACDecoderWrapper(**dac_config) + elif decoder_type == "local_attn": + from .local_attention import TransformerDecoder1D + + local_attn_config = decoder_config["config"] + + decoder = TransformerDecoder1D( + **local_attn_config + ) + else: + raise ValueError(f"Unknown decoder type {decoder_type}") + + requires_grad = decoder_config.get("requires_grad", True) + if not requires_grad: + for param in decoder.parameters(): + param.requires_grad = False + + return decoder + +def create_autoencoder_from_config(config: Dict[str, Any]): + + ae_config = config["model"] + + encoder = create_encoder_from_config(ae_config["encoder"]) + decoder = create_decoder_from_config(ae_config["decoder"]) + + bottleneck = ae_config.get("bottleneck", None) + + latent_dim = ae_config.get("latent_dim", None) + assert latent_dim is not None, "latent_dim must be specified in model config" + downsampling_ratio = ae_config.get("downsampling_ratio", None) + assert downsampling_ratio is not None, "downsampling_ratio must be specified in model config" + io_channels = ae_config.get("io_channels", None) + assert io_channels is not None, "io_channels must be specified in model config" + sample_rate = config.get("sample_rate", None) + assert sample_rate is not None, "sample_rate must be specified in model config" + + in_channels = ae_config.get("in_channels", None) + out_channels = ae_config.get("out_channels", None) + + pretransform = ae_config.get("pretransform", None) + + if pretransform is not None: + pretransform = create_pretransform_from_config(pretransform, sample_rate) + + if bottleneck is not None: + bottleneck = create_bottleneck_from_config(bottleneck) + + soft_clip = ae_config["decoder"].get("soft_clip", False) + + return AudioAutoencoder( + encoder, + decoder, + io_channels=io_channels, + latent_dim=latent_dim, + downsampling_ratio=downsampling_ratio, + sample_rate=sample_rate, + bottleneck=bottleneck, + pretransform=pretransform, + in_channels=in_channels, + out_channels=out_channels, + soft_clip=soft_clip + ) \ No newline at end of file diff --git a/src/modules/stable_vae/models/blocks.py b/src/modules/stable_vae/models/blocks.py new file mode 100644 index 0000000000000000000000000000000000000000..cb310c8980ef5dc0f138e6f9f3478d4cdc63354d --- /dev/null +++ b/src/modules/stable_vae/models/blocks.py @@ -0,0 +1,359 @@ +from functools import reduce +import math +import numpy as np +import torch +from torch import nn +from torch.nn import functional as F + +from torch.backends.cuda import sdp_kernel +from packaging import version + +from .nn.layers import Snake1d + + +class ResidualBlock(nn.Module): + def __init__(self, main, skip=None): + super().__init__() + self.main = nn.Sequential(*main) + self.skip = skip if skip else nn.Identity() + + def forward(self, input): + return self.main(input) + self.skip(input) + + +class ResConvBlock(ResidualBlock): + def __init__(self, c_in, c_mid, c_out, is_last=False, kernel_size=5, conv_bias=True, use_snake=False): + skip = None if c_in == c_out else nn.Conv1d(c_in, c_out, 1, bias=False) + super().__init__([ + nn.Conv1d(c_in, c_mid, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_mid), + Snake1d(c_mid) if use_snake else nn.GELU(), + nn.Conv1d(c_mid, c_out, kernel_size, padding=kernel_size//2, bias=conv_bias), + nn.GroupNorm(1, c_out) if not is_last else nn.Identity(), + (Snake1d(c_out) if use_snake else nn.GELU()) if not is_last else nn.Identity(), + ], skip) + + +class SelfAttention1d(nn.Module): + def __init__(self, c_in, n_head=1, dropout_rate=0.): + super().__init__() + assert c_in % n_head == 0 + self.norm = nn.GroupNorm(1, c_in) + self.n_head = n_head + self.qkv_proj = nn.Conv1d(c_in, c_in * 3, 1) + self.out_proj = nn.Conv1d(c_in, c_in, 1) + self.dropout = nn.Dropout(dropout_rate, inplace=True) + + self.use_flash = torch.cuda.is_available() and version.parse(torch.__version__) >= version.parse('2.0.0') + + if not self.use_flash: + return + + device_properties = torch.cuda.get_device_properties(torch.device('cuda')) + + if device_properties.major == 8 and device_properties.minor == 0: + # Use flash attention for A100 GPUs + self.sdp_kernel_config = (True, False, False) + else: + # Don't use flash attention for other GPUs + self.sdp_kernel_config = (False, True, True) + + def forward(self, input): + n, c, s = input.shape + qkv = self.qkv_proj(self.norm(input)) + qkv = qkv.view( + [n, self.n_head * 3, c // self.n_head, s]).transpose(2, 3) + q, k, v = qkv.chunk(3, dim=1) + scale = k.shape[3]**-0.25 + + if self.use_flash: + with sdp_kernel(*self.sdp_kernel_config): + y = F.scaled_dot_product_attention(q, k, v, is_causal=False).contiguous().view([n, c, s]) + else: + att = ((q * scale) @ (k.transpose(2, 3) * scale)).softmax(3) + y = (att @ v).transpose(2, 3).contiguous().view([n, c, s]) + + + return input + self.dropout(self.out_proj(y)) + + +class SkipBlock(nn.Module): + def __init__(self, *main): + super().__init__() + self.main = nn.Sequential(*main) + + def forward(self, input): + return torch.cat([self.main(input), input], dim=1) + + +class FourierFeatures(nn.Module): + def __init__(self, in_features, out_features, std=1.): + super().__init__() + assert out_features % 2 == 0 + self.weight = nn.Parameter(torch.randn( + [out_features // 2, in_features]) * std) + + def forward(self, input): + f = 2 * math.pi * input @ self.weight.T + return torch.cat([f.cos(), f.sin()], dim=-1) + + +def expand_to_planes(input, shape): + return input[..., None].repeat([1, 1, shape[2]]) + +_kernels = { + 'linear': + [1 / 8, 3 / 8, 3 / 8, 1 / 8], + 'cubic': + [-0.01171875, -0.03515625, 0.11328125, 0.43359375, + 0.43359375, 0.11328125, -0.03515625, -0.01171875], + 'lanczos3': + [0.003689131001010537, 0.015056144446134567, -0.03399861603975296, + -0.066637322306633, 0.13550527393817902, 0.44638532400131226, + 0.44638532400131226, 0.13550527393817902, -0.066637322306633, + -0.03399861603975296, 0.015056144446134567, 0.003689131001010537] +} + + +class Downsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, (self.pad,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv1d(x, weight, stride=2) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +class Upsample1d(nn.Module): + def __init__(self, kernel='linear', pad_mode='reflect', channels_last=False): + super().__init__() + self.pad_mode = pad_mode + kernel_1d = torch.tensor(_kernels[kernel]) * 2 + self.pad = kernel_1d.shape[0] // 2 - 1 + self.register_buffer('kernel', kernel_1d) + self.channels_last = channels_last + + def forward(self, x): + if self.channels_last: + x = x.permute(0, 2, 1) + x = F.pad(x, ((self.pad + 1) // 2,) * 2, self.pad_mode) + weight = x.new_zeros([x.shape[1], x.shape[1], self.kernel.shape[0]]) + indices = torch.arange(x.shape[1], device=x.device) + weight[indices, indices] = self.kernel.to(weight) + x = F.conv_transpose1d(x, weight, stride=2, padding=self.pad * 2 + 1) + if self.channels_last: + x = x.permute(0, 2, 1) + return x + + +def Downsample1d_2( + in_channels: int, out_channels: int, factor: int, kernel_multiplier: int = 2 +) -> nn.Module: + assert kernel_multiplier % 2 == 0, "Kernel multiplier must be even" + + return nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * kernel_multiplier + 1, + stride=factor, + padding=factor * (kernel_multiplier // 2), + ) + + +def Upsample1d_2( + in_channels: int, out_channels: int, factor: int, use_nearest: bool = False +) -> nn.Module: + + if factor == 1: + return nn.Conv1d( + in_channels=in_channels, out_channels=out_channels, kernel_size=3, padding=1 + ) + + if use_nearest: + return nn.Sequential( + nn.Upsample(scale_factor=factor, mode="nearest"), + nn.Conv1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=3, + padding=1, + ), + ) + else: + return nn.ConvTranspose1d( + in_channels=in_channels, + out_channels=out_channels, + kernel_size=factor * 2, + stride=factor, + padding=factor // 2 + factor % 2, + output_padding=factor % 2, + ) + + +def zero_init(layer): + nn.init.zeros_(layer.weight) + if layer.bias is not None: + nn.init.zeros_(layer.bias) + return layer + + +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +#rms_norm = torch.compile(rms_norm) + +class AdaRMSNorm(nn.Module): + def __init__(self, features, cond_features, eps=1e-6): + super().__init__() + self.eps = eps + self.linear = zero_init(nn.Linear(cond_features, features, bias=False)) + + def extra_repr(self): + return f"eps={self.eps}," + + def forward(self, x, cond): + return rms_norm(x, self.linear(cond)[:, None, :] + 1, self.eps) + + +def normalize(x, eps=1e-4): + dim = list(range(1, x.ndim)) + n = torch.linalg.vector_norm(x, dim=dim, keepdim=True) + alpha = np.sqrt(n.numel() / x.numel()) + return x / torch.add(eps, n, alpha=alpha) + + +class ForcedWNConv1d(nn.Module): + def __init__(self, in_channels, out_channels, kernel_size=1): + super().__init__() + self.weight = nn.Parameter(torch.randn([out_channels, in_channels, kernel_size])) + + def forward(self, x): + if self.training: + with torch.no_grad(): + self.weight.copy_(normalize(self.weight)) + + fan_in = self.weight[0].numel() + + w = normalize(self.weight) / math.sqrt(fan_in) + + return F.conv1d(x, w, padding='same') + +# Kernels + +use_compile = True + +def compile(function, *args, **kwargs): + if not use_compile: + return function + try: + return torch.compile(function, *args, **kwargs) + except RuntimeError: + return function + + +@compile +def linear_geglu(x, weight, bias=None): + x = x @ weight.mT + if bias is not None: + x = x + bias + x, gate = x.chunk(2, dim=-1) + return x * F.gelu(gate) + + +@compile +def rms_norm(x, scale, eps): + dtype = reduce(torch.promote_types, (x.dtype, scale.dtype, torch.float32)) + mean_sq = torch.mean(x.to(dtype)**2, dim=-1, keepdim=True) + scale = scale.to(dtype) * torch.rsqrt(mean_sq + eps) + return x * scale.to(x.dtype) + +# Layers + + +class LinearGEGLU(nn.Linear): + def __init__(self, in_features, out_features, bias=True): + super().__init__(in_features, out_features * 2, bias=bias) + self.out_features = out_features + + def forward(self, x): + return linear_geglu(x, self.weight, self.bias) + + +class RMSNorm(nn.Module): + def __init__(self, shape, fix_scale = False, eps=1e-6): + super().__init__() + self.eps = eps + + if fix_scale: + self.register_buffer("scale", torch.ones(shape)) + else: + self.scale = nn.Parameter(torch.ones(shape)) + + def extra_repr(self): + return f"shape={tuple(self.scale.shape)}, eps={self.eps}" + + def forward(self, x): + return rms_norm(x, self.scale, self.eps) + + +# jit script make it 1.4x faster and save GPU memory +@torch.jit.script +def snake_beta(x, alpha, beta): + return x + (1.0 / (beta + 0.000000001)) * pow(torch.sin(x * alpha), 2) + +# try: +# snake_beta = torch.compile(snake_beta) +# except RuntimeError: +# pass + + +# Adapted from https://github.com/NVIDIA/BigVGAN/blob/main/activations.py under MIT license +# License available in LICENSES/LICENSE_NVIDIA.txt +class SnakeBeta(nn.Module): + + def __init__(self, in_features, alpha=1.0, alpha_trainable=True, alpha_logscale=True): + super(SnakeBeta, self).__init__() + self.in_features = in_features + + # initialize alpha + self.alpha_logscale = alpha_logscale + if self.alpha_logscale: + # log scale alphas initialized to zeros + self.alpha = nn.Parameter(torch.zeros(in_features) * alpha) + self.beta = nn.Parameter(torch.zeros(in_features) * alpha) + else: + # linear scale alphas initialized to ones + self.alpha = nn.Parameter(torch.ones(in_features) * alpha) + self.beta = nn.Parameter(torch.ones(in_features) * alpha) + + self.alpha.requires_grad = alpha_trainable + self.beta.requires_grad = alpha_trainable + + # self.no_div_by_zero = 0.000000001 + + def forward(self, x): + alpha = self.alpha.unsqueeze(0).unsqueeze(-1) + # line up with x to [B, C, T] + beta = self.beta.unsqueeze(0).unsqueeze(-1) + if self.alpha_logscale: + alpha = torch.exp(alpha) + beta = torch.exp(beta) + x = snake_beta(x, alpha, beta) + + return x \ No newline at end of file diff --git a/src/modules/stable_vae/models/bottleneck.py b/src/modules/stable_vae/models/bottleneck.py new file mode 100644 index 0000000000000000000000000000000000000000..df88c5f1b1f5fa3675c1a42f42e5e31e27d00ed3 --- /dev/null +++ b/src/modules/stable_vae/models/bottleneck.py @@ -0,0 +1,346 @@ +import torch +from torch import nn +from torch.nn import functional as F + +from einops import rearrange +from vector_quantize_pytorch import ResidualVQ, FSQ +from .nn.quantize import ResidualVectorQuantize as DACResidualVQ + + +class Bottleneck(nn.Module): + def __init__(self, is_discrete: bool = False): + super().__init__() + + self.is_discrete = is_discrete + + def encode(self, x, return_info=False, **kwargs): + raise NotImplementedError + + def decode(self, x): + raise NotImplementedError + + +class DiscreteBottleneck(Bottleneck): + def __init__(self, num_quantizers, codebook_size, tokens_id): + super().__init__(is_discrete=True) + + self.num_quantizers = num_quantizers + self.codebook_size = codebook_size + self.tokens_id = tokens_id + + def decode_tokens(self, codes, **kwargs): + raise NotImplementedError + + +class TanhBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + self.tanh = nn.Tanh() + + def encode(self, x, return_info=False): + info = {} + + x = torch.tanh(x) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + +@torch.jit.script +def vae_sample_kl(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + var = stdev * stdev + logvar = torch.log(var) + latents = torch.randn_like(mean) * stdev + mean + + kl = (mean * mean + var - logvar - 1).sum(1).mean() + + return latents, kl + + +@torch.jit.script +def vae_sample(mean, scale): + stdev = nn.functional.softplus(scale) + 1e-4 + latents = torch.randn_like(mean) * stdev + mean + return latents + + +class VAEBottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False, **kwargs): + mean, scale = x.chunk(2, dim=1) + + if return_info: + info = {} + x, kl = vae_sample_kl(mean, scale) + info["kl"] = kl + return x, info + else: + x = vae_sample(mean, scale) + return x + + def decode(self, x): + return x + + +def compute_mean_kernel(x, y): + kernel_input = (x[:, None] - y[None]).pow(2).mean(2) / x.shape[-1] + return torch.exp(-kernel_input).mean() + + +def compute_mmd(latents): + latents_reshaped = latents.permute(0, 2, 1).reshape(-1, latents.shape[1]) + noise = torch.randn_like(latents_reshaped) + + latents_kernel = compute_mean_kernel(latents_reshaped, latents_reshaped) + noise_kernel = compute_mean_kernel(noise, noise) + latents_noise_kernel = compute_mean_kernel(latents_reshaped, noise) + + mmd = latents_kernel + noise_kernel - 2 * latents_noise_kernel + return mmd.mean() + + +class WassersteinBottleneck(Bottleneck): + def __init__(self, noise_augment_dim: int = 0): + super().__init__(is_discrete=False) + + self.noise_augment_dim = noise_augment_dim + + def encode(self, x, return_info=False): + info = {} + + if self.training and return_info: + mmd = compute_mmd(x) + info["mmd"] = mmd + + if return_info: + return x, info + + return x + + def decode(self, x): + + if self.noise_augment_dim > 0: + noise = torch.randn(x.shape[0], self.noise_augment_dim, + x.shape[-1]).type_as(x) + x = torch.cat([x, noise], dim=1) + + return x + + +class L2Bottleneck(Bottleneck): + def __init__(self): + super().__init__(is_discrete=False) + + def encode(self, x, return_info=False): + info = {} + + x = F.normalize(x, dim=1) + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return F.normalize(x, dim=1) + + +class RVQBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False, **kwargs): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + + +class RVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["num_quantizers"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "quantizer_indices") + self.quantizer = ResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["num_quantizers"] + + def encode(self, x, return_info=False): + info = {} + + x, kl = vae_sample(*x.chunk(2, dim=1)) + + info["kl"] = kl + + x = rearrange(x, "b c n -> b n c") + x, indices, loss = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + info["quantizer_loss"] = loss.mean() + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, codes, **kwargs): + latents = self.quantizer.get_outputs_from_indices(codes) + + return self.decode(latents, **kwargs) + + +class DACRVQBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, **kwargs): + info = {} + + info["pre_quantizer"] = x + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, **kwargs) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + + +class DACRVQVAEBottleneck(DiscreteBottleneck): + def __init__(self, quantize_on_decode=False, **quantizer_kwargs): + super().__init__(num_quantizers = quantizer_kwargs["n_codebooks"], codebook_size = quantizer_kwargs["codebook_size"], tokens_id = "codes") + self.quantizer = DACResidualVQ(**quantizer_kwargs) + self.num_quantizers = quantizer_kwargs["n_codebooks"] + self.quantize_on_decode = quantize_on_decode + + def encode(self, x, return_info=False, n_quantizers: int = None): + info = {} + + mean, scale = x.chunk(2, dim=1) + + x, kl = vae_sample(mean, scale) + + info["pre_quantizer"] = x + info["kl"] = kl + + if self.quantize_on_decode: + return x, info if return_info else x + + z, codes, latents, commitment_loss, codebook_loss = self.quantizer(x, n_quantizers=n_quantizers) + + output = { + "z": z, + "codes": codes, + "latents": latents, + "vq/commitment_loss": commitment_loss, + "vq/codebook_loss": codebook_loss, + } + + output["vq/commitment_loss"] /= self.num_quantizers + output["vq/codebook_loss"] /= self.num_quantizers + + info.update(output) + + if return_info: + return output["z"], info + + return output["z"] + + def decode(self, x): + + if self.quantize_on_decode: + x = self.quantizer(x)[0] + + return x + + def decode_tokens(self, codes, **kwargs): + latents, _, _ = self.quantizer.from_codes(codes) + + return self.decode(latents, **kwargs) + + +class FSQBottleneck(DiscreteBottleneck): + def __init__(self, dim, levels): + super().__init__(num_quantizers = 1, codebook_size = levels ** dim, tokens_id = "quantizer_indices") + self.quantizer = FSQ(levels=[levels] * dim) + + def encode(self, x, return_info=False): + info = {} + + x = rearrange(x, "b c n -> b n c") + x, indices = self.quantizer(x) + x = rearrange(x, "b n c -> b c n") + + info["quantizer_indices"] = indices + + if return_info: + return x, info + else: + return x + + def decode(self, x): + return x + + def decode_tokens(self, tokens, **kwargs): + latents = self.quantizer.indices_to_codes(tokens) + + return self.decode(latents, **kwargs) \ No newline at end of file diff --git a/src/modules/stable_vae/models/factory.py b/src/modules/stable_vae/models/factory.py new file mode 100644 index 0000000000000000000000000000000000000000..4188703000ee176342c7f329342f18d6fe747b04 --- /dev/null +++ b/src/modules/stable_vae/models/factory.py @@ -0,0 +1,153 @@ +import json + +def create_model_from_config(model_config): + model_type = model_config.get('model_type', None) + + assert model_type is not None, 'model_type must be specified in model config' + + if model_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + return create_autoencoder_from_config(model_config) + elif model_type == 'diffusion_uncond': + from .diffusion import create_diffusion_uncond_from_config + return create_diffusion_uncond_from_config(model_config) + elif model_type == 'diffusion_cond' or model_type == 'diffusion_cond_inpaint' or model_type == "diffusion_prior": + from .diffusion import create_diffusion_cond_from_config + return create_diffusion_cond_from_config(model_config) + elif model_type == 'diffusion_autoencoder': + from .autoencoders import create_diffAE_from_config + return create_diffAE_from_config(model_config) + elif model_type == 'lm': + from .lm import create_audio_lm_from_config + return create_audio_lm_from_config(model_config) + else: + raise NotImplementedError(f'Unknown model type: {model_type}') + +def create_model_from_config_path(model_config_path): + with open(model_config_path) as f: + model_config = json.load(f) + + return create_model_from_config(model_config) + +def create_pretransform_from_config(pretransform_config, sample_rate): + pretransform_type = pretransform_config.get('type', None) + + assert pretransform_type is not None, 'type must be specified in pretransform config' + + if pretransform_type == 'autoencoder': + from .autoencoders import create_autoencoder_from_config + from .pretransforms import AutoencoderPretransform + + # Create fake top-level config to pass sample rate to autoencoder constructor + # This is a bit of a hack but it keeps us from re-defining the sample rate in the config + autoencoder_config = {"sample_rate": sample_rate, "model": pretransform_config["config"]} + autoencoder = create_autoencoder_from_config(autoencoder_config) + + scale = pretransform_config.get("scale", 1.0) + model_half = pretransform_config.get("model_half", False) + iterate_batch = pretransform_config.get("iterate_batch", False) + chunked = pretransform_config.get("chunked", False) + + pretransform = AutoencoderPretransform(autoencoder, scale=scale, model_half=model_half, iterate_batch=iterate_batch, chunked=chunked) + elif pretransform_type == 'wavelet': + from .pretransforms import WaveletPretransform + + wavelet_config = pretransform_config["config"] + channels = wavelet_config["channels"] + levels = wavelet_config["levels"] + wavelet = wavelet_config["wavelet"] + + pretransform = WaveletPretransform(channels, levels, wavelet) + elif pretransform_type == 'pqmf': + from .pretransforms import PQMFPretransform + pqmf_config = pretransform_config["config"] + pretransform = PQMFPretransform(**pqmf_config) + elif pretransform_type == 'dac_pretrained': + from .pretransforms import PretrainedDACPretransform + pretrained_dac_config = pretransform_config["config"] + pretransform = PretrainedDACPretransform(**pretrained_dac_config) + elif pretransform_type == "audiocraft_pretrained": + from .pretransforms import AudiocraftCompressionPretransform + + audiocraft_config = pretransform_config["config"] + pretransform = AudiocraftCompressionPretransform(**audiocraft_config) + else: + raise NotImplementedError(f'Unknown pretransform type: {pretransform_type}') + + enable_grad = pretransform_config.get('enable_grad', False) + pretransform.enable_grad = enable_grad + + pretransform.eval().requires_grad_(pretransform.enable_grad) + + return pretransform + +def create_bottleneck_from_config(bottleneck_config): + bottleneck_type = bottleneck_config.get('type', None) + + assert bottleneck_type is not None, 'type must be specified in bottleneck config' + + if bottleneck_type == 'tanh': + from .bottleneck import TanhBottleneck + bottleneck = TanhBottleneck() + elif bottleneck_type == 'vae': + from .bottleneck import VAEBottleneck + bottleneck = VAEBottleneck() + elif bottleneck_type == 'rvq': + from .bottleneck import RVQBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQBottleneck(**quantizer_params) + elif bottleneck_type == "dac_rvq": + from .bottleneck import DACRVQBottleneck + + bottleneck = DACRVQBottleneck(**bottleneck_config["config"]) + + elif bottleneck_type == 'rvq_vae': + from .bottleneck import RVQVAEBottleneck + + quantizer_params = { + "dim": 128, + "codebook_size": 1024, + "num_quantizers": 8, + "decay": 0.99, + "kmeans_init": True, + "kmeans_iters": 50, + "threshold_ema_dead_code": 2, + } + + quantizer_params.update(bottleneck_config["config"]) + + bottleneck = RVQVAEBottleneck(**quantizer_params) + + elif bottleneck_type == 'dac_rvq_vae': + from .bottleneck import DACRVQVAEBottleneck + bottleneck = DACRVQVAEBottleneck(**bottleneck_config["config"]) + elif bottleneck_type == 'l2_norm': + from .bottleneck import L2Bottleneck + bottleneck = L2Bottleneck() + elif bottleneck_type == "wasserstein": + from .bottleneck import WassersteinBottleneck + bottleneck = WassersteinBottleneck(**bottleneck_config.get("config", {})) + elif bottleneck_type == "fsq": + from .bottleneck import FSQBottleneck + bottleneck = FSQBottleneck(**bottleneck_config["config"]) + else: + raise NotImplementedError(f'Unknown bottleneck type: {bottleneck_type}') + + requires_grad = bottleneck_config.get('requires_grad', True) + if not requires_grad: + for param in bottleneck.parameters(): + param.requires_grad = False + + return bottleneck diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7 --- /dev/null +++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1,3 @@ +from . import layers +from . import loss +from . import quantize diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94 --- /dev/null +++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/layers-checkpoint.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b --- /dev/null +++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/loss-checkpoint.py @@ -0,0 +1,368 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..bc462bc56ab5191b65d58ebd8bcaff4af7fb5927 --- /dev/null +++ b/src/modules/stable_vae/models/nn/.ipynb_checkpoints/quantize-checkpoint.py @@ -0,0 +1,262 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from .layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/src/modules/stable_vae/models/nn/__init__.py b/src/modules/stable_vae/models/nn/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..6718c8b1a3d36c31655b030f4c515a144cde4db7 --- /dev/null +++ b/src/modules/stable_vae/models/nn/__init__.py @@ -0,0 +1,3 @@ +from . import layers +from . import loss +from . import quantize diff --git a/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..fd83e18bd22222ca6b9ce0f0ab056cc026747bb3 Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..613da213e9ce976569f03e20476c405e0a68b0cc Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..57a85d0e28029662f1ecc2790e44f71caa09cd0c Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c93f6bedcfe519c094304705d6dd5033cc9a7b45 Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/layers.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..47d3c0daa35a755146c3ebf4f49d430469fe0c6a Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..96bfb48894b2b72431b76c07dcd7187e322b1415 Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/loss.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..d960f1c6e8f02f6b1c3b72b9d60543ceadb619cf Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-310.pyc differ diff --git a/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..c59b1da7d3336c046e199092b83bcb281049336b Binary files /dev/null and b/src/modules/stable_vae/models/nn/__pycache__/quantize.cpython-311.pyc differ diff --git a/src/modules/stable_vae/models/nn/layers.py b/src/modules/stable_vae/models/nn/layers.py new file mode 100644 index 0000000000000000000000000000000000000000..44fbc2929715e11d843b24195d7042a528969a94 --- /dev/null +++ b/src/modules/stable_vae/models/nn/layers.py @@ -0,0 +1,33 @@ +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + + +def WNConv1d(*args, **kwargs): + return weight_norm(nn.Conv1d(*args, **kwargs)) + + +def WNConvTranspose1d(*args, **kwargs): + return weight_norm(nn.ConvTranspose1d(*args, **kwargs)) + + +# Scripting this brings model speed up 1.4x +@torch.jit.script +def snake(x, alpha): + shape = x.shape + x = x.reshape(shape[0], shape[1], -1) + x = x + (alpha + 1e-9).reciprocal() * torch.sin(alpha * x).pow(2) + x = x.reshape(shape) + return x + + +class Snake1d(nn.Module): + def __init__(self, channels): + super().__init__() + self.alpha = nn.Parameter(torch.ones(1, channels, 1)) + + def forward(self, x): + return snake(x, self.alpha) diff --git a/src/modules/stable_vae/models/nn/loss.py b/src/modules/stable_vae/models/nn/loss.py new file mode 100644 index 0000000000000000000000000000000000000000..9bb3dd6ce08a7a24f18f941eeb5b68fe9461e86b --- /dev/null +++ b/src/modules/stable_vae/models/nn/loss.py @@ -0,0 +1,368 @@ +import typing +from typing import List + +import torch +import torch.nn.functional as F +from audiotools import AudioSignal +from audiotools import STFTParams +from torch import nn + + +class L1Loss(nn.L1Loss): + """L1 Loss between AudioSignals. Defaults + to comparing ``audio_data``, but any + attribute of an AudioSignal can be used. + + Parameters + ---------- + attribute : str, optional + Attribute of signal to compare, defaults to ``audio_data``. + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs): + self.attribute = attribute + self.weight = weight + super().__init__(**kwargs) + + def forward(self, x: AudioSignal, y: AudioSignal): + """ + Parameters + ---------- + x : AudioSignal + Estimate AudioSignal + y : AudioSignal + Reference AudioSignal + + Returns + ------- + torch.Tensor + L1 loss between AudioSignal attributes. + """ + if isinstance(x, AudioSignal): + x = getattr(x, self.attribute) + y = getattr(y, self.attribute) + return super().forward(x, y) + + +class SISDRLoss(nn.Module): + """ + Computes the Scale-Invariant Source-to-Distortion Ratio between a batch + of estimated and reference audio signals or aligned features. + + Parameters + ---------- + scaling : int, optional + Whether to use scale-invariant (True) or + signal-to-noise ratio (False), by default True + reduction : str, optional + How to reduce across the batch (either 'mean', + 'sum', or none).], by default ' mean' + zero_mean : int, optional + Zero mean the references and estimates before + computing the loss, by default True + clip_min : int, optional + The minimum possible loss value. Helps network + to not focus on making already good examples better, by default None + weight : float, optional + Weight of this loss, defaults to 1.0. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/distance.py + """ + + def __init__( + self, + scaling: int = True, + reduction: str = "mean", + zero_mean: int = True, + clip_min: int = None, + weight: float = 1.0, + ): + self.scaling = scaling + self.reduction = reduction + self.zero_mean = zero_mean + self.clip_min = clip_min + self.weight = weight + super().__init__() + + def forward(self, x: AudioSignal, y: AudioSignal): + eps = 1e-8 + # nb, nc, nt + if isinstance(x, AudioSignal): + references = x.audio_data + estimates = y.audio_data + else: + references = x + estimates = y + + nb = references.shape[0] + references = references.reshape(nb, 1, -1).permute(0, 2, 1) + estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1) + + # samples now on axis 1 + if self.zero_mean: + mean_reference = references.mean(dim=1, keepdim=True) + mean_estimate = estimates.mean(dim=1, keepdim=True) + else: + mean_reference = 0 + mean_estimate = 0 + + _references = references - mean_reference + _estimates = estimates - mean_estimate + + references_projection = (_references**2).sum(dim=-2) + eps + references_on_estimates = (_estimates * _references).sum(dim=-2) + eps + + scale = ( + (references_on_estimates / references_projection).unsqueeze(1) + if self.scaling + else 1 + ) + + e_true = scale * _references + e_res = _estimates - e_true + + signal = (e_true**2).sum(dim=1) + noise = (e_res**2).sum(dim=1) + sdr = -10 * torch.log10(signal / noise + eps) + + if self.clip_min is not None: + sdr = torch.clamp(sdr, min=self.clip_min) + + if self.reduction == "mean": + sdr = sdr.mean() + elif self.reduction == "sum": + sdr = sdr.sum() + return sdr + + +class MultiScaleSTFTLoss(nn.Module): + """Computes the multi-scale STFT loss from [1]. + + Parameters + ---------- + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + References + ---------- + + 1. Engel, Jesse, Chenjie Gu, and Adam Roberts. + "DDSP: Differentiable Digital Signal Processing." + International Conference on Learning Representations. 2019. + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.loss_fn = loss_fn + self.log_weight = log_weight + self.mag_weight = mag_weight + self.clamp_eps = clamp_eps + self.weight = weight + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes multi-scale STFT between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Multi-scale STFT loss. + """ + loss = 0.0 + for s in self.stft_params: + x.stft(s.window_length, s.hop_length, s.window_type) + y.stft(s.window_length, s.hop_length, s.window_type) + loss += self.log_weight * self.loss_fn( + x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude) + return loss + + +class MelSpectrogramLoss(nn.Module): + """Compute distance between mel spectrograms. Can be used + in a multi-scale way. + + Parameters + ---------- + n_mels : List[int] + Number of mels per STFT, by default [150, 80], + window_lengths : List[int], optional + Length of each window of each STFT, by default [2048, 512] + loss_fn : typing.Callable, optional + How to compare each loss, by default nn.L1Loss() + clamp_eps : float, optional + Clamp on the log magnitude, below, by default 1e-5 + mag_weight : float, optional + Weight of raw magnitude portion of loss, by default 1.0 + log_weight : float, optional + Weight of log magnitude portion of loss, by default 1.0 + pow : float, optional + Power to raise magnitude to before taking log, by default 2.0 + weight : float, optional + Weight of this loss, by default 1.0 + match_stride : bool, optional + Whether to match the stride of convolutional layers, by default False + + Implementation copied from: https://github.com/descriptinc/lyrebird-audiotools/blob/961786aa1a9d628cca0c0486e5885a457fe70c1a/audiotools/metrics/spectral.py + """ + + def __init__( + self, + n_mels: List[int] = [150, 80], + window_lengths: List[int] = [2048, 512], + loss_fn: typing.Callable = nn.L1Loss(), + clamp_eps: float = 1e-5, + mag_weight: float = 1.0, + log_weight: float = 1.0, + pow: float = 2.0, + weight: float = 1.0, + match_stride: bool = False, + mel_fmin: List[float] = [0.0, 0.0], + mel_fmax: List[float] = [None, None], + window_type: str = None, + ): + super().__init__() + self.stft_params = [ + STFTParams( + window_length=w, + hop_length=w // 4, + match_stride=match_stride, + window_type=window_type, + ) + for w in window_lengths + ] + self.n_mels = n_mels + self.loss_fn = loss_fn + self.clamp_eps = clamp_eps + self.log_weight = log_weight + self.mag_weight = mag_weight + self.weight = weight + self.mel_fmin = mel_fmin + self.mel_fmax = mel_fmax + self.pow = pow + + def forward(self, x: AudioSignal, y: AudioSignal): + """Computes mel loss between an estimate and a reference + signal. + + Parameters + ---------- + x : AudioSignal + Estimate signal + y : AudioSignal + Reference signal + + Returns + ------- + torch.Tensor + Mel loss. + """ + loss = 0.0 + for n_mels, fmin, fmax, s in zip( + self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params + ): + kwargs = { + "window_length": s.window_length, + "hop_length": s.hop_length, + "window_type": s.window_type, + } + x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs) + + loss += self.log_weight * self.loss_fn( + x_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + y_mels.clamp(self.clamp_eps).pow(self.pow).log10(), + ) + loss += self.mag_weight * self.loss_fn(x_mels, y_mels) + return loss + + +class GANLoss(nn.Module): + """ + Computes a discriminator loss, given a discriminator on + generated waveforms/spectrograms compared to ground truth + waveforms/spectrograms. Computes the loss for both the + discriminator and the generator in separate functions. + """ + + def __init__(self, discriminator): + super().__init__() + self.discriminator = discriminator + + def forward(self, fake, real): + d_fake = self.discriminator(fake.audio_data) + d_real = self.discriminator(real.audio_data) + return d_fake, d_real + + def discriminator_loss(self, fake, real): + d_fake, d_real = self.forward(fake.clone().detach(), real) + + loss_d = 0 + for x_fake, x_real in zip(d_fake, d_real): + loss_d += torch.mean(x_fake[-1] ** 2) + loss_d += torch.mean((1 - x_real[-1]) ** 2) + return loss_d + + def generator_loss(self, fake, real): + d_fake, d_real = self.forward(fake, real) + + loss_g = 0 + for x_fake in d_fake: + loss_g += torch.mean((1 - x_fake[-1]) ** 2) + + loss_feature = 0 + + for i in range(len(d_fake)): + for j in range(len(d_fake[i]) - 1): + loss_feature += F.l1_loss(d_fake[i][j], d_real[i][j].detach()) + return loss_g, loss_feature diff --git a/src/modules/stable_vae/models/nn/quantize.py b/src/modules/stable_vae/models/nn/quantize.py new file mode 100644 index 0000000000000000000000000000000000000000..bc462bc56ab5191b65d58ebd8bcaff4af7fb5927 --- /dev/null +++ b/src/modules/stable_vae/models/nn/quantize.py @@ -0,0 +1,262 @@ +from typing import Union + +import numpy as np +import torch +import torch.nn as nn +import torch.nn.functional as F +from einops import rearrange +from torch.nn.utils import weight_norm + +from .layers import WNConv1d + + +class VectorQuantize(nn.Module): + """ + Implementation of VQ similar to Karpathy's repo: + https://github.com/karpathy/deep-vector-quantization + Additionally uses following tricks from Improved VQGAN + (https://arxiv.org/pdf/2110.04627.pdf): + 1. Factorized codes: Perform nearest neighbor lookup in low-dimensional space + for improved codebook usage + 2. l2-normalized codes: Converts euclidean distance to cosine similarity which + improves training stability + """ + + def __init__(self, input_dim: int, codebook_size: int, codebook_dim: int): + super().__init__() + self.codebook_size = codebook_size + self.codebook_dim = codebook_dim + + self.in_proj = WNConv1d(input_dim, codebook_dim, kernel_size=1) + self.out_proj = WNConv1d(codebook_dim, input_dim, kernel_size=1) + self.codebook = nn.Embedding(codebook_size, codebook_dim) + + def forward(self, z): + """Quantized the input tensor using a fixed codebook and returns + the corresponding codebook vectors + + Parameters + ---------- + z : Tensor[B x D x T] + + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + Tensor[1] + Codebook loss to update the codebook + Tensor[B x T] + Codebook indices (quantized discrete representation of input) + Tensor[B x D x T] + Projected latents (continuous representation of input before quantization) + """ + + # Factorized codes (ViT-VQGAN) Project input into low-dimensional space + z_e = self.in_proj(z) # z_e : (B x D x T) + z_q, indices = self.decode_latents(z_e) + + commitment_loss = F.mse_loss(z_e, z_q.detach(), reduction="none").mean([1, 2]) + codebook_loss = F.mse_loss(z_q, z_e.detach(), reduction="none").mean([1, 2]) + + z_q = ( + z_e + (z_q - z_e).detach() + ) # noop in forward pass, straight-through gradient estimator in backward pass + + z_q = self.out_proj(z_q) + + return z_q, commitment_loss, codebook_loss, indices, z_e + + def embed_code(self, embed_id): + return F.embedding(embed_id, self.codebook.weight) + + def decode_code(self, embed_id): + return self.embed_code(embed_id).transpose(1, 2) + + def decode_latents(self, latents): + encodings = rearrange(latents, "b d t -> (b t) d") + codebook = self.codebook.weight # codebook: (N x D) + + # L2 normalize encodings and codebook (ViT-VQGAN) + encodings = F.normalize(encodings) + codebook = F.normalize(codebook) + + # Compute euclidean distance with codebook + dist = ( + encodings.pow(2).sum(1, keepdim=True) + - 2 * encodings @ codebook.t() + + codebook.pow(2).sum(1, keepdim=True).t() + ) + indices = rearrange((-dist).max(1)[1], "(b t) -> b t", b=latents.size(0)) + z_q = self.decode_code(indices) + return z_q, indices + + +class ResidualVectorQuantize(nn.Module): + """ + Introduced in SoundStream: An end2end neural audio codec + https://arxiv.org/abs/2107.03312 + """ + + def __init__( + self, + input_dim: int = 512, + n_codebooks: int = 9, + codebook_size: int = 1024, + codebook_dim: Union[int, list] = 8, + quantizer_dropout: float = 0.0, + ): + super().__init__() + if isinstance(codebook_dim, int): + codebook_dim = [codebook_dim for _ in range(n_codebooks)] + + self.n_codebooks = n_codebooks + self.codebook_dim = codebook_dim + self.codebook_size = codebook_size + + self.quantizers = nn.ModuleList( + [ + VectorQuantize(input_dim, codebook_size, codebook_dim[i]) + for i in range(n_codebooks) + ] + ) + self.quantizer_dropout = quantizer_dropout + + def forward(self, z, n_quantizers: int = None): + """Quantized the input tensor using a fixed set of `n` codebooks and returns + the corresponding codebook vectors + Parameters + ---------- + z : Tensor[B x D x T] + n_quantizers : int, optional + No. of quantizers to use + (n_quantizers < self.n_codebooks ex: for quantizer dropout) + Note: if `self.quantizer_dropout` is True, this argument is ignored + when in training mode, and a random number of quantizers is used. + Returns + ------- + dict + A dictionary with the following keys: + + "z" : Tensor[B x D x T] + Quantized continuous representation of input + "codes" : Tensor[B x N x T] + Codebook indices for each codebook + (quantized discrete representation of input) + "latents" : Tensor[B x N*D x T] + Projected latents (continuous representation of input before quantization) + "vq/commitment_loss" : Tensor[1] + Commitment loss to train encoder to predict vectors closer to codebook + entries + "vq/codebook_loss" : Tensor[1] + Codebook loss to update the codebook + """ + z_q = 0 + residual = z + commitment_loss = 0 + codebook_loss = 0 + + codebook_indices = [] + latents = [] + + if n_quantizers is None: + n_quantizers = self.n_codebooks + if self.training: + n_quantizers = torch.ones((z.shape[0],)) * self.n_codebooks + 1 + dropout = torch.randint(1, self.n_codebooks + 1, (z.shape[0],)) + n_dropout = int(z.shape[0] * self.quantizer_dropout) + n_quantizers[:n_dropout] = dropout[:n_dropout] + n_quantizers = n_quantizers.to(z.device) + + for i, quantizer in enumerate(self.quantizers): + if self.training is False and i >= n_quantizers: + break + + z_q_i, commitment_loss_i, codebook_loss_i, indices_i, z_e_i = quantizer( + residual + ) + + # Create mask to apply quantizer dropout + mask = ( + torch.full((z.shape[0],), fill_value=i, device=z.device) < n_quantizers + ) + z_q = z_q + z_q_i * mask[:, None, None] + residual = residual - z_q_i + + # Sum losses + commitment_loss += (commitment_loss_i * mask).mean() + codebook_loss += (codebook_loss_i * mask).mean() + + codebook_indices.append(indices_i) + latents.append(z_e_i) + + codes = torch.stack(codebook_indices, dim=1) + latents = torch.cat(latents, dim=1) + + return z_q, codes, latents, commitment_loss, codebook_loss + + def from_codes(self, codes: torch.Tensor): + """Given the quantized codes, reconstruct the continuous representation + Parameters + ---------- + codes : Tensor[B x N x T] + Quantized discrete representation of input + Returns + ------- + Tensor[B x D x T] + Quantized continuous representation of input + """ + z_q = 0.0 + z_p = [] + n_codebooks = codes.shape[1] + for i in range(n_codebooks): + z_p_i = self.quantizers[i].decode_code(codes[:, i, :]) + z_p.append(z_p_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + return z_q, torch.cat(z_p, dim=1), codes + + def from_latents(self, latents: torch.Tensor): + """Given the unquantized latents, reconstruct the + continuous representation after quantization. + + Parameters + ---------- + latents : Tensor[B x N x T] + Continuous representation of input after projection + + Returns + ------- + Tensor[B x D x T] + Quantized representation of full-projected space + Tensor[B x D x T] + Quantized representation of latent space + """ + z_q = 0 + z_p = [] + codes = [] + dims = np.cumsum([0] + [q.codebook_dim for q in self.quantizers]) + + n_codebooks = np.where(dims <= latents.shape[1])[0].max(axis=0, keepdims=True)[ + 0 + ] + for i in range(n_codebooks): + j, k = dims[i], dims[i + 1] + z_p_i, codes_i = self.quantizers[i].decode_latents(latents[:, j:k, :]) + z_p.append(z_p_i) + codes.append(codes_i) + + z_q_i = self.quantizers[i].out_proj(z_p_i) + z_q = z_q + z_q_i + + return z_q, torch.cat(z_p, dim=1), torch.stack(codes, dim=1) + + +if __name__ == "__main__": + rvq = ResidualVectorQuantize(quantizer_dropout=True) + x = torch.randn(16, 512, 80) + y = rvq(x) + print(y["latents"].shape) diff --git a/src/modules/stable_vae/models/pretransforms.py b/src/modules/stable_vae/models/pretransforms.py new file mode 100644 index 0000000000000000000000000000000000000000..c9942db59908ce8135f44e45090e928f6c60393a --- /dev/null +++ b/src/modules/stable_vae/models/pretransforms.py @@ -0,0 +1,258 @@ +import torch +from einops import rearrange +from torch import nn + +class Pretransform(nn.Module): + def __init__(self, enable_grad, io_channels, is_discrete): + super().__init__() + + self.is_discrete = is_discrete + self.io_channels = io_channels + self.encoded_channels = None + self.downsampling_ratio = None + + self.enable_grad = enable_grad + + def encode(self, x): + raise NotImplementedError + + def decode(self, z): + raise NotImplementedError + + def tokenize(self, x): + raise NotImplementedError + + def decode_tokens(self, tokens): + raise NotImplementedError + +class AutoencoderPretransform(Pretransform): + def __init__(self, model, scale=1.0, model_half=False, iterate_batch=False, chunked=False): + super().__init__(enable_grad=False, io_channels=model.io_channels, is_discrete=model.bottleneck is not None and model.bottleneck.is_discrete) + self.model = model + self.model.requires_grad_(False).eval() + self.scale=scale + self.downsampling_ratio = model.downsampling_ratio + self.io_channels = model.io_channels + self.sample_rate = model.sample_rate + + self.model_half = model_half + self.iterate_batch = iterate_batch + + self.encoded_channels = model.latent_dim + + self.chunked = chunked + self.num_quantizers = model.bottleneck.num_quantizers if model.bottleneck is not None and model.bottleneck.is_discrete else None + self.codebook_size = model.bottleneck.codebook_size if model.bottleneck is not None and model.bottleneck.is_discrete else None + + if self.model_half: + self.model.half() + + def encode(self, x, **kwargs): + + if self.model_half: + x = x.half() + self.model.to(torch.float16) + + encoded = self.model.encode_audio(x, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + encoded = encoded.float() + + return encoded / self.scale + + def decode(self, z, **kwargs): + z = z * self.scale + + if self.model_half: + z = z.half() + self.model.to(torch.float16) + + decoded = self.model.decode_audio(z, chunked=self.chunked, iterate_batch=self.iterate_batch, **kwargs) + + if self.model_half: + decoded = decoded.float() + + return decoded + + def tokenize(self, x, **kwargs): + assert self.model.is_discrete, "Cannot tokenize with a continuous model" + + _, info = self.model.encode(x, return_info = True, **kwargs) + + return info[self.model.bottleneck.tokens_id] + + def decode_tokens(self, tokens, **kwargs): + assert self.model.is_discrete, "Cannot decode tokens with a continuous model" + + return self.model.decode_tokens(tokens, **kwargs) + + def load_state_dict(self, state_dict, strict=True): + self.model.load_state_dict(state_dict, strict=strict) + +class WaveletPretransform(Pretransform): + def __init__(self, channels, levels, wavelet): + super().__init__(enable_grad=False, io_channels=channels, is_discrete=False) + + from .wavelets import WaveletEncode1d, WaveletDecode1d + + self.encoder = WaveletEncode1d(channels, levels, wavelet) + self.decoder = WaveletDecode1d(channels, levels, wavelet) + + self.downsampling_ratio = 2 ** levels + self.io_channels = channels + self.encoded_channels = channels * self.downsampling_ratio + + def encode(self, x): + return self.encoder(x) + + def decode(self, z): + return self.decoder(z) + +class PQMFPretransform(Pretransform): + def __init__(self, attenuation=100, num_bands=16): + # TODO: Fix PQMF to take in in-channels + super().__init__(enable_grad=False, io_channels=1, is_discrete=False) + from .pqmf import PQMF + self.pqmf = PQMF(attenuation, num_bands) + + + def encode(self, x): + # x is (Batch x Channels x Time) + x = self.pqmf.forward(x) + # pqmf.forward returns (Batch x Channels x Bands x Time) + # but Pretransform needs Batch x Channels x Time + # so concatenate channels and bands into one axis + return rearrange(x, "b c n t -> b (c n) t") + + def decode(self, x): + # x is (Batch x (Channels Bands) x Time), convert back to (Batch x Channels x Bands x Time) + x = rearrange(x, "b (c n) t -> b c n t", n=self.pqmf.num_bands) + # returns (Batch x Channels x Time) + return self.pqmf.inverse(x) + +class PretrainedDACPretransform(Pretransform): + def __init__(self, model_type="44khz", model_bitrate="8kbps", scale=1.0, quantize_on_decode: bool = True, chunked=True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + import dac + + model_path = dac.utils.download(model_type=model_type, model_bitrate=model_bitrate) + + self.model = dac.DAC.load(model_path) + + self.quantize_on_decode = quantize_on_decode + + if model_type == "44khz": + self.downsampling_ratio = 512 + else: + self.downsampling_ratio = 320 + + self.io_channels = 1 + + self.scale = scale + + self.chunked = chunked + + self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.n_codebooks + + self.codebook_size = self.model.codebook_size + + def encode(self, x): + + latents = self.model.encoder(x) + + if self.quantize_on_decode: + output = latents + else: + z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + output = z + + if self.scale != 1.0: + output = output / self.scale + + return output + + def decode(self, z): + + if self.scale != 1.0: + z = z * self.scale + + if self.quantize_on_decode: + z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + return self.model.decode(z) + + def tokenize(self, x): + return self.model.encode(x)[1] + + def decode_tokens(self, tokens): + latents = self.model.quantizer.from_codes(tokens) + return self.model.decode(latents) + +class AudiocraftCompressionPretransform(Pretransform): + def __init__(self, model_type="facebook/encodec_32khz", scale=1.0, quantize_on_decode: bool = True): + super().__init__(enable_grad=False, io_channels=1, is_discrete=True) + + try: + from audiocraft.models import CompressionModel + except ImportError: + raise ImportError("Audiocraft is not installed. Please install audiocraft to use Audiocraft models.") + + self.model = CompressionModel.get_pretrained(model_type) + + self.quantize_on_decode = quantize_on_decode + + self.downsampling_ratio = round(self.model.sample_rate / self.model.frame_rate) + + self.sample_rate = self.model.sample_rate + + self.io_channels = self.model.channels + + self.scale = scale + + #self.encoded_channels = self.model.latent_dim + + self.num_quantizers = self.model.num_codebooks + + self.codebook_size = self.model.cardinality + + self.model.to(torch.float16).eval().requires_grad_(False) + + def encode(self, x): + + assert False, "Audiocraft compression models do not support continuous encoding" + + # latents = self.model.encoder(x) + + # if self.quantize_on_decode: + # output = latents + # else: + # z, _, _, _, _ = self.model.quantizer(latents, n_quantizers=self.model.n_codebooks) + # output = z + + # if self.scale != 1.0: + # output = output / self.scale + + # return output + + def decode(self, z): + + assert False, "Audiocraft compression models do not support continuous decoding" + + # if self.scale != 1.0: + # z = z * self.scale + + # if self.quantize_on_decode: + # z, _, _, _, _ = self.model.quantizer(z, n_quantizers=self.model.n_codebooks) + + # return self.model.decode(z) + + def tokenize(self, x): + with torch.cuda.amp.autocast(enabled=False): + return self.model.encode(x.to(torch.float16))[0] + + def decode_tokens(self, tokens): + with torch.cuda.amp.autocast(enabled=False): + return self.model.decode(tokens) diff --git a/src/modules/stable_vae/models/utils.py b/src/modules/stable_vae/models/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..ec8eeaf773d47db2c000a3b2237d88d310214dcf --- /dev/null +++ b/src/modules/stable_vae/models/utils.py @@ -0,0 +1,51 @@ +import torch +import torch.nn as nn +from torchaudio import transforms as T + + +class PadCrop(nn.Module): + def __init__(self, n_samples, randomize=True): + super().__init__() + self.n_samples = n_samples + self.randomize = randomize + + def __call__(self, signal): + n, s = signal.shape + start = 0 if (not self.randomize) else torch.randint(0, max(0, s - self.n_samples) + 1, []).item() + end = start + self.n_samples + output = signal.new_zeros([n, self.n_samples]) + output[:, :min(s, self.n_samples)] = signal[:, start:end] + return output + + +def set_audio_channels(audio, target_channels): + if target_channels == 1: + # Convert to mono + audio = audio.mean(1, keepdim=True) + elif target_channels == 2: + # Convert to stereo + if audio.shape[1] == 1: + audio = audio.repeat(1, 2, 1) + elif audio.shape[1] > 2: + audio = audio[:, :2, :] + return audio + +def prepare_audio(audio, in_sr, target_sr, target_length, target_channels, device): + + audio = audio.to(device) + + if in_sr != target_sr: + resample_tf = T.Resample(in_sr, target_sr).to(device) + audio = resample_tf(audio) + + audio = PadCrop(target_length, randomize=False)(audio) + + # Add batch dimension + if audio.dim() == 1: + audio = audio.unsqueeze(0).unsqueeze(0) + elif audio.dim() == 2: + audio = audio.unsqueeze(0) + + audio = set_audio_channels(audio, target_channels) + + return audio \ No newline at end of file diff --git a/src/utils/.ipynb_checkpoints/__init__-checkpoint.py b/src/utils/.ipynb_checkpoints/__init__-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..90f60fdd89ad8575faafe45188bd1d968852fc67 --- /dev/null +++ b/src/utils/.ipynb_checkpoints/__init__-checkpoint.py @@ -0,0 +1 @@ +from .utils import * \ No newline at end of file diff --git a/src/utils/.ipynb_checkpoints/utils-checkpoint.py b/src/utils/.ipynb_checkpoints/utils-checkpoint.py new file mode 100644 index 0000000000000000000000000000000000000000..82a6bc56e9e341e54dc6a136f1f78261dde0f655 --- /dev/null +++ b/src/utils/.ipynb_checkpoints/utils-checkpoint.py @@ -0,0 +1,94 @@ +import torch +import numpy as np +import yaml +import os + + +def load_yaml_with_includes(yaml_file): + def loader_with_include(loader, node): + # Load the included file + include_path = os.path.join(os.path.dirname(yaml_file), loader.construct_scalar(node)) + with open(include_path, 'r') as f: + return yaml.load(f, Loader=yaml.FullLoader) + + yaml.add_constructor('!include', loader_with_include, Loader=yaml.FullLoader) + + with open(yaml_file, 'r') as f: + return yaml.load(f, Loader=yaml.FullLoader) + + +def scale_shift(x, scale, shift): + return (x+shift) * scale + + +def scale_shift_re(x, scale, shift): + return (x/scale) - shift + + +def align_seq(source, target_length, mapping_method='hard'): + source_len = source.shape[1] + if mapping_method == 'hard': + mapping_idx = np.round(np.arange(target_length) * source_len / target_length) + output = source[:, mapping_idx] + else: + # TBD + raise NotImplementedError + + return output + + +def customized_lr_scheduler(optimizer, warmup_steps=-1): + from torch.optim.lr_scheduler import LambdaLR + + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'customized': + return customized_lr_scheduler(optimizer, **kwargs) + elif name == 'cosine': + from torch.optim.lr_scheduler import CosineAnnealingLR + return CosineAnnealingLR(optimizer, **kwargs) + else: + raise NotImplementedError(name) + + +def compute_snr(noise_scheduler, timesteps): + """ + Computes SNR as per + https://github.com/TiankaiHang/Min-SNR-Diffusion + Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion + # Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + +if __name__ == "__main__": + + a = torch.rand(2, 10) + target_len = 15 + + b = align_seq(a, target_len) \ No newline at end of file diff --git a/src/utils/__init__.py b/src/utils/__init__.py new file mode 100644 index 0000000000000000000000000000000000000000..90f60fdd89ad8575faafe45188bd1d968852fc67 --- /dev/null +++ b/src/utils/__init__.py @@ -0,0 +1 @@ +from .utils import * \ No newline at end of file diff --git a/src/utils/__pycache__/__init__.cpython-310.pyc b/src/utils/__pycache__/__init__.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..4f5bf998d1743a35df9c2834fa91f5bd74a593f6 Binary files /dev/null and b/src/utils/__pycache__/__init__.cpython-310.pyc differ diff --git a/src/utils/__pycache__/__init__.cpython-311.pyc b/src/utils/__pycache__/__init__.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..23b6f8b4f31f08d986faf16cf2f3fc26440c5aca Binary files /dev/null and b/src/utils/__pycache__/__init__.cpython-311.pyc differ diff --git a/src/utils/__pycache__/utils.cpython-310.pyc b/src/utils/__pycache__/utils.cpython-310.pyc new file mode 100644 index 0000000000000000000000000000000000000000..a7b4f963755252cd0393b6ec4a9b70d930937fe8 Binary files /dev/null and b/src/utils/__pycache__/utils.cpython-310.pyc differ diff --git a/src/utils/__pycache__/utils.cpython-311.pyc b/src/utils/__pycache__/utils.cpython-311.pyc new file mode 100644 index 0000000000000000000000000000000000000000..e6023a70f40be0ead3861f81fb870c4623a7b2e1 Binary files /dev/null and b/src/utils/__pycache__/utils.cpython-311.pyc differ diff --git a/src/utils/utils.py b/src/utils/utils.py new file mode 100644 index 0000000000000000000000000000000000000000..82a6bc56e9e341e54dc6a136f1f78261dde0f655 --- /dev/null +++ b/src/utils/utils.py @@ -0,0 +1,94 @@ +import torch +import numpy as np +import yaml +import os + + +def load_yaml_with_includes(yaml_file): + def loader_with_include(loader, node): + # Load the included file + include_path = os.path.join(os.path.dirname(yaml_file), loader.construct_scalar(node)) + with open(include_path, 'r') as f: + return yaml.load(f, Loader=yaml.FullLoader) + + yaml.add_constructor('!include', loader_with_include, Loader=yaml.FullLoader) + + with open(yaml_file, 'r') as f: + return yaml.load(f, Loader=yaml.FullLoader) + + +def scale_shift(x, scale, shift): + return (x+shift) * scale + + +def scale_shift_re(x, scale, shift): + return (x/scale) - shift + + +def align_seq(source, target_length, mapping_method='hard'): + source_len = source.shape[1] + if mapping_method == 'hard': + mapping_idx = np.round(np.arange(target_length) * source_len / target_length) + output = source[:, mapping_idx] + else: + # TBD + raise NotImplementedError + + return output + + +def customized_lr_scheduler(optimizer, warmup_steps=-1): + from torch.optim.lr_scheduler import LambdaLR + + def fn(step): + if warmup_steps > 0: + return min(step / warmup_steps, 1) + else: + return 1 + return LambdaLR(optimizer, fn) + + +def get_lr_scheduler(optimizer, name, **kwargs): + if name == 'customized': + return customized_lr_scheduler(optimizer, **kwargs) + elif name == 'cosine': + from torch.optim.lr_scheduler import CosineAnnealingLR + return CosineAnnealingLR(optimizer, **kwargs) + else: + raise NotImplementedError(name) + + +def compute_snr(noise_scheduler, timesteps): + """ + Computes SNR as per + https://github.com/TiankaiHang/Min-SNR-Diffusion + Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849 + """ + alphas_cumprod = noise_scheduler.alphas_cumprod + sqrt_alphas_cumprod = alphas_cumprod**0.5 + sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5 + + # Expand the tensors. + # Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion + # Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026 + sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None] + alpha = sqrt_alphas_cumprod.expand(timesteps.shape) + + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float() + while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape): + sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None] + sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape) + + # Compute SNR. + snr = (alpha / sigma) ** 2 + return snr + + +if __name__ == "__main__": + + a = torch.rand(2, 10) + target_len = 15 + + b = align_seq(a, target_len) \ No newline at end of file