OpenSound commited on
Commit
9d3cb0a
1 Parent(s): 5c3a213

Upload 211 files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. audiotools/__init__.py +10 -0
  2. audiotools/core/__init__.py +4 -0
  3. audiotools/core/audio_signal.py +1682 -0
  4. audiotools/core/display.py +194 -0
  5. audiotools/core/dsp.py +390 -0
  6. audiotools/core/effects.py +647 -0
  7. audiotools/core/ffmpeg.py +204 -0
  8. audiotools/core/loudness.py +320 -0
  9. audiotools/core/playback.py +252 -0
  10. audiotools/core/templates/__init__.py +0 -0
  11. audiotools/core/templates/headers.html +322 -0
  12. audiotools/core/templates/pandoc.css +407 -0
  13. audiotools/core/templates/widget.html +52 -0
  14. audiotools/core/util.py +671 -0
  15. audiotools/core/whisper.py +97 -0
  16. audiotools/data/__init__.py +3 -0
  17. audiotools/data/datasets.py +517 -0
  18. audiotools/data/preprocess.py +81 -0
  19. audiotools/data/transforms.py +1592 -0
  20. audiotools/metrics/__init__.py +6 -0
  21. audiotools/metrics/distance.py +131 -0
  22. audiotools/metrics/quality.py +159 -0
  23. audiotools/metrics/spectral.py +247 -0
  24. audiotools/ml/__init__.py +5 -0
  25. audiotools/ml/accelerator.py +184 -0
  26. audiotools/ml/decorators.py +440 -0
  27. audiotools/ml/experiment.py +90 -0
  28. audiotools/ml/layers/__init__.py +2 -0
  29. audiotools/ml/layers/base.py +328 -0
  30. audiotools/ml/layers/spectral_gate.py +127 -0
  31. audiotools/post.py +140 -0
  32. audiotools/preference.py +600 -0
  33. src/inference.py +169 -0
  34. src/inference_controlnet.py +129 -0
  35. src/models/.ipynb_checkpoints/blocks-checkpoint.py +325 -0
  36. src/models/.ipynb_checkpoints/conditioners-checkpoint.py +183 -0
  37. src/models/.ipynb_checkpoints/controlnet-checkpoint.py +318 -0
  38. src/models/.ipynb_checkpoints/udit-checkpoint.py +365 -0
  39. src/models/__pycache__/attention.cpython-311.pyc +0 -0
  40. src/models/__pycache__/blocks.cpython-310.pyc +0 -0
  41. src/models/__pycache__/blocks.cpython-311.pyc +0 -0
  42. src/models/__pycache__/conditioners.cpython-310.pyc +0 -0
  43. src/models/__pycache__/conditioners.cpython-311.pyc +0 -0
  44. src/models/__pycache__/controlnet.cpython-311.pyc +0 -0
  45. src/models/__pycache__/modules.cpython-311.pyc +0 -0
  46. src/models/__pycache__/rotary.cpython-311.pyc +0 -0
  47. src/models/__pycache__/timm.cpython-311.pyc +0 -0
  48. src/models/__pycache__/udit.cpython-310.pyc +0 -0
  49. src/models/__pycache__/udit.cpython-311.pyc +0 -0
  50. src/models/blocks.py +325 -0
audiotools/__init__.py ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ __version__ = "0.7.3"
2
+ from .core import AudioSignal
3
+ from .core import STFTParams
4
+ from .core import Meter
5
+ from .core import util
6
+ from . import metrics
7
+ from . import data
8
+ from . import ml
9
+ from .data import datasets
10
+ from .data import transforms
audiotools/core/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ from . import util
2
+ from .audio_signal import AudioSignal
3
+ from .audio_signal import STFTParams
4
+ from .loudness import Meter
audiotools/core/audio_signal.py ADDED
@@ -0,0 +1,1682 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ import functools
3
+ import hashlib
4
+ import math
5
+ import pathlib
6
+ import tempfile
7
+ import typing
8
+ import warnings
9
+ from collections import namedtuple
10
+ from pathlib import Path
11
+
12
+ import julius
13
+ import numpy as np
14
+ import soundfile
15
+ import torch
16
+
17
+ from . import util
18
+ from .display import DisplayMixin
19
+ from .dsp import DSPMixin
20
+ from .effects import EffectMixin
21
+ from .effects import ImpulseResponseMixin
22
+ from .ffmpeg import FFMPEGMixin
23
+ from .loudness import LoudnessMixin
24
+ from .playback import PlayMixin
25
+ from .whisper import WhisperMixin
26
+
27
+
28
+ STFTParams = namedtuple(
29
+ "STFTParams",
30
+ ["window_length", "hop_length", "window_type", "match_stride", "padding_type"],
31
+ )
32
+ """
33
+ STFTParams object is a container that holds STFT parameters - window_length,
34
+ hop_length, and window_type. Not all parameters need to be specified. Ones that
35
+ are not specified will be inferred by the AudioSignal parameters.
36
+
37
+ Parameters
38
+ ----------
39
+ window_length : int, optional
40
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
41
+ hop_length : int, optional
42
+ Hop length of STFT, by default ``window_length // 4``.
43
+ window_type : str, optional
44
+ Type of window to use, by default ``sqrt\_hann``.
45
+ match_stride : bool, optional
46
+ Whether to match the stride of convolutional layers, by default False
47
+ padding_type : str, optional
48
+ Type of padding to use, by default 'reflect'
49
+ """
50
+ STFTParams.__new__.__defaults__ = (None, None, None, None, None)
51
+
52
+
53
+ class AudioSignal(
54
+ EffectMixin,
55
+ LoudnessMixin,
56
+ PlayMixin,
57
+ ImpulseResponseMixin,
58
+ DSPMixin,
59
+ DisplayMixin,
60
+ FFMPEGMixin,
61
+ WhisperMixin,
62
+ ):
63
+ """This is the core object of this library. Audio is always
64
+ loaded into an AudioSignal, which then enables all the features
65
+ of this library, including audio augmentations, I/O, playback,
66
+ and more.
67
+
68
+ The structure of this object is that the base functionality
69
+ is defined in ``core/audio_signal.py``, while extensions to
70
+ that functionality are defined in the other ``core/*.py``
71
+ files. For example, all the display-based functionality
72
+ (e.g. plot spectrograms, waveforms, write to tensorboard)
73
+ are in ``core/display.py``.
74
+
75
+ Parameters
76
+ ----------
77
+ audio_path_or_array : typing.Union[torch.Tensor, str, Path, np.ndarray]
78
+ Object to create AudioSignal from. Can be a tensor, numpy array,
79
+ or a path to a file. The file is always reshaped to
80
+ sample_rate : int, optional
81
+ Sample rate of the audio. If different from underlying file, resampling is
82
+ performed. If passing in an array or tensor, this must be defined,
83
+ by default None
84
+ stft_params : STFTParams, optional
85
+ Parameters of STFT to use. , by default None
86
+ offset : float, optional
87
+ Offset in seconds to read from file, by default 0
88
+ duration : float, optional
89
+ Duration in seconds to read from file, by default None
90
+ device : str, optional
91
+ Device to load audio onto, by default None
92
+
93
+ Examples
94
+ --------
95
+ Loading an AudioSignal from an array, at a sample rate of
96
+ 44100.
97
+
98
+ >>> signal = AudioSignal(torch.randn(5*44100), 44100)
99
+
100
+ Note, the signal is reshaped to have a batch size, and one
101
+ audio channel:
102
+
103
+ >>> print(signal.shape)
104
+ (1, 1, 44100)
105
+
106
+ You can treat AudioSignals like tensors, and many of the same
107
+ functions you might use on tensors are defined for AudioSignals
108
+ as well:
109
+
110
+ >>> signal.to("cuda")
111
+ >>> signal.cuda()
112
+ >>> signal.clone()
113
+ >>> signal.detach()
114
+
115
+ Indexing AudioSignals returns an AudioSignal:
116
+
117
+ >>> signal[..., 3*44100:4*44100]
118
+
119
+ The above signal is 1 second long, and is also an AudioSignal.
120
+ """
121
+
122
+ def __init__(
123
+ self,
124
+ audio_path_or_array: typing.Union[torch.Tensor, str, Path, np.ndarray],
125
+ sample_rate: int = None,
126
+ stft_params: STFTParams = None,
127
+ offset: float = 0,
128
+ duration: float = None,
129
+ device: str = None,
130
+ ):
131
+ audio_path = None
132
+ audio_array = None
133
+
134
+ if isinstance(audio_path_or_array, str):
135
+ audio_path = audio_path_or_array
136
+ elif isinstance(audio_path_or_array, pathlib.Path):
137
+ audio_path = audio_path_or_array
138
+ elif isinstance(audio_path_or_array, np.ndarray):
139
+ audio_array = audio_path_or_array
140
+ elif torch.is_tensor(audio_path_or_array):
141
+ audio_array = audio_path_or_array
142
+ else:
143
+ raise ValueError(
144
+ "audio_path_or_array must be either a Path, "
145
+ "string, numpy array, or torch Tensor!"
146
+ )
147
+
148
+ self.path_to_file = None
149
+
150
+ self.audio_data = None
151
+ self.sources = None # List of AudioSignal objects.
152
+ self.stft_data = None
153
+ if audio_path is not None:
154
+ self.load_from_file(
155
+ audio_path, offset=offset, duration=duration, device=device
156
+ )
157
+ elif audio_array is not None:
158
+ assert sample_rate is not None, "Must set sample rate!"
159
+ self.load_from_array(audio_array, sample_rate, device=device)
160
+
161
+ self.window = None
162
+ self.stft_params = stft_params
163
+
164
+ self.metadata = {
165
+ "offset": offset,
166
+ "duration": duration,
167
+ }
168
+
169
+ @property
170
+ def path_to_input_file(
171
+ self,
172
+ ):
173
+ """
174
+ Path to input file, if it exists.
175
+ Alias to ``path_to_file`` for backwards compatibility
176
+ """
177
+ return self.path_to_file
178
+
179
+ @classmethod
180
+ def excerpt(
181
+ cls,
182
+ audio_path: typing.Union[str, Path],
183
+ offset: float = None,
184
+ duration: float = None,
185
+ state: typing.Union[np.random.RandomState, int] = None,
186
+ **kwargs,
187
+ ):
188
+ """Randomly draw an excerpt of ``duration`` seconds from an
189
+ audio file specified at ``audio_path``, between ``offset`` seconds
190
+ and end of file. ``state`` can be used to seed the random draw.
191
+
192
+ Parameters
193
+ ----------
194
+ audio_path : typing.Union[str, Path]
195
+ Path to audio file to grab excerpt from.
196
+ offset : float, optional
197
+ Lower bound for the start time, in seconds drawn from
198
+ the file, by default None.
199
+ duration : float, optional
200
+ Duration of excerpt, in seconds, by default None
201
+ state : typing.Union[np.random.RandomState, int], optional
202
+ RandomState or seed of random state, by default None
203
+
204
+ Returns
205
+ -------
206
+ AudioSignal
207
+ AudioSignal containing excerpt.
208
+
209
+ Examples
210
+ --------
211
+ >>> signal = AudioSignal.excerpt("path/to/audio", duration=5)
212
+ """
213
+ info = util.info(audio_path)
214
+ total_duration = info.duration
215
+
216
+ state = util.random_state(state)
217
+ lower_bound = 0 if offset is None else offset
218
+ upper_bound = max(total_duration - duration, 0)
219
+ offset = state.uniform(lower_bound, upper_bound)
220
+
221
+ signal = cls(audio_path, offset=offset, duration=duration, **kwargs)
222
+ signal.metadata["offset"] = offset
223
+ signal.metadata["duration"] = duration
224
+
225
+ return signal
226
+
227
+ @classmethod
228
+ def salient_excerpt(
229
+ cls,
230
+ audio_path: typing.Union[str, Path],
231
+ loudness_cutoff: float = None,
232
+ num_tries: int = 8,
233
+ state: typing.Union[np.random.RandomState, int] = None,
234
+ **kwargs,
235
+ ):
236
+ """Similar to AudioSignal.excerpt, except it extracts excerpts only
237
+ if they are above a specified loudness threshold, which is computed via
238
+ a fast LUFS routine.
239
+
240
+ Parameters
241
+ ----------
242
+ audio_path : typing.Union[str, Path]
243
+ Path to audio file to grab excerpt from.
244
+ loudness_cutoff : float, optional
245
+ Loudness threshold in dB. Typical values are ``-40, -60``,
246
+ etc, by default None
247
+ num_tries : int, optional
248
+ Number of tries to grab an excerpt above the threshold
249
+ before giving up, by default 8.
250
+ state : typing.Union[np.random.RandomState, int], optional
251
+ RandomState or seed of random state, by default None
252
+ kwargs : dict
253
+ Keyword arguments to AudioSignal.excerpt
254
+
255
+ Returns
256
+ -------
257
+ AudioSignal
258
+ AudioSignal containing excerpt.
259
+
260
+
261
+ .. warning::
262
+ if ``num_tries`` is set to None, ``salient_excerpt`` may try forever, which can
263
+ result in an infinite loop if ``audio_path`` does not have
264
+ any loud enough excerpts.
265
+
266
+ Examples
267
+ --------
268
+ >>> signal = AudioSignal.salient_excerpt(
269
+ "path/to/audio",
270
+ loudness_cutoff=-40,
271
+ duration=5
272
+ )
273
+ """
274
+ state = util.random_state(state)
275
+ if loudness_cutoff is None:
276
+ excerpt = cls.excerpt(audio_path, state=state, **kwargs)
277
+ else:
278
+ loudness = -np.inf
279
+ num_try = 0
280
+ while loudness <= loudness_cutoff:
281
+ excerpt = cls.excerpt(audio_path, state=state, **kwargs)
282
+ loudness = excerpt.loudness()
283
+ num_try += 1
284
+ if num_tries is not None and num_try >= num_tries:
285
+ break
286
+ return excerpt
287
+
288
+ @classmethod
289
+ def zeros(
290
+ cls,
291
+ duration: float,
292
+ sample_rate: int,
293
+ num_channels: int = 1,
294
+ batch_size: int = 1,
295
+ **kwargs,
296
+ ):
297
+ """Helper function create an AudioSignal of all zeros.
298
+
299
+ Parameters
300
+ ----------
301
+ duration : float
302
+ Duration of AudioSignal
303
+ sample_rate : int
304
+ Sample rate of AudioSignal
305
+ num_channels : int, optional
306
+ Number of channels, by default 1
307
+ batch_size : int, optional
308
+ Batch size, by default 1
309
+
310
+ Returns
311
+ -------
312
+ AudioSignal
313
+ AudioSignal containing all zeros.
314
+
315
+ Examples
316
+ --------
317
+ Generate 5 seconds of all zeros at a sample rate of 44100.
318
+
319
+ >>> signal = AudioSignal.zeros(5.0, 44100)
320
+ """
321
+ n_samples = int(duration * sample_rate)
322
+ return cls(
323
+ torch.zeros(batch_size, num_channels, n_samples), sample_rate, **kwargs
324
+ )
325
+
326
+ @classmethod
327
+ def wave(
328
+ cls,
329
+ frequency: float,
330
+ duration: float,
331
+ sample_rate: int,
332
+ num_channels: int = 1,
333
+ shape: str = "sine",
334
+ **kwargs,
335
+ ):
336
+ """
337
+ Generate a waveform of a given frequency and shape.
338
+
339
+ Parameters
340
+ ----------
341
+ frequency : float
342
+ Frequency of the waveform
343
+ duration : float
344
+ Duration of the waveform
345
+ sample_rate : int
346
+ Sample rate of the waveform
347
+ num_channels : int, optional
348
+ Number of channels, by default 1
349
+ shape : str, optional
350
+ Shape of the waveform, by default "saw"
351
+ One of "sawtooth", "square", "sine", "triangle"
352
+ kwargs : dict
353
+ Keyword arguments to AudioSignal
354
+ """
355
+ n_samples = int(duration * sample_rate)
356
+ t = torch.linspace(0, duration, n_samples)
357
+ if shape == "sawtooth":
358
+ from scipy.signal import sawtooth
359
+
360
+ wave_data = sawtooth(2 * np.pi * frequency * t, 0.5)
361
+ elif shape == "square":
362
+ from scipy.signal import square
363
+
364
+ wave_data = square(2 * np.pi * frequency * t)
365
+ elif shape == "sine":
366
+ wave_data = np.sin(2 * np.pi * frequency * t)
367
+ elif shape == "triangle":
368
+ from scipy.signal import sawtooth
369
+
370
+ # frequency is doubled by the abs call, so omit the 2 in 2pi
371
+ wave_data = sawtooth(np.pi * frequency * t, 0.5)
372
+ wave_data = -np.abs(wave_data) * 2 + 1
373
+ else:
374
+ raise ValueError(f"Invalid shape {shape}")
375
+
376
+ wave_data = torch.tensor(wave_data, dtype=torch.float32)
377
+ wave_data = wave_data.unsqueeze(0).unsqueeze(0).repeat(1, num_channels, 1)
378
+ return cls(wave_data, sample_rate, **kwargs)
379
+
380
+ @classmethod
381
+ def batch(
382
+ cls,
383
+ audio_signals: list,
384
+ pad_signals: bool = False,
385
+ truncate_signals: bool = False,
386
+ resample: bool = False,
387
+ dim: int = 0,
388
+ ):
389
+ """Creates a batched AudioSignal from a list of AudioSignals.
390
+
391
+ Parameters
392
+ ----------
393
+ audio_signals : list[AudioSignal]
394
+ List of AudioSignal objects
395
+ pad_signals : bool, optional
396
+ Whether to pad signals to length of the maximum length
397
+ AudioSignal in the list, by default False
398
+ truncate_signals : bool, optional
399
+ Whether to truncate signals to length of shortest length
400
+ AudioSignal in the list, by default False
401
+ resample : bool, optional
402
+ Whether to resample AudioSignal to the sample rate of
403
+ the first AudioSignal in the list, by default False
404
+ dim : int, optional
405
+ Dimension along which to batch the signals.
406
+
407
+ Returns
408
+ -------
409
+ AudioSignal
410
+ Batched AudioSignal.
411
+
412
+ Raises
413
+ ------
414
+ RuntimeError
415
+ If not all AudioSignals are the same sample rate, and
416
+ ``resample=False``, an error is raised.
417
+ RuntimeError
418
+ If not all AudioSignals are the same the length, and
419
+ both ``pad_signals=False`` and ``truncate_signals=False``,
420
+ an error is raised.
421
+
422
+ Examples
423
+ --------
424
+ Batching a bunch of random signals:
425
+
426
+ >>> signal_list = [AudioSignal(torch.randn(44100), 44100) for _ in range(10)]
427
+ >>> signal = AudioSignal.batch(signal_list)
428
+ >>> print(signal.shape)
429
+ (10, 1, 44100)
430
+
431
+ """
432
+ signal_lengths = [x.signal_length for x in audio_signals]
433
+ sample_rates = [x.sample_rate for x in audio_signals]
434
+
435
+ if len(set(sample_rates)) != 1:
436
+ if resample:
437
+ for x in audio_signals:
438
+ x.resample(sample_rates[0])
439
+ else:
440
+ raise RuntimeError(
441
+ f"Not all signals had the same sample rate! Got {sample_rates}. "
442
+ f"All signals must have the same sample rate, or resample must be True. "
443
+ )
444
+
445
+ if len(set(signal_lengths)) != 1:
446
+ if pad_signals:
447
+ max_length = max(signal_lengths)
448
+ for x in audio_signals:
449
+ pad_len = max_length - x.signal_length
450
+ x.zero_pad(0, pad_len)
451
+ elif truncate_signals:
452
+ min_length = min(signal_lengths)
453
+ for x in audio_signals:
454
+ x.truncate_samples(min_length)
455
+ else:
456
+ raise RuntimeError(
457
+ f"Not all signals had the same length! Got {signal_lengths}. "
458
+ f"All signals must be the same length, or pad_signals/truncate_signals "
459
+ f"must be True. "
460
+ )
461
+ # Concatenate along the specified dimension (default 0)
462
+ audio_data = torch.cat([x.audio_data for x in audio_signals], dim=dim)
463
+ audio_paths = [x.path_to_file for x in audio_signals]
464
+
465
+ batched_signal = cls(
466
+ audio_data,
467
+ sample_rate=audio_signals[0].sample_rate,
468
+ )
469
+ batched_signal.path_to_file = audio_paths
470
+ return batched_signal
471
+
472
+ # I/O
473
+ def load_from_file(
474
+ self,
475
+ audio_path: typing.Union[str, Path],
476
+ offset: float,
477
+ duration: float,
478
+ device: str = "cpu",
479
+ ):
480
+ """Loads data from file. Used internally when AudioSignal
481
+ is instantiated with a path to a file.
482
+
483
+ Parameters
484
+ ----------
485
+ audio_path : typing.Union[str, Path]
486
+ Path to file
487
+ offset : float
488
+ Offset in seconds
489
+ duration : float
490
+ Duration in seconds
491
+ device : str, optional
492
+ Device to put AudioSignal on, by default "cpu"
493
+
494
+ Returns
495
+ -------
496
+ AudioSignal
497
+ AudioSignal loaded from file
498
+ """
499
+ import librosa
500
+
501
+ data, sample_rate = librosa.load(
502
+ audio_path,
503
+ offset=offset,
504
+ duration=duration,
505
+ sr=None,
506
+ mono=False,
507
+ )
508
+ data = util.ensure_tensor(data)
509
+ if data.shape[-1] == 0:
510
+ raise RuntimeError(
511
+ f"Audio file {audio_path} with offset {offset} and duration {duration} is empty!"
512
+ )
513
+
514
+ if data.ndim < 2:
515
+ data = data.unsqueeze(0)
516
+ if data.ndim < 3:
517
+ data = data.unsqueeze(0)
518
+ self.audio_data = data
519
+
520
+ self.original_signal_length = self.signal_length
521
+
522
+ self.sample_rate = sample_rate
523
+ self.path_to_file = audio_path
524
+ return self.to(device)
525
+
526
+ def load_from_array(
527
+ self,
528
+ audio_array: typing.Union[torch.Tensor, np.ndarray],
529
+ sample_rate: int,
530
+ device: str = "cpu",
531
+ ):
532
+ """Loads data from array, reshaping it to be exactly 3
533
+ dimensions. Used internally when AudioSignal is called
534
+ with a tensor or an array.
535
+
536
+ Parameters
537
+ ----------
538
+ audio_array : typing.Union[torch.Tensor, np.ndarray]
539
+ Array/tensor of audio of samples.
540
+ sample_rate : int
541
+ Sample rate of audio
542
+ device : str, optional
543
+ Device to move audio onto, by default "cpu"
544
+
545
+ Returns
546
+ -------
547
+ AudioSignal
548
+ AudioSignal loaded from array
549
+ """
550
+ audio_data = util.ensure_tensor(audio_array)
551
+
552
+ if audio_data.dtype == torch.double:
553
+ audio_data = audio_data.float()
554
+
555
+ if audio_data.ndim < 2:
556
+ audio_data = audio_data.unsqueeze(0)
557
+ if audio_data.ndim < 3:
558
+ audio_data = audio_data.unsqueeze(0)
559
+ self.audio_data = audio_data
560
+
561
+ self.original_signal_length = self.signal_length
562
+
563
+ self.sample_rate = sample_rate
564
+ return self.to(device)
565
+
566
+ def write(self, audio_path: typing.Union[str, Path]):
567
+ """Writes audio to a file. Only writes the audio
568
+ that is in the very first item of the batch. To write other items
569
+ in the batch, index the signal along the batch dimension
570
+ before writing. After writing, the signal's ``path_to_file``
571
+ attribute is updated to the new path.
572
+
573
+ Parameters
574
+ ----------
575
+ audio_path : typing.Union[str, Path]
576
+ Path to write audio to.
577
+
578
+ Returns
579
+ -------
580
+ AudioSignal
581
+ Returns original AudioSignal, so you can use this in a fluent
582
+ interface.
583
+
584
+ Examples
585
+ --------
586
+ Creating and writing a signal to disk:
587
+
588
+ >>> signal = AudioSignal(torch.randn(10, 1, 44100), 44100)
589
+ >>> signal.write("/tmp/out.wav")
590
+
591
+ Writing a different element of the batch:
592
+
593
+ >>> signal[5].write("/tmp/out.wav")
594
+
595
+ Using this in a fluent interface:
596
+
597
+ >>> signal.write("/tmp/original.wav").low_pass(4000).write("/tmp/lowpass.wav")
598
+
599
+ """
600
+ if self.audio_data[0].abs().max() > 1:
601
+ warnings.warn("Audio amplitude > 1 clipped when saving")
602
+ soundfile.write(str(audio_path), self.audio_data[0].numpy().T, self.sample_rate)
603
+
604
+ self.path_to_file = audio_path
605
+ return self
606
+
607
+ def deepcopy(self):
608
+ """Copies the signal and all of its attributes.
609
+
610
+ Returns
611
+ -------
612
+ AudioSignal
613
+ Deep copy of the audio signal.
614
+ """
615
+ return copy.deepcopy(self)
616
+
617
+ def copy(self):
618
+ """Shallow copy of signal.
619
+
620
+ Returns
621
+ -------
622
+ AudioSignal
623
+ Shallow copy of the audio signal.
624
+ """
625
+ return copy.copy(self)
626
+
627
+ def clone(self):
628
+ """Clones all tensors contained in the AudioSignal,
629
+ and returns a copy of the signal with everything
630
+ cloned. Useful when using AudioSignal within autograd
631
+ computation graphs.
632
+
633
+ Relevant attributes are the stft data, the audio data,
634
+ and the loudness of the file.
635
+
636
+ Returns
637
+ -------
638
+ AudioSignal
639
+ Clone of AudioSignal.
640
+ """
641
+ clone = type(self)(
642
+ self.audio_data.clone(),
643
+ self.sample_rate,
644
+ stft_params=self.stft_params,
645
+ )
646
+ if self.stft_data is not None:
647
+ clone.stft_data = self.stft_data.clone()
648
+ if self._loudness is not None:
649
+ clone._loudness = self._loudness.clone()
650
+ clone.path_to_file = copy.deepcopy(self.path_to_file)
651
+ clone.metadata = copy.deepcopy(self.metadata)
652
+ return clone
653
+
654
+ def detach(self):
655
+ """Detaches tensors contained in AudioSignal.
656
+
657
+ Relevant attributes are the stft data, the audio data,
658
+ and the loudness of the file.
659
+
660
+ Returns
661
+ -------
662
+ AudioSignal
663
+ Same signal, but with all tensors detached.
664
+ """
665
+ if self._loudness is not None:
666
+ self._loudness = self._loudness.detach()
667
+ if self.stft_data is not None:
668
+ self.stft_data = self.stft_data.detach()
669
+
670
+ self.audio_data = self.audio_data.detach()
671
+ return self
672
+
673
+ def hash(self):
674
+ """Writes the audio data to a temporary file, and then
675
+ hashes it using hashlib. Useful for creating a file
676
+ name based on the audio content.
677
+
678
+ Returns
679
+ -------
680
+ str
681
+ Hash of audio data.
682
+
683
+ Examples
684
+ --------
685
+ Creating a signal, and writing it to a unique file name:
686
+
687
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
688
+ >>> hash = signal.hash()
689
+ >>> signal.write(f"{hash}.wav")
690
+
691
+ """
692
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
693
+ self.write(f.name)
694
+ h = hashlib.sha256()
695
+ b = bytearray(128 * 1024)
696
+ mv = memoryview(b)
697
+ with open(f.name, "rb", buffering=0) as f:
698
+ for n in iter(lambda: f.readinto(mv), 0):
699
+ h.update(mv[:n])
700
+ file_hash = h.hexdigest()
701
+ return file_hash
702
+
703
+ # Signal operations
704
+ def to_mono(self):
705
+ """Converts audio data to mono audio, by taking the mean
706
+ along the channels dimension.
707
+
708
+ Returns
709
+ -------
710
+ AudioSignal
711
+ AudioSignal with mean of channels.
712
+ """
713
+ self.audio_data = self.audio_data.mean(1, keepdim=True)
714
+ return self
715
+
716
+ def resample(self, sample_rate: int):
717
+ """Resamples the audio, using sinc interpolation. This works on both
718
+ cpu and gpu, and is much faster on gpu.
719
+
720
+ Parameters
721
+ ----------
722
+ sample_rate : int
723
+ Sample rate to resample to.
724
+
725
+ Returns
726
+ -------
727
+ AudioSignal
728
+ Resampled AudioSignal
729
+ """
730
+ if sample_rate == self.sample_rate:
731
+ return self
732
+ self.audio_data = julius.resample_frac(
733
+ self.audio_data, self.sample_rate, sample_rate
734
+ )
735
+ self.sample_rate = sample_rate
736
+ return self
737
+
738
+ # Tensor operations
739
+ def to(self, device: str):
740
+ """Moves all tensors contained in signal to the specified device.
741
+
742
+ Parameters
743
+ ----------
744
+ device : str
745
+ Device to move AudioSignal onto. Typical values are
746
+ "cuda", "cpu", or "cuda:n" to specify the nth gpu.
747
+
748
+ Returns
749
+ -------
750
+ AudioSignal
751
+ AudioSignal with all tensors moved to specified device.
752
+ """
753
+ if self._loudness is not None:
754
+ self._loudness = self._loudness.to(device)
755
+ if self.stft_data is not None:
756
+ self.stft_data = self.stft_data.to(device)
757
+ if self.audio_data is not None:
758
+ self.audio_data = self.audio_data.to(device)
759
+ return self
760
+
761
+ def float(self):
762
+ """Calls ``.float()`` on ``self.audio_data``.
763
+
764
+ Returns
765
+ -------
766
+ AudioSignal
767
+ """
768
+ self.audio_data = self.audio_data.float()
769
+ return self
770
+
771
+ def cpu(self):
772
+ """Moves AudioSignal to cpu.
773
+
774
+ Returns
775
+ -------
776
+ AudioSignal
777
+ """
778
+ return self.to("cpu")
779
+
780
+ def cuda(self): # pragma: no cover
781
+ """Moves AudioSignal to cuda.
782
+
783
+ Returns
784
+ -------
785
+ AudioSignal
786
+ """
787
+ return self.to("cuda")
788
+
789
+ def numpy(self):
790
+ """Detaches ``self.audio_data``, moves to cpu, and converts to numpy.
791
+
792
+ Returns
793
+ -------
794
+ np.ndarray
795
+ Audio data as a numpy array.
796
+ """
797
+ return self.audio_data.detach().cpu().numpy()
798
+
799
+ def zero_pad(self, before: int, after: int):
800
+ """Zero pads the audio_data tensor before and after.
801
+
802
+ Parameters
803
+ ----------
804
+ before : int
805
+ How many zeros to prepend to audio.
806
+ after : int
807
+ How many zeros to append to audio.
808
+
809
+ Returns
810
+ -------
811
+ AudioSignal
812
+ AudioSignal with padding applied.
813
+ """
814
+ self.audio_data = torch.nn.functional.pad(self.audio_data, (before, after))
815
+ return self
816
+
817
+ def zero_pad_to(self, length: int, mode: str = "after"):
818
+ """Pad with zeros to a specified length, either before or after
819
+ the audio data.
820
+
821
+ Parameters
822
+ ----------
823
+ length : int
824
+ Length to pad to
825
+ mode : str, optional
826
+ Whether to prepend or append zeros to signal, by default "after"
827
+
828
+ Returns
829
+ -------
830
+ AudioSignal
831
+ AudioSignal with padding applied.
832
+ """
833
+ if mode == "before":
834
+ self.zero_pad(max(length - self.signal_length, 0), 0)
835
+ elif mode == "after":
836
+ self.zero_pad(0, max(length - self.signal_length, 0))
837
+ return self
838
+
839
+ def trim(self, before: int, after: int):
840
+ """Trims the audio_data tensor before and after.
841
+
842
+ Parameters
843
+ ----------
844
+ before : int
845
+ How many samples to trim from beginning.
846
+ after : int
847
+ How many samples to trim from end.
848
+
849
+ Returns
850
+ -------
851
+ AudioSignal
852
+ AudioSignal with trimming applied.
853
+ """
854
+ if after == 0:
855
+ self.audio_data = self.audio_data[..., before:]
856
+ else:
857
+ self.audio_data = self.audio_data[..., before:-after]
858
+ return self
859
+
860
+ def truncate_samples(self, length_in_samples: int):
861
+ """Truncate signal to specified length.
862
+
863
+ Parameters
864
+ ----------
865
+ length_in_samples : int
866
+ Truncate to this many samples.
867
+
868
+ Returns
869
+ -------
870
+ AudioSignal
871
+ AudioSignal with truncation applied.
872
+ """
873
+ self.audio_data = self.audio_data[..., :length_in_samples]
874
+ return self
875
+
876
+ @property
877
+ def device(self):
878
+ """Get device that AudioSignal is on.
879
+
880
+ Returns
881
+ -------
882
+ torch.device
883
+ Device that AudioSignal is on.
884
+ """
885
+ if self.audio_data is not None:
886
+ device = self.audio_data.device
887
+ elif self.stft_data is not None:
888
+ device = self.stft_data.device
889
+ return device
890
+
891
+ # Properties
892
+ @property
893
+ def audio_data(self):
894
+ """Returns the audio data tensor in the object.
895
+
896
+ Audio data is always of the shape
897
+ (batch_size, num_channels, num_samples). If value has less
898
+ than 3 dims (e.g. is (num_channels, num_samples)), then it will
899
+ be reshaped to (1, num_channels, num_samples) - a batch size of 1.
900
+
901
+ Parameters
902
+ ----------
903
+ data : typing.Union[torch.Tensor, np.ndarray]
904
+ Audio data to set.
905
+
906
+ Returns
907
+ -------
908
+ torch.Tensor
909
+ Audio samples.
910
+ """
911
+ return self._audio_data
912
+
913
+ @audio_data.setter
914
+ def audio_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
915
+ if data is not None:
916
+ assert torch.is_tensor(data), "audio_data should be torch.Tensor"
917
+ assert data.ndim == 3, "audio_data should be 3-dim (B, C, T)"
918
+ self._audio_data = data
919
+ # Old loudness value not guaranteed to be right, reset it.
920
+ self._loudness = None
921
+ return
922
+
923
+ # alias for audio_data
924
+ samples = audio_data
925
+
926
+ @property
927
+ def stft_data(self):
928
+ """Returns the STFT data inside the signal. Shape is
929
+ (batch, channels, frequencies, time).
930
+
931
+ Returns
932
+ -------
933
+ torch.Tensor
934
+ Complex spectrogram data.
935
+ """
936
+ return self._stft_data
937
+
938
+ @stft_data.setter
939
+ def stft_data(self, data: typing.Union[torch.Tensor, np.ndarray]):
940
+ if data is not None:
941
+ assert torch.is_tensor(data) and torch.is_complex(data)
942
+ if self.stft_data is not None and self.stft_data.shape != data.shape:
943
+ warnings.warn("stft_data changed shape")
944
+ self._stft_data = data
945
+ return
946
+
947
+ @property
948
+ def batch_size(self):
949
+ """Batch size of audio signal.
950
+
951
+ Returns
952
+ -------
953
+ int
954
+ Batch size of signal.
955
+ """
956
+ return self.audio_data.shape[0]
957
+
958
+ @property
959
+ def signal_length(self):
960
+ """Length of audio signal.
961
+
962
+ Returns
963
+ -------
964
+ int
965
+ Length of signal in samples.
966
+ """
967
+ return self.audio_data.shape[-1]
968
+
969
+ # alias for signal_length
970
+ length = signal_length
971
+
972
+ @property
973
+ def shape(self):
974
+ """Shape of audio data.
975
+
976
+ Returns
977
+ -------
978
+ tuple
979
+ Shape of audio data.
980
+ """
981
+ return self.audio_data.shape
982
+
983
+ @property
984
+ def signal_duration(self):
985
+ """Length of audio signal in seconds.
986
+
987
+ Returns
988
+ -------
989
+ float
990
+ Length of signal in seconds.
991
+ """
992
+ return self.signal_length / self.sample_rate
993
+
994
+ # alias for signal_duration
995
+ duration = signal_duration
996
+
997
+ @property
998
+ def num_channels(self):
999
+ """Number of audio channels.
1000
+
1001
+ Returns
1002
+ -------
1003
+ int
1004
+ Number of audio channels.
1005
+ """
1006
+ return self.audio_data.shape[1]
1007
+
1008
+ # STFT
1009
+ @staticmethod
1010
+ @functools.lru_cache(None)
1011
+ def get_window(window_type: str, window_length: int, device: str):
1012
+ """Wrapper around scipy.signal.get_window so one can also get the
1013
+ popular sqrt-hann window. This function caches for efficiency
1014
+ using functools.lru\_cache.
1015
+
1016
+ Parameters
1017
+ ----------
1018
+ window_type : str
1019
+ Type of window to get
1020
+ window_length : int
1021
+ Length of the window
1022
+ device : str
1023
+ Device to put window onto.
1024
+
1025
+ Returns
1026
+ -------
1027
+ torch.Tensor
1028
+ Window returned by scipy.signal.get_window, as a tensor.
1029
+ """
1030
+ from scipy import signal
1031
+
1032
+ if window_type == "average":
1033
+ window = np.ones(window_length) / window_length
1034
+ elif window_type == "sqrt_hann":
1035
+ window = np.sqrt(signal.get_window("hann", window_length))
1036
+ else:
1037
+ window = signal.get_window(window_type, window_length)
1038
+ window = torch.from_numpy(window).to(device).float()
1039
+ return window
1040
+
1041
+ @property
1042
+ def stft_params(self):
1043
+ """Returns STFTParams object, which can be re-used to other
1044
+ AudioSignals.
1045
+
1046
+ This property can be set as well. If values are not defined in STFTParams,
1047
+ they are inferred automatically from the signal properties. The default is to use
1048
+ 32ms windows, with 8ms hop length, and the square root of the hann window.
1049
+
1050
+ Returns
1051
+ -------
1052
+ STFTParams
1053
+ STFT parameters for the AudioSignal.
1054
+
1055
+ Examples
1056
+ --------
1057
+ >>> stft_params = STFTParams(128, 32)
1058
+ >>> signal1 = AudioSignal(torch.randn(44100), 44100, stft_params=stft_params)
1059
+ >>> signal2 = AudioSignal(torch.randn(44100), 44100, stft_params=signal1.stft_params)
1060
+ >>> signal1.stft_params = STFTParams() # Defaults
1061
+ """
1062
+ return self._stft_params
1063
+
1064
+ @stft_params.setter
1065
+ def stft_params(self, value: STFTParams):
1066
+ default_win_len = int(2 ** (np.ceil(np.log2(0.032 * self.sample_rate))))
1067
+ default_hop_len = default_win_len // 4
1068
+ default_win_type = "hann"
1069
+ default_match_stride = False
1070
+ default_padding_type = "reflect"
1071
+
1072
+ default_stft_params = STFTParams(
1073
+ window_length=default_win_len,
1074
+ hop_length=default_hop_len,
1075
+ window_type=default_win_type,
1076
+ match_stride=default_match_stride,
1077
+ padding_type=default_padding_type,
1078
+ )._asdict()
1079
+
1080
+ value = value._asdict() if value else default_stft_params
1081
+
1082
+ for key in default_stft_params:
1083
+ if value[key] is None:
1084
+ value[key] = default_stft_params[key]
1085
+
1086
+ self._stft_params = STFTParams(**value)
1087
+ self.stft_data = None
1088
+
1089
+ def compute_stft_padding(
1090
+ self, window_length: int, hop_length: int, match_stride: bool
1091
+ ):
1092
+ """Compute how the STFT should be padded, based on match\_stride.
1093
+
1094
+ Parameters
1095
+ ----------
1096
+ window_length : int
1097
+ Window length of STFT.
1098
+ hop_length : int
1099
+ Hop length of STFT.
1100
+ match_stride : bool
1101
+ Whether or not to match stride, making the STFT have the same alignment as
1102
+ convolutional layers.
1103
+
1104
+ Returns
1105
+ -------
1106
+ tuple
1107
+ Amount to pad on either side of audio.
1108
+ """
1109
+ length = self.signal_length
1110
+
1111
+ if match_stride:
1112
+ assert (
1113
+ hop_length == window_length // 4
1114
+ ), "For match_stride, hop must equal n_fft // 4"
1115
+ right_pad = math.ceil(length / hop_length) * hop_length - length
1116
+ pad = (window_length - hop_length) // 2
1117
+ else:
1118
+ right_pad = 0
1119
+ pad = 0
1120
+
1121
+ return right_pad, pad
1122
+
1123
+ def stft(
1124
+ self,
1125
+ window_length: int = None,
1126
+ hop_length: int = None,
1127
+ window_type: str = None,
1128
+ match_stride: bool = None,
1129
+ padding_type: str = None,
1130
+ ):
1131
+ """Computes the short-time Fourier transform of the audio data,
1132
+ with specified STFT parameters.
1133
+
1134
+ Parameters
1135
+ ----------
1136
+ window_length : int, optional
1137
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
1138
+ hop_length : int, optional
1139
+ Hop length of STFT, by default ``window_length // 4``.
1140
+ window_type : str, optional
1141
+ Type of window to use, by default ``sqrt\_hann``.
1142
+ match_stride : bool, optional
1143
+ Whether to match the stride of convolutional layers, by default False
1144
+ padding_type : str, optional
1145
+ Type of padding to use, by default 'reflect'
1146
+
1147
+ Returns
1148
+ -------
1149
+ torch.Tensor
1150
+ STFT of audio data.
1151
+
1152
+ Examples
1153
+ --------
1154
+ Compute the STFT of an AudioSignal:
1155
+
1156
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1157
+ >>> signal.stft()
1158
+
1159
+ Vary the window and hop length:
1160
+
1161
+ >>> stft_params = [STFTParams(128, 32), STFTParams(512, 128)]
1162
+ >>> for stft_param in stft_params:
1163
+ >>> signal.stft_params = stft_params
1164
+ >>> signal.stft()
1165
+
1166
+ """
1167
+ window_length = (
1168
+ self.stft_params.window_length
1169
+ if window_length is None
1170
+ else int(window_length)
1171
+ )
1172
+ hop_length = (
1173
+ self.stft_params.hop_length if hop_length is None else int(hop_length)
1174
+ )
1175
+ window_type = (
1176
+ self.stft_params.window_type if window_type is None else window_type
1177
+ )
1178
+ match_stride = (
1179
+ self.stft_params.match_stride if match_stride is None else match_stride
1180
+ )
1181
+ padding_type = (
1182
+ self.stft_params.padding_type if padding_type is None else padding_type
1183
+ )
1184
+
1185
+ window = self.get_window(window_type, window_length, self.audio_data.device)
1186
+ window = window.to(self.audio_data.device)
1187
+
1188
+ audio_data = self.audio_data
1189
+ right_pad, pad = self.compute_stft_padding(
1190
+ window_length, hop_length, match_stride
1191
+ )
1192
+ audio_data = torch.nn.functional.pad(
1193
+ audio_data, (pad, pad + right_pad), padding_type
1194
+ )
1195
+ stft_data = torch.stft(
1196
+ audio_data.reshape(-1, audio_data.shape[-1]),
1197
+ n_fft=window_length,
1198
+ hop_length=hop_length,
1199
+ window=window,
1200
+ return_complex=True,
1201
+ center=True,
1202
+ )
1203
+ _, nf, nt = stft_data.shape
1204
+ stft_data = stft_data.reshape(self.batch_size, self.num_channels, nf, nt)
1205
+
1206
+ if match_stride:
1207
+ # Drop first two and last two frames, which are added
1208
+ # because of padding. Now num_frames * hop_length = num_samples.
1209
+ stft_data = stft_data[..., 2:-2]
1210
+ self.stft_data = stft_data
1211
+
1212
+ return stft_data
1213
+
1214
+ def istft(
1215
+ self,
1216
+ window_length: int = None,
1217
+ hop_length: int = None,
1218
+ window_type: str = None,
1219
+ match_stride: bool = None,
1220
+ length: int = None,
1221
+ ):
1222
+ """Computes inverse STFT and sets it to audio\_data.
1223
+
1224
+ Parameters
1225
+ ----------
1226
+ window_length : int, optional
1227
+ Window length of STFT, by default ``0.032 * self.sample_rate``.
1228
+ hop_length : int, optional
1229
+ Hop length of STFT, by default ``window_length // 4``.
1230
+ window_type : str, optional
1231
+ Type of window to use, by default ``sqrt\_hann``.
1232
+ match_stride : bool, optional
1233
+ Whether to match the stride of convolutional layers, by default False
1234
+ length : int, optional
1235
+ Original length of signal, by default None
1236
+
1237
+ Returns
1238
+ -------
1239
+ AudioSignal
1240
+ AudioSignal with istft applied.
1241
+
1242
+ Raises
1243
+ ------
1244
+ RuntimeError
1245
+ Raises an error if stft was not called prior to istft on the signal,
1246
+ or if stft_data is not set.
1247
+ """
1248
+ if self.stft_data is None:
1249
+ raise RuntimeError("Cannot do inverse STFT without self.stft_data!")
1250
+
1251
+ window_length = (
1252
+ self.stft_params.window_length
1253
+ if window_length is None
1254
+ else int(window_length)
1255
+ )
1256
+ hop_length = (
1257
+ self.stft_params.hop_length if hop_length is None else int(hop_length)
1258
+ )
1259
+ window_type = (
1260
+ self.stft_params.window_type if window_type is None else window_type
1261
+ )
1262
+ match_stride = (
1263
+ self.stft_params.match_stride if match_stride is None else match_stride
1264
+ )
1265
+
1266
+ window = self.get_window(window_type, window_length, self.stft_data.device)
1267
+
1268
+ nb, nch, nf, nt = self.stft_data.shape
1269
+ stft_data = self.stft_data.reshape(nb * nch, nf, nt)
1270
+ right_pad, pad = self.compute_stft_padding(
1271
+ window_length, hop_length, match_stride
1272
+ )
1273
+
1274
+ if length is None:
1275
+ length = self.original_signal_length
1276
+ length = length + 2 * pad + right_pad
1277
+
1278
+ if match_stride:
1279
+ # Zero-pad the STFT on either side, putting back the frames that were
1280
+ # dropped in stft().
1281
+ stft_data = torch.nn.functional.pad(stft_data, (2, 2))
1282
+
1283
+ audio_data = torch.istft(
1284
+ stft_data,
1285
+ n_fft=window_length,
1286
+ hop_length=hop_length,
1287
+ window=window,
1288
+ length=length,
1289
+ center=True,
1290
+ )
1291
+ audio_data = audio_data.reshape(nb, nch, -1)
1292
+ if match_stride:
1293
+ audio_data = audio_data[..., pad : -(pad + right_pad)]
1294
+ self.audio_data = audio_data
1295
+
1296
+ return self
1297
+
1298
+ @staticmethod
1299
+ @functools.lru_cache(None)
1300
+ def get_mel_filters(
1301
+ sr: int, n_fft: int, n_mels: int, fmin: float = 0.0, fmax: float = None
1302
+ ):
1303
+ """Create a Filterbank matrix to combine FFT bins into Mel-frequency bins.
1304
+
1305
+ Parameters
1306
+ ----------
1307
+ sr : int
1308
+ Sample rate of audio
1309
+ n_fft : int
1310
+ Number of FFT bins
1311
+ n_mels : int
1312
+ Number of mels
1313
+ fmin : float, optional
1314
+ Lowest frequency, in Hz, by default 0.0
1315
+ fmax : float, optional
1316
+ Highest frequency, by default None
1317
+
1318
+ Returns
1319
+ -------
1320
+ np.ndarray [shape=(n_mels, 1 + n_fft/2)]
1321
+ Mel transform matrix
1322
+ """
1323
+ from librosa.filters import mel as librosa_mel_fn
1324
+
1325
+ return librosa_mel_fn(
1326
+ sr=sr,
1327
+ n_fft=n_fft,
1328
+ n_mels=n_mels,
1329
+ fmin=fmin,
1330
+ fmax=fmax,
1331
+ )
1332
+
1333
+ def mel_spectrogram(
1334
+ self, n_mels: int = 80, mel_fmin: float = 0.0, mel_fmax: float = None, **kwargs
1335
+ ):
1336
+ """Computes a Mel spectrogram.
1337
+
1338
+ Parameters
1339
+ ----------
1340
+ n_mels : int, optional
1341
+ Number of mels, by default 80
1342
+ mel_fmin : float, optional
1343
+ Lowest frequency, in Hz, by default 0.0
1344
+ mel_fmax : float, optional
1345
+ Highest frequency, by default None
1346
+ kwargs : dict, optional
1347
+ Keyword arguments to self.stft().
1348
+
1349
+ Returns
1350
+ -------
1351
+ torch.Tensor [shape=(batch, channels, mels, time)]
1352
+ Mel spectrogram.
1353
+ """
1354
+ stft = self.stft(**kwargs)
1355
+ magnitude = torch.abs(stft)
1356
+
1357
+ nf = magnitude.shape[2]
1358
+ mel_basis = self.get_mel_filters(
1359
+ sr=self.sample_rate,
1360
+ n_fft=2 * (nf - 1),
1361
+ n_mels=n_mels,
1362
+ fmin=mel_fmin,
1363
+ fmax=mel_fmax,
1364
+ )
1365
+ mel_basis = torch.from_numpy(mel_basis).to(self.device)
1366
+
1367
+ mel_spectrogram = magnitude.transpose(2, -1) @ mel_basis.T
1368
+ mel_spectrogram = mel_spectrogram.transpose(-1, 2)
1369
+ return mel_spectrogram
1370
+
1371
+ @staticmethod
1372
+ @functools.lru_cache(None)
1373
+ def get_dct(n_mfcc: int, n_mels: int, norm: str = "ortho", device: str = None):
1374
+ """Create a discrete cosine transform (DCT) transformation matrix with shape (``n_mels``, ``n_mfcc``),
1375
+ it can be normalized depending on norm. For more information about dct:
1376
+ http://en.wikipedia.org/wiki/Discrete_cosine_transform#DCT-II
1377
+
1378
+ Parameters
1379
+ ----------
1380
+ n_mfcc : int
1381
+ Number of mfccs
1382
+ n_mels : int
1383
+ Number of mels
1384
+ norm : str
1385
+ Use "ortho" to get a orthogonal matrix or None, by default "ortho"
1386
+ device : str, optional
1387
+ Device to load the transformation matrix on, by default None
1388
+
1389
+ Returns
1390
+ -------
1391
+ torch.Tensor [shape=(n_mels, n_mfcc)] T
1392
+ The dct transformation matrix.
1393
+ """
1394
+ from torchaudio.functional import create_dct
1395
+
1396
+ return create_dct(n_mfcc, n_mels, norm).to(device)
1397
+
1398
+ def mfcc(
1399
+ self, n_mfcc: int = 40, n_mels: int = 80, log_offset: float = 1e-6, **kwargs
1400
+ ):
1401
+ """Computes mel-frequency cepstral coefficients (MFCCs).
1402
+
1403
+ Parameters
1404
+ ----------
1405
+ n_mfcc : int, optional
1406
+ Number of mels, by default 40
1407
+ n_mels : int, optional
1408
+ Number of mels, by default 80
1409
+ log_offset: float, optional
1410
+ Small value to prevent numerical issues when trying to compute log(0), by default 1e-6
1411
+ kwargs : dict, optional
1412
+ Keyword arguments to self.mel_spectrogram(), note that some of them will be used for self.stft()
1413
+
1414
+ Returns
1415
+ -------
1416
+ torch.Tensor [shape=(batch, channels, mfccs, time)]
1417
+ MFCCs.
1418
+ """
1419
+
1420
+ mel_spectrogram = self.mel_spectrogram(n_mels, **kwargs)
1421
+ mel_spectrogram = torch.log(mel_spectrogram + log_offset)
1422
+ dct_mat = self.get_dct(n_mfcc, n_mels, "ortho", self.device)
1423
+
1424
+ mfcc = mel_spectrogram.transpose(-1, -2) @ dct_mat
1425
+ mfcc = mfcc.transpose(-1, -2)
1426
+ return mfcc
1427
+
1428
+ @property
1429
+ def magnitude(self):
1430
+ """Computes and returns the absolute value of the STFT, which
1431
+ is the magnitude. This value can also be set to some tensor.
1432
+ When set, ``self.stft_data`` is manipulated so that its magnitude
1433
+ matches what this is set to, and modulated by the phase.
1434
+
1435
+ Returns
1436
+ -------
1437
+ torch.Tensor
1438
+ Magnitude of STFT.
1439
+
1440
+ Examples
1441
+ --------
1442
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1443
+ >>> magnitude = signal.magnitude # Computes stft if not computed
1444
+ >>> magnitude[magnitude < magnitude.mean()] = 0
1445
+ >>> signal.magnitude = magnitude
1446
+ >>> signal.istft()
1447
+ """
1448
+ if self.stft_data is None:
1449
+ self.stft()
1450
+ return torch.abs(self.stft_data)
1451
+
1452
+ @magnitude.setter
1453
+ def magnitude(self, value):
1454
+ self.stft_data = value * torch.exp(1j * self.phase)
1455
+ return
1456
+
1457
+ def log_magnitude(
1458
+ self, ref_value: float = 1.0, amin: float = 1e-5, top_db: float = 80.0
1459
+ ):
1460
+ """Computes the log-magnitude of the spectrogram.
1461
+
1462
+ Parameters
1463
+ ----------
1464
+ ref_value : float, optional
1465
+ The magnitude is scaled relative to ``ref``: ``20 * log10(S / ref)``.
1466
+ Zeros in the output correspond to positions where ``S == ref``,
1467
+ by default 1.0
1468
+ amin : float, optional
1469
+ Minimum threshold for ``S`` and ``ref``, by default 1e-5
1470
+ top_db : float, optional
1471
+ Threshold the output at ``top_db`` below the peak:
1472
+ ``max(10 * log10(S/ref)) - top_db``, by default -80.0
1473
+
1474
+ Returns
1475
+ -------
1476
+ torch.Tensor
1477
+ Log-magnitude spectrogram
1478
+ """
1479
+ magnitude = self.magnitude
1480
+
1481
+ amin = amin**2
1482
+ log_spec = 10.0 * torch.log10(magnitude.pow(2).clamp(min=amin))
1483
+ log_spec -= 10.0 * np.log10(np.maximum(amin, ref_value))
1484
+
1485
+ if top_db is not None:
1486
+ log_spec = torch.maximum(log_spec, log_spec.max() - top_db)
1487
+ return log_spec
1488
+
1489
+ @property
1490
+ def phase(self):
1491
+ """Computes and returns the phase of the STFT.
1492
+ This value can also be set to some tensor.
1493
+ When set, ``self.stft_data`` is manipulated so that its phase
1494
+ matches what this is set to, we original magnitudeith th.
1495
+
1496
+ Returns
1497
+ -------
1498
+ torch.Tensor
1499
+ Phase of STFT.
1500
+
1501
+ Examples
1502
+ --------
1503
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1504
+ >>> phase = signal.phase # Computes stft if not computed
1505
+ >>> phase[phase < phase.mean()] = 0
1506
+ >>> signal.phase = phase
1507
+ >>> signal.istft()
1508
+ """
1509
+ if self.stft_data is None:
1510
+ self.stft()
1511
+ return torch.angle(self.stft_data)
1512
+
1513
+ @phase.setter
1514
+ def phase(self, value):
1515
+ self.stft_data = self.magnitude * torch.exp(1j * value)
1516
+ return
1517
+
1518
+ # Operator overloading
1519
+ def __add__(self, other):
1520
+ new_signal = self.clone()
1521
+ new_signal.audio_data += util._get_value(other)
1522
+ return new_signal
1523
+
1524
+ def __iadd__(self, other):
1525
+ self.audio_data += util._get_value(other)
1526
+ return self
1527
+
1528
+ def __radd__(self, other):
1529
+ return self + other
1530
+
1531
+ def __sub__(self, other):
1532
+ new_signal = self.clone()
1533
+ new_signal.audio_data -= util._get_value(other)
1534
+ return new_signal
1535
+
1536
+ def __isub__(self, other):
1537
+ self.audio_data -= util._get_value(other)
1538
+ return self
1539
+
1540
+ def __mul__(self, other):
1541
+ new_signal = self.clone()
1542
+ new_signal.audio_data *= util._get_value(other)
1543
+ return new_signal
1544
+
1545
+ def __imul__(self, other):
1546
+ self.audio_data *= util._get_value(other)
1547
+ return self
1548
+
1549
+ def __rmul__(self, other):
1550
+ return self * other
1551
+
1552
+ # Representation
1553
+ def _info(self):
1554
+ dur = f"{self.signal_duration:0.3f}" if self.signal_duration else "[unknown]"
1555
+ info = {
1556
+ "duration": f"{dur} seconds",
1557
+ "batch_size": self.batch_size,
1558
+ "path": self.path_to_file if self.path_to_file else "path unknown",
1559
+ "sample_rate": self.sample_rate,
1560
+ "num_channels": self.num_channels if self.num_channels else "[unknown]",
1561
+ "audio_data.shape": self.audio_data.shape,
1562
+ "stft_params": self.stft_params,
1563
+ "device": self.device,
1564
+ }
1565
+
1566
+ return info
1567
+
1568
+ def markdown(self):
1569
+ """Produces a markdown representation of AudioSignal, in a markdown table.
1570
+
1571
+ Returns
1572
+ -------
1573
+ str
1574
+ Markdown representation of AudioSignal.
1575
+
1576
+ Examples
1577
+ --------
1578
+ >>> signal = AudioSignal(torch.randn(44100), 44100)
1579
+ >>> print(signal.markdown())
1580
+ | Key | Value
1581
+ |---|---
1582
+ | duration | 1.000 seconds |
1583
+ | batch_size | 1 |
1584
+ | path | path unknown |
1585
+ | sample_rate | 44100 |
1586
+ | num_channels | 1 |
1587
+ | audio_data.shape | torch.Size([1, 1, 44100]) |
1588
+ | stft_params | STFTParams(window_length=2048, hop_length=512, window_type='sqrt_hann', match_stride=False) |
1589
+ | device | cpu |
1590
+ """
1591
+ info = self._info()
1592
+
1593
+ FORMAT = "| Key | Value \n" "|---|--- \n"
1594
+ for k, v in info.items():
1595
+ row = f"| {k} | {v} |\n"
1596
+ FORMAT += row
1597
+ return FORMAT
1598
+
1599
+ def __str__(self):
1600
+ info = self._info()
1601
+
1602
+ desc = ""
1603
+ for k, v in info.items():
1604
+ desc += f"{k}: {v}\n"
1605
+ return desc
1606
+
1607
+ def __rich__(self):
1608
+ from rich.table import Table
1609
+
1610
+ info = self._info()
1611
+
1612
+ table = Table(title=f"{self.__class__.__name__}")
1613
+ table.add_column("Key", style="green")
1614
+ table.add_column("Value", style="cyan")
1615
+
1616
+ for k, v in info.items():
1617
+ table.add_row(k, str(v))
1618
+ return table
1619
+
1620
+ # Comparison
1621
+ def __eq__(self, other):
1622
+ for k, v in list(self.__dict__.items()):
1623
+ if torch.is_tensor(v):
1624
+ if not torch.allclose(v, other.__dict__[k], atol=1e-6):
1625
+ max_error = (v - other.__dict__[k]).abs().max()
1626
+ print(f"Max abs error for {k}: {max_error}")
1627
+ return False
1628
+ return True
1629
+
1630
+ # Indexing
1631
+ def __getitem__(self, key):
1632
+ if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
1633
+ assert self.batch_size == 1
1634
+ audio_data = self.audio_data
1635
+ _loudness = self._loudness
1636
+ stft_data = self.stft_data
1637
+
1638
+ elif isinstance(key, (bool, int, list, slice, tuple)) or (
1639
+ torch.is_tensor(key) and key.ndim <= 1
1640
+ ):
1641
+ # Indexing only on the batch dimension.
1642
+ # Then let's copy over relevant stuff.
1643
+ # Future work: make this work for time-indexing
1644
+ # as well, using the hop length.
1645
+ audio_data = self.audio_data[key]
1646
+ _loudness = self._loudness[key] if self._loudness is not None else None
1647
+ stft_data = self.stft_data[key] if self.stft_data is not None else None
1648
+
1649
+ sources = None
1650
+
1651
+ copy = type(self)(audio_data, self.sample_rate, stft_params=self.stft_params)
1652
+ copy._loudness = _loudness
1653
+ copy._stft_data = stft_data
1654
+ copy.sources = sources
1655
+
1656
+ return copy
1657
+
1658
+ def __setitem__(self, key, value):
1659
+ if not isinstance(value, type(self)):
1660
+ self.audio_data[key] = value
1661
+ return
1662
+
1663
+ if torch.is_tensor(key) and key.ndim == 0 and key.item() is True:
1664
+ assert self.batch_size == 1
1665
+ self.audio_data = value.audio_data
1666
+ self._loudness = value._loudness
1667
+ self.stft_data = value.stft_data
1668
+ return
1669
+
1670
+ elif isinstance(key, (bool, int, list, slice, tuple)) or (
1671
+ torch.is_tensor(key) and key.ndim <= 1
1672
+ ):
1673
+ if self.audio_data is not None and value.audio_data is not None:
1674
+ self.audio_data[key] = value.audio_data
1675
+ if self._loudness is not None and value._loudness is not None:
1676
+ self._loudness[key] = value._loudness
1677
+ if self.stft_data is not None and value.stft_data is not None:
1678
+ self.stft_data[key] = value.stft_data
1679
+ return
1680
+
1681
+ def __ne__(self, other):
1682
+ return not self == other
audiotools/core/display.py ADDED
@@ -0,0 +1,194 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import typing
3
+ from functools import wraps
4
+
5
+ from . import util
6
+
7
+
8
+ def format_figure(func):
9
+ """Decorator for formatting figures produced by the code below.
10
+ See :py:func:`audiotools.core.util.format_figure` for more.
11
+
12
+ Parameters
13
+ ----------
14
+ func : Callable
15
+ Plotting function that is decorated by this function.
16
+
17
+ """
18
+
19
+ @wraps(func)
20
+ def wrapper(*args, **kwargs):
21
+ f_keys = inspect.signature(util.format_figure).parameters.keys()
22
+ f_kwargs = {}
23
+ for k, v in list(kwargs.items()):
24
+ if k in f_keys:
25
+ kwargs.pop(k)
26
+ f_kwargs[k] = v
27
+ func(*args, **kwargs)
28
+ util.format_figure(**f_kwargs)
29
+
30
+ return wrapper
31
+
32
+
33
+ class DisplayMixin:
34
+ @format_figure
35
+ def specshow(
36
+ self,
37
+ preemphasis: bool = False,
38
+ x_axis: str = "time",
39
+ y_axis: str = "linear",
40
+ n_mels: int = 128,
41
+ **kwargs,
42
+ ):
43
+ """Displays a spectrogram, using ``librosa.display.specshow``.
44
+
45
+ Parameters
46
+ ----------
47
+ preemphasis : bool, optional
48
+ Whether or not to apply preemphasis, which makes high
49
+ frequency detail easier to see, by default False
50
+ x_axis : str, optional
51
+ How to label the x axis, by default "time"
52
+ y_axis : str, optional
53
+ How to label the y axis, by default "linear"
54
+ n_mels : int, optional
55
+ If displaying a mel spectrogram with ``y_axis = "mel"``,
56
+ this controls the number of mels, by default 128.
57
+ kwargs : dict, optional
58
+ Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
59
+ """
60
+ import librosa
61
+ import librosa.display
62
+
63
+ # Always re-compute the STFT data before showing it, in case
64
+ # it changed.
65
+ signal = self.clone()
66
+ signal.stft_data = None
67
+
68
+ if preemphasis:
69
+ signal.preemphasis()
70
+
71
+ ref = signal.magnitude.max()
72
+ log_mag = signal.log_magnitude(ref_value=ref)
73
+
74
+ if y_axis == "mel":
75
+ log_mag = 20 * signal.mel_spectrogram(n_mels).clamp(1e-5).log10()
76
+ log_mag -= log_mag.max()
77
+
78
+ librosa.display.specshow(
79
+ log_mag.numpy()[0].mean(axis=0),
80
+ x_axis=x_axis,
81
+ y_axis=y_axis,
82
+ sr=signal.sample_rate,
83
+ **kwargs,
84
+ )
85
+
86
+ @format_figure
87
+ def waveplot(self, x_axis: str = "time", **kwargs):
88
+ """Displays a waveform plot, using ``librosa.display.waveshow``.
89
+
90
+ Parameters
91
+ ----------
92
+ x_axis : str, optional
93
+ How to label the x axis, by default "time"
94
+ kwargs : dict, optional
95
+ Keyword arguments to :py:func:`audiotools.core.util.format_figure`.
96
+ """
97
+ import librosa
98
+ import librosa.display
99
+
100
+ audio_data = self.audio_data[0].mean(dim=0)
101
+ audio_data = audio_data.cpu().numpy()
102
+
103
+ plot_fn = "waveshow" if hasattr(librosa.display, "waveshow") else "waveplot"
104
+ wave_plot_fn = getattr(librosa.display, plot_fn)
105
+ wave_plot_fn(audio_data, x_axis=x_axis, sr=self.sample_rate, **kwargs)
106
+
107
+ @format_figure
108
+ def wavespec(self, x_axis: str = "time", **kwargs):
109
+ """Displays a waveform plot, using ``librosa.display.waveshow``.
110
+
111
+ Parameters
112
+ ----------
113
+ x_axis : str, optional
114
+ How to label the x axis, by default "time"
115
+ kwargs : dict, optional
116
+ Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow`.
117
+ """
118
+ import matplotlib.pyplot as plt
119
+ from matplotlib.gridspec import GridSpec
120
+
121
+ gs = GridSpec(6, 1)
122
+ plt.subplot(gs[0, :])
123
+ self.waveplot(x_axis=x_axis)
124
+ plt.subplot(gs[1:, :])
125
+ self.specshow(x_axis=x_axis, **kwargs)
126
+
127
+ def write_audio_to_tb(
128
+ self,
129
+ tag: str,
130
+ writer,
131
+ step: int = None,
132
+ plot_fn: typing.Union[typing.Callable, str] = "specshow",
133
+ **kwargs,
134
+ ):
135
+ """Writes a signal and its spectrogram to Tensorboard. Will show up
136
+ under the Audio and Images tab in Tensorboard.
137
+
138
+ Parameters
139
+ ----------
140
+ tag : str
141
+ Tag to write signal to (e.g. ``clean/sample_0.wav``). The image will be
142
+ written to the corresponding ``.png`` file (e.g. ``clean/sample_0.png``).
143
+ writer : SummaryWriter
144
+ A SummaryWriter object from PyTorch library.
145
+ step : int, optional
146
+ The step to write the signal to, by default None
147
+ plot_fn : typing.Union[typing.Callable, str], optional
148
+ How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
149
+ kwargs : dict, optional
150
+ Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
151
+ whatever ``plot_fn`` is set to.
152
+ """
153
+ import matplotlib.pyplot as plt
154
+
155
+ audio_data = self.audio_data[0, 0].detach().cpu()
156
+ sample_rate = self.sample_rate
157
+ writer.add_audio(tag, audio_data, step, sample_rate)
158
+
159
+ if plot_fn is not None:
160
+ if isinstance(plot_fn, str):
161
+ plot_fn = getattr(self, plot_fn)
162
+ fig = plt.figure()
163
+ plt.clf()
164
+ plot_fn(**kwargs)
165
+ writer.add_figure(tag.replace("wav", "png"), fig, step)
166
+
167
+ def save_image(
168
+ self,
169
+ image_path: str,
170
+ plot_fn: typing.Union[typing.Callable, str] = "specshow",
171
+ **kwargs,
172
+ ):
173
+ """Save AudioSignal spectrogram (or whatever ``plot_fn`` is set to) to
174
+ a specified file.
175
+
176
+ Parameters
177
+ ----------
178
+ image_path : str
179
+ Where to save the file to.
180
+ plot_fn : typing.Union[typing.Callable, str], optional
181
+ How to create the image. Set to ``None`` to avoid plotting, by default "specshow"
182
+ kwargs : dict, optional
183
+ Keyword arguments to :py:func:`audiotools.core.display.DisplayMixin.specshow` or
184
+ whatever ``plot_fn`` is set to.
185
+ """
186
+ import matplotlib.pyplot as plt
187
+
188
+ if isinstance(plot_fn, str):
189
+ plot_fn = getattr(self, plot_fn)
190
+
191
+ plt.clf()
192
+ plot_fn(**kwargs)
193
+ plt.savefig(image_path, bbox_inches="tight", pad_inches=0)
194
+ plt.close()
audiotools/core/dsp.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+
3
+ import julius
4
+ import numpy as np
5
+ import torch
6
+
7
+ from . import util
8
+
9
+
10
+ class DSPMixin:
11
+ _original_batch_size = None
12
+ _original_num_channels = None
13
+ _padded_signal_length = None
14
+
15
+ def _preprocess_signal_for_windowing(self, window_duration, hop_duration):
16
+ self._original_batch_size = self.batch_size
17
+ self._original_num_channels = self.num_channels
18
+
19
+ window_length = int(window_duration * self.sample_rate)
20
+ hop_length = int(hop_duration * self.sample_rate)
21
+
22
+ if window_length % hop_length != 0:
23
+ factor = window_length // hop_length
24
+ window_length = factor * hop_length
25
+
26
+ self.zero_pad(hop_length, hop_length)
27
+ self._padded_signal_length = self.signal_length
28
+
29
+ return window_length, hop_length
30
+
31
+ def windows(
32
+ self, window_duration: float, hop_duration: float, preprocess: bool = True
33
+ ):
34
+ """Generator which yields windows of specified duration from signal with a specified
35
+ hop length.
36
+
37
+ Parameters
38
+ ----------
39
+ window_duration : float
40
+ Duration of every window in seconds.
41
+ hop_duration : float
42
+ Hop between windows in seconds.
43
+ preprocess : bool, optional
44
+ Whether to preprocess the signal, so that the first sample is in
45
+ the middle of the first window, by default True
46
+
47
+ Yields
48
+ ------
49
+ AudioSignal
50
+ Each window is returned as an AudioSignal.
51
+ """
52
+ if preprocess:
53
+ window_length, hop_length = self._preprocess_signal_for_windowing(
54
+ window_duration, hop_duration
55
+ )
56
+
57
+ self.audio_data = self.audio_data.reshape(-1, 1, self.signal_length)
58
+
59
+ for b in range(self.batch_size):
60
+ i = 0
61
+ start_idx = i * hop_length
62
+ while True:
63
+ start_idx = i * hop_length
64
+ i += 1
65
+ end_idx = start_idx + window_length
66
+ if end_idx > self.signal_length:
67
+ break
68
+ yield self[b, ..., start_idx:end_idx]
69
+
70
+ def collect_windows(
71
+ self, window_duration: float, hop_duration: float, preprocess: bool = True
72
+ ):
73
+ """Reshapes signal into windows of specified duration from signal with a specified
74
+ hop length. Window are placed along the batch dimension. Use with
75
+ :py:func:`audiotools.core.dsp.DSPMixin.overlap_and_add` to reconstruct the
76
+ original signal.
77
+
78
+ Parameters
79
+ ----------
80
+ window_duration : float
81
+ Duration of every window in seconds.
82
+ hop_duration : float
83
+ Hop between windows in seconds.
84
+ preprocess : bool, optional
85
+ Whether to preprocess the signal, so that the first sample is in
86
+ the middle of the first window, by default True
87
+
88
+ Returns
89
+ -------
90
+ AudioSignal
91
+ AudioSignal unfolded with shape ``(nb * nch * num_windows, 1, window_length)``
92
+ """
93
+ if preprocess:
94
+ window_length, hop_length = self._preprocess_signal_for_windowing(
95
+ window_duration, hop_duration
96
+ )
97
+
98
+ # self.audio_data: (nb, nch, nt).
99
+ unfolded = torch.nn.functional.unfold(
100
+ self.audio_data.reshape(-1, 1, 1, self.signal_length),
101
+ kernel_size=(1, window_length),
102
+ stride=(1, hop_length),
103
+ )
104
+ # unfolded: (nb * nch, window_length, num_windows).
105
+ # -> (nb * nch * num_windows, 1, window_length)
106
+ unfolded = unfolded.permute(0, 2, 1).reshape(-1, 1, window_length)
107
+ self.audio_data = unfolded
108
+ return self
109
+
110
+ def overlap_and_add(self, hop_duration: float):
111
+ """Function which takes a list of windows and overlap adds them into a
112
+ signal the same length as ``audio_signal``.
113
+
114
+ Parameters
115
+ ----------
116
+ hop_duration : float
117
+ How much to shift for each window
118
+ (overlap is window_duration - hop_duration) in seconds.
119
+
120
+ Returns
121
+ -------
122
+ AudioSignal
123
+ overlap-and-added signal.
124
+ """
125
+ hop_length = int(hop_duration * self.sample_rate)
126
+ window_length = self.signal_length
127
+
128
+ nb, nch = self._original_batch_size, self._original_num_channels
129
+
130
+ unfolded = self.audio_data.reshape(nb * nch, -1, window_length).permute(0, 2, 1)
131
+ folded = torch.nn.functional.fold(
132
+ unfolded,
133
+ output_size=(1, self._padded_signal_length),
134
+ kernel_size=(1, window_length),
135
+ stride=(1, hop_length),
136
+ )
137
+
138
+ norm = torch.ones_like(unfolded, device=unfolded.device)
139
+ norm = torch.nn.functional.fold(
140
+ norm,
141
+ output_size=(1, self._padded_signal_length),
142
+ kernel_size=(1, window_length),
143
+ stride=(1, hop_length),
144
+ )
145
+
146
+ folded = folded / norm
147
+
148
+ folded = folded.reshape(nb, nch, -1)
149
+ self.audio_data = folded
150
+ self.trim(hop_length, hop_length)
151
+ return self
152
+
153
+ def low_pass(
154
+ self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
155
+ ):
156
+ """Low-passes the signal in-place. Each item in the batch
157
+ can have a different low-pass cutoff, if the input
158
+ to this signal is an array or tensor. If a float, all
159
+ items are given the same low-pass filter.
160
+
161
+ Parameters
162
+ ----------
163
+ cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
164
+ Cutoff in Hz of low-pass filter.
165
+ zeros : int, optional
166
+ Number of taps to use in low-pass filter, by default 51
167
+
168
+ Returns
169
+ -------
170
+ AudioSignal
171
+ Low-passed AudioSignal.
172
+ """
173
+ cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
174
+ cutoffs = cutoffs / self.sample_rate
175
+ filtered = torch.empty_like(self.audio_data)
176
+
177
+ for i, cutoff in enumerate(cutoffs):
178
+ lp_filter = julius.LowPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
179
+ filtered[i] = lp_filter(self.audio_data[i])
180
+
181
+ self.audio_data = filtered
182
+ self.stft_data = None
183
+ return self
184
+
185
+ def high_pass(
186
+ self, cutoffs: typing.Union[torch.Tensor, np.ndarray, float], zeros: int = 51
187
+ ):
188
+ """High-passes the signal in-place. Each item in the batch
189
+ can have a different high-pass cutoff, if the input
190
+ to this signal is an array or tensor. If a float, all
191
+ items are given the same high-pass filter.
192
+
193
+ Parameters
194
+ ----------
195
+ cutoffs : typing.Union[torch.Tensor, np.ndarray, float]
196
+ Cutoff in Hz of high-pass filter.
197
+ zeros : int, optional
198
+ Number of taps to use in high-pass filter, by default 51
199
+
200
+ Returns
201
+ -------
202
+ AudioSignal
203
+ High-passed AudioSignal.
204
+ """
205
+ cutoffs = util.ensure_tensor(cutoffs, 2, self.batch_size)
206
+ cutoffs = cutoffs / self.sample_rate
207
+ filtered = torch.empty_like(self.audio_data)
208
+
209
+ for i, cutoff in enumerate(cutoffs):
210
+ hp_filter = julius.HighPassFilter(cutoff.cpu(), zeros=zeros).to(self.device)
211
+ filtered[i] = hp_filter(self.audio_data[i])
212
+
213
+ self.audio_data = filtered
214
+ self.stft_data = None
215
+ return self
216
+
217
+ def mask_frequencies(
218
+ self,
219
+ fmin_hz: typing.Union[torch.Tensor, np.ndarray, float],
220
+ fmax_hz: typing.Union[torch.Tensor, np.ndarray, float],
221
+ val: float = 0.0,
222
+ ):
223
+ """Masks frequencies between ``fmin_hz`` and ``fmax_hz``, and fills them
224
+ with the value specified by ``val``. Useful for implementing SpecAug.
225
+ The min and max can be different for every item in the batch.
226
+
227
+ Parameters
228
+ ----------
229
+ fmin_hz : typing.Union[torch.Tensor, np.ndarray, float]
230
+ Lower end of band to mask out.
231
+ fmax_hz : typing.Union[torch.Tensor, np.ndarray, float]
232
+ Upper end of band to mask out.
233
+ val : float, optional
234
+ Value to fill in, by default 0.0
235
+
236
+ Returns
237
+ -------
238
+ AudioSignal
239
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
240
+ masked audio data.
241
+ """
242
+ # SpecAug
243
+ mag, phase = self.magnitude, self.phase
244
+ fmin_hz = util.ensure_tensor(fmin_hz, ndim=mag.ndim)
245
+ fmax_hz = util.ensure_tensor(fmax_hz, ndim=mag.ndim)
246
+ assert torch.all(fmin_hz < fmax_hz)
247
+
248
+ # build mask
249
+ nbins = mag.shape[-2]
250
+ bins_hz = torch.linspace(0, self.sample_rate / 2, nbins, device=self.device)
251
+ bins_hz = bins_hz[None, None, :, None].repeat(
252
+ self.batch_size, 1, 1, mag.shape[-1]
253
+ )
254
+ mask = (fmin_hz <= bins_hz) & (bins_hz < fmax_hz)
255
+ mask = mask.to(self.device)
256
+
257
+ mag = mag.masked_fill(mask, val)
258
+ phase = phase.masked_fill(mask, val)
259
+ self.stft_data = mag * torch.exp(1j * phase)
260
+ return self
261
+
262
+ def mask_timesteps(
263
+ self,
264
+ tmin_s: typing.Union[torch.Tensor, np.ndarray, float],
265
+ tmax_s: typing.Union[torch.Tensor, np.ndarray, float],
266
+ val: float = 0.0,
267
+ ):
268
+ """Masks timesteps between ``tmin_s`` and ``tmax_s``, and fills them
269
+ with the value specified by ``val``. Useful for implementing SpecAug.
270
+ The min and max can be different for every item in the batch.
271
+
272
+ Parameters
273
+ ----------
274
+ tmin_s : typing.Union[torch.Tensor, np.ndarray, float]
275
+ Lower end of timesteps to mask out.
276
+ tmax_s : typing.Union[torch.Tensor, np.ndarray, float]
277
+ Upper end of timesteps to mask out.
278
+ val : float, optional
279
+ Value to fill in, by default 0.0
280
+
281
+ Returns
282
+ -------
283
+ AudioSignal
284
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
285
+ masked audio data.
286
+ """
287
+ # SpecAug
288
+ mag, phase = self.magnitude, self.phase
289
+ tmin_s = util.ensure_tensor(tmin_s, ndim=mag.ndim)
290
+ tmax_s = util.ensure_tensor(tmax_s, ndim=mag.ndim)
291
+
292
+ assert torch.all(tmin_s < tmax_s)
293
+
294
+ # build mask
295
+ nt = mag.shape[-1]
296
+ bins_t = torch.linspace(0, self.signal_duration, nt, device=self.device)
297
+ bins_t = bins_t[None, None, None, :].repeat(
298
+ self.batch_size, 1, mag.shape[-2], 1
299
+ )
300
+ mask = (tmin_s <= bins_t) & (bins_t < tmax_s)
301
+
302
+ mag = mag.masked_fill(mask, val)
303
+ phase = phase.masked_fill(mask, val)
304
+ self.stft_data = mag * torch.exp(1j * phase)
305
+ return self
306
+
307
+ def mask_low_magnitudes(
308
+ self, db_cutoff: typing.Union[torch.Tensor, np.ndarray, float], val: float = 0.0
309
+ ):
310
+ """Mask away magnitudes below a specified threshold, which
311
+ can be different for every item in the batch.
312
+
313
+ Parameters
314
+ ----------
315
+ db_cutoff : typing.Union[torch.Tensor, np.ndarray, float]
316
+ Decibel value for which things below it will be masked away.
317
+ val : float, optional
318
+ Value to fill in for masked portions, by default 0.0
319
+
320
+ Returns
321
+ -------
322
+ AudioSignal
323
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
324
+ masked audio data.
325
+ """
326
+ mag = self.magnitude
327
+ log_mag = self.log_magnitude()
328
+
329
+ db_cutoff = util.ensure_tensor(db_cutoff, ndim=mag.ndim)
330
+ mask = log_mag < db_cutoff
331
+ mag = mag.masked_fill(mask, val)
332
+
333
+ self.magnitude = mag
334
+ return self
335
+
336
+ def shift_phase(self, shift: typing.Union[torch.Tensor, np.ndarray, float]):
337
+ """Shifts the phase by a constant value.
338
+
339
+ Parameters
340
+ ----------
341
+ shift : typing.Union[torch.Tensor, np.ndarray, float]
342
+ What to shift the phase by.
343
+
344
+ Returns
345
+ -------
346
+ AudioSignal
347
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
348
+ masked audio data.
349
+ """
350
+ shift = util.ensure_tensor(shift, ndim=self.phase.ndim)
351
+ self.phase = self.phase + shift
352
+ return self
353
+
354
+ def corrupt_phase(self, scale: typing.Union[torch.Tensor, np.ndarray, float]):
355
+ """Corrupts the phase randomly by some scaled value.
356
+
357
+ Parameters
358
+ ----------
359
+ scale : typing.Union[torch.Tensor, np.ndarray, float]
360
+ Standard deviation of noise to add to the phase.
361
+
362
+ Returns
363
+ -------
364
+ AudioSignal
365
+ Signal with ``stft_data`` manipulated. Apply ``.istft()`` to get the
366
+ masked audio data.
367
+ """
368
+ scale = util.ensure_tensor(scale, ndim=self.phase.ndim)
369
+ self.phase = self.phase + scale * torch.randn_like(self.phase)
370
+ return self
371
+
372
+ def preemphasis(self, coef: float = 0.85):
373
+ """Applies pre-emphasis to audio signal.
374
+
375
+ Parameters
376
+ ----------
377
+ coef : float, optional
378
+ How much pre-emphasis to apply, lower values do less. 0 does nothing.
379
+ by default 0.85
380
+
381
+ Returns
382
+ -------
383
+ AudioSignal
384
+ Pre-emphasized signal.
385
+ """
386
+ kernel = torch.tensor([1, -coef, 0]).view(1, 1, -1).to(self.device)
387
+ x = self.audio_data.reshape(-1, 1, self.signal_length)
388
+ x = torch.nn.functional.conv1d(x, kernel, padding=1)
389
+ self.audio_data = x.reshape(*self.audio_data.shape)
390
+ return self
audiotools/core/effects.py ADDED
@@ -0,0 +1,647 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+
3
+ import julius
4
+ import numpy as np
5
+ import torch
6
+ import torchaudio
7
+
8
+ from . import util
9
+
10
+
11
+ class EffectMixin:
12
+ GAIN_FACTOR = np.log(10) / 20
13
+ """Gain factor for converting between amplitude and decibels."""
14
+ CODEC_PRESETS = {
15
+ "8-bit": {"format": "wav", "encoding": "ULAW", "bits_per_sample": 8},
16
+ "GSM-FR": {"format": "gsm"},
17
+ "MP3": {"format": "mp3", "compression": -9},
18
+ "Vorbis": {"format": "vorbis", "compression": -1},
19
+ "Ogg": {
20
+ "format": "ogg",
21
+ "compression": -1,
22
+ },
23
+ "Amr-nb": {"format": "amr-nb"},
24
+ }
25
+ """Presets for applying codecs via torchaudio."""
26
+
27
+ def mix(
28
+ self,
29
+ other,
30
+ snr: typing.Union[torch.Tensor, np.ndarray, float] = 10,
31
+ other_eq: typing.Union[torch.Tensor, np.ndarray] = None,
32
+ ):
33
+ """Mixes noise with signal at specified
34
+ signal-to-noise ratio. Optionally, the
35
+ other signal can be equalized in-place.
36
+
37
+
38
+ Parameters
39
+ ----------
40
+ other : AudioSignal
41
+ AudioSignal object to mix with.
42
+ snr : typing.Union[torch.Tensor, np.ndarray, float], optional
43
+ Signal to noise ratio, by default 10
44
+ other_eq : typing.Union[torch.Tensor, np.ndarray], optional
45
+ EQ curve to apply to other signal, if any, by default None
46
+
47
+ Returns
48
+ -------
49
+ AudioSignal
50
+ In-place modification of AudioSignal.
51
+ """
52
+ snr = util.ensure_tensor(snr).to(self.device)
53
+
54
+ pad_len = max(0, self.signal_length - other.signal_length)
55
+ other.zero_pad(0, pad_len)
56
+ other.truncate_samples(self.signal_length)
57
+ if other_eq is not None:
58
+ other = other.equalizer(other_eq)
59
+
60
+ tgt_loudness = self.loudness() - snr
61
+ other = other.normalize(tgt_loudness)
62
+
63
+ self.audio_data = self.audio_data + other.audio_data
64
+ return self
65
+
66
+ def convolve(self, other, start_at_max: bool = True):
67
+ """Convolves self with other.
68
+ This function uses FFTs to do the convolution.
69
+
70
+ Parameters
71
+ ----------
72
+ other : AudioSignal
73
+ Signal to convolve with.
74
+ start_at_max : bool, optional
75
+ Whether to start at the max value of other signal, to
76
+ avoid inducing delays, by default True
77
+
78
+ Returns
79
+ -------
80
+ AudioSignal
81
+ Convolved signal, in-place.
82
+ """
83
+ from . import AudioSignal
84
+
85
+ pad_len = self.signal_length - other.signal_length
86
+
87
+ if pad_len > 0:
88
+ other.zero_pad(0, pad_len)
89
+ else:
90
+ other.truncate_samples(self.signal_length)
91
+
92
+ if start_at_max:
93
+ # Use roll to rotate over the max for every item
94
+ # so that the impulse responses don't induce any
95
+ # delay.
96
+ idx = other.audio_data.abs().argmax(axis=-1)
97
+ irs = torch.zeros_like(other.audio_data)
98
+ for i in range(other.batch_size):
99
+ irs[i] = torch.roll(other.audio_data[i], -idx[i].item(), -1)
100
+ other = AudioSignal(irs, other.sample_rate)
101
+
102
+ delta = torch.zeros_like(other.audio_data)
103
+ delta[..., 0] = 1
104
+
105
+ length = self.signal_length
106
+ delta_fft = torch.fft.rfft(delta, length)
107
+ other_fft = torch.fft.rfft(other.audio_data, length)
108
+ self_fft = torch.fft.rfft(self.audio_data, length)
109
+
110
+ convolved_fft = other_fft * self_fft
111
+ convolved_audio = torch.fft.irfft(convolved_fft, length)
112
+
113
+ delta_convolved_fft = other_fft * delta_fft
114
+ delta_audio = torch.fft.irfft(delta_convolved_fft, length)
115
+
116
+ # Use the delta to rescale the audio exactly as needed.
117
+ delta_max = delta_audio.abs().max(dim=-1, keepdims=True)[0]
118
+ scale = 1 / delta_max.clamp(1e-5)
119
+ convolved_audio = convolved_audio * scale
120
+
121
+ self.audio_data = convolved_audio
122
+
123
+ return self
124
+
125
+ def apply_ir(
126
+ self,
127
+ ir,
128
+ drr: typing.Union[torch.Tensor, np.ndarray, float] = None,
129
+ ir_eq: typing.Union[torch.Tensor, np.ndarray] = None,
130
+ use_original_phase: bool = False,
131
+ ):
132
+ """Applies an impulse response to the signal. If ` is`ir_eq``
133
+ is specified, the impulse response is equalized before
134
+ it is applied, using the given curve.
135
+
136
+ Parameters
137
+ ----------
138
+ ir : AudioSignal
139
+ Impulse response to convolve with.
140
+ drr : typing.Union[torch.Tensor, np.ndarray, float], optional
141
+ Direct-to-reverberant ratio that impulse response will be
142
+ altered to, if specified, by default None
143
+ ir_eq : typing.Union[torch.Tensor, np.ndarray], optional
144
+ Equalization that will be applied to impulse response
145
+ if specified, by default None
146
+ use_original_phase : bool, optional
147
+ Whether to use the original phase, instead of the convolved
148
+ phase, by default False
149
+
150
+ Returns
151
+ -------
152
+ AudioSignal
153
+ Signal with impulse response applied to it
154
+ """
155
+ if ir_eq is not None:
156
+ ir = ir.equalizer(ir_eq)
157
+ if drr is not None:
158
+ ir = ir.alter_drr(drr)
159
+
160
+ # Save the peak before
161
+ max_spk = self.audio_data.abs().max(dim=-1, keepdims=True).values
162
+
163
+ # Augment the impulse response to simulate microphone effects
164
+ # and with varying direct-to-reverberant ratio.
165
+ phase = self.phase
166
+ self.convolve(ir)
167
+
168
+ # Use the input phase
169
+ if use_original_phase:
170
+ self.stft()
171
+ self.stft_data = self.magnitude * torch.exp(1j * phase)
172
+ self.istft()
173
+
174
+ # Rescale to the input's amplitude
175
+ max_transformed = self.audio_data.abs().max(dim=-1, keepdims=True).values
176
+ scale_factor = max_spk.clamp(1e-8) / max_transformed.clamp(1e-8)
177
+ self = self * scale_factor
178
+
179
+ return self
180
+
181
+ def ensure_max_of_audio(self, max: float = 1.0):
182
+ """Ensures that ``abs(audio_data) <= max``.
183
+
184
+ Parameters
185
+ ----------
186
+ max : float, optional
187
+ Max absolute value of signal, by default 1.0
188
+
189
+ Returns
190
+ -------
191
+ AudioSignal
192
+ Signal with values scaled between -max and max.
193
+ """
194
+ peak = self.audio_data.abs().max(dim=-1, keepdims=True)[0]
195
+ peak_gain = torch.ones_like(peak)
196
+ peak_gain[peak > max] = max / peak[peak > max]
197
+ self.audio_data = self.audio_data * peak_gain
198
+ return self
199
+
200
+ def normalize(self, db: typing.Union[torch.Tensor, np.ndarray, float] = -24.0):
201
+ """Normalizes the signal's volume to the specified db, in LUFS.
202
+ This is GPU-compatible, making for very fast loudness normalization.
203
+
204
+ Parameters
205
+ ----------
206
+ db : typing.Union[torch.Tensor, np.ndarray, float], optional
207
+ Loudness to normalize to, by default -24.0
208
+
209
+ Returns
210
+ -------
211
+ AudioSignal
212
+ Normalized audio signal.
213
+ """
214
+ db = util.ensure_tensor(db).to(self.device)
215
+ ref_db = self.loudness()
216
+ gain = db - ref_db
217
+ gain = torch.exp(gain * self.GAIN_FACTOR)
218
+
219
+ self.audio_data = self.audio_data * gain[:, None, None]
220
+ return self
221
+
222
+ def volume_change(self, db: typing.Union[torch.Tensor, np.ndarray, float]):
223
+ """Change volume of signal by some amount, in dB.
224
+
225
+ Parameters
226
+ ----------
227
+ db : typing.Union[torch.Tensor, np.ndarray, float]
228
+ Amount to change volume by.
229
+
230
+ Returns
231
+ -------
232
+ AudioSignal
233
+ Signal at new volume.
234
+ """
235
+ db = util.ensure_tensor(db, ndim=1).to(self.device)
236
+ gain = torch.exp(db * self.GAIN_FACTOR)
237
+ self.audio_data = self.audio_data * gain[:, None, None]
238
+ return self
239
+
240
+ def _to_2d(self):
241
+ waveform = self.audio_data.reshape(-1, self.signal_length)
242
+ return waveform
243
+
244
+ def _to_3d(self, waveform):
245
+ return waveform.reshape(self.batch_size, self.num_channels, -1)
246
+
247
+ def pitch_shift(self, n_semitones: int, quick: bool = True):
248
+ """Pitch shift the signal. All items in the batch
249
+ get the same pitch shift.
250
+
251
+ Parameters
252
+ ----------
253
+ n_semitones : int
254
+ How many semitones to shift the signal by.
255
+ quick : bool, optional
256
+ Using quick pitch shifting, by default True
257
+
258
+ Returns
259
+ -------
260
+ AudioSignal
261
+ Pitch shifted audio signal.
262
+ """
263
+ device = self.device
264
+ effects = [
265
+ ["pitch", str(n_semitones * 100)],
266
+ ["rate", str(self.sample_rate)],
267
+ ]
268
+ if quick:
269
+ effects[0].insert(1, "-q")
270
+
271
+ waveform = self._to_2d().cpu()
272
+ waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
273
+ waveform, self.sample_rate, effects, channels_first=True
274
+ )
275
+ self.sample_rate = sample_rate
276
+ self.audio_data = self._to_3d(waveform)
277
+ return self.to(device)
278
+
279
+ def time_stretch(self, factor: float, quick: bool = True):
280
+ """Time stretch the audio signal.
281
+
282
+ Parameters
283
+ ----------
284
+ factor : float
285
+ Factor by which to stretch the AudioSignal. Typically
286
+ between 0.8 and 1.2.
287
+ quick : bool, optional
288
+ Whether to use quick time stretching, by default True
289
+
290
+ Returns
291
+ -------
292
+ AudioSignal
293
+ Time-stretched AudioSignal.
294
+ """
295
+ device = self.device
296
+ effects = [
297
+ ["tempo", str(factor)],
298
+ ["rate", str(self.sample_rate)],
299
+ ]
300
+ if quick:
301
+ effects[0].insert(1, "-q")
302
+
303
+ waveform = self._to_2d().cpu()
304
+ waveform, sample_rate = torchaudio.sox_effects.apply_effects_tensor(
305
+ waveform, self.sample_rate, effects, channels_first=True
306
+ )
307
+ self.sample_rate = sample_rate
308
+ self.audio_data = self._to_3d(waveform)
309
+ return self.to(device)
310
+
311
+ def apply_codec(
312
+ self,
313
+ preset: str = None,
314
+ format: str = "wav",
315
+ encoding: str = None,
316
+ bits_per_sample: int = None,
317
+ compression: int = None,
318
+ ): # pragma: no cover
319
+ """Applies an audio codec to the signal.
320
+
321
+ Parameters
322
+ ----------
323
+ preset : str, optional
324
+ One of the keys in ``self.CODEC_PRESETS``, by default None
325
+ format : str, optional
326
+ Format for audio codec, by default "wav"
327
+ encoding : str, optional
328
+ Encoding to use, by default None
329
+ bits_per_sample : int, optional
330
+ How many bits per sample, by default None
331
+ compression : int, optional
332
+ Compression amount of codec, by default None
333
+
334
+ Returns
335
+ -------
336
+ AudioSignal
337
+ AudioSignal with codec applied.
338
+
339
+ Raises
340
+ ------
341
+ ValueError
342
+ If preset is not in ``self.CODEC_PRESETS``, an error
343
+ is thrown.
344
+ """
345
+ torchaudio_version_070 = "0.7" in torchaudio.__version__
346
+ if torchaudio_version_070:
347
+ return self
348
+
349
+ kwargs = {
350
+ "format": format,
351
+ "encoding": encoding,
352
+ "bits_per_sample": bits_per_sample,
353
+ "compression": compression,
354
+ }
355
+
356
+ if preset is not None:
357
+ if preset in self.CODEC_PRESETS:
358
+ kwargs = self.CODEC_PRESETS[preset]
359
+ else:
360
+ raise ValueError(
361
+ f"Unknown preset: {preset}. "
362
+ f"Known presets: {list(self.CODEC_PRESETS.keys())}"
363
+ )
364
+
365
+ waveform = self._to_2d()
366
+ if kwargs["format"] in ["vorbis", "mp3", "ogg", "amr-nb"]:
367
+ # Apply it in a for loop
368
+ augmented = torch.cat(
369
+ [
370
+ torchaudio.functional.apply_codec(
371
+ waveform[i][None, :], self.sample_rate, **kwargs
372
+ )
373
+ for i in range(waveform.shape[0])
374
+ ],
375
+ dim=0,
376
+ )
377
+ else:
378
+ augmented = torchaudio.functional.apply_codec(
379
+ waveform, self.sample_rate, **kwargs
380
+ )
381
+ augmented = self._to_3d(augmented)
382
+
383
+ self.audio_data = augmented
384
+ return self
385
+
386
+ def mel_filterbank(self, n_bands: int):
387
+ """Breaks signal into mel bands.
388
+
389
+ Parameters
390
+ ----------
391
+ n_bands : int
392
+ Number of mel bands to use.
393
+
394
+ Returns
395
+ -------
396
+ torch.Tensor
397
+ Mel-filtered bands, with last axis being the band index.
398
+ """
399
+ filterbank = (
400
+ julius.SplitBands(self.sample_rate, n_bands).float().to(self.device)
401
+ )
402
+ filtered = filterbank(self.audio_data)
403
+ return filtered.permute(1, 2, 3, 0)
404
+
405
+ def equalizer(self, db: typing.Union[torch.Tensor, np.ndarray]):
406
+ """Applies a mel-spaced equalizer to the audio signal.
407
+
408
+ Parameters
409
+ ----------
410
+ db : typing.Union[torch.Tensor, np.ndarray]
411
+ EQ curve to apply.
412
+
413
+ Returns
414
+ -------
415
+ AudioSignal
416
+ AudioSignal with equalization applied.
417
+ """
418
+ db = util.ensure_tensor(db)
419
+ n_bands = db.shape[-1]
420
+ fbank = self.mel_filterbank(n_bands)
421
+
422
+ # If there's a batch dimension, make sure it's the same.
423
+ if db.ndim == 2:
424
+ if db.shape[0] != 1:
425
+ assert db.shape[0] == fbank.shape[0]
426
+ else:
427
+ db = db.unsqueeze(0)
428
+
429
+ weights = (10**db).to(self.device).float()
430
+ fbank = fbank * weights[:, None, None, :]
431
+ eq_audio_data = fbank.sum(-1)
432
+ self.audio_data = eq_audio_data
433
+ return self
434
+
435
+ def clip_distortion(
436
+ self, clip_percentile: typing.Union[torch.Tensor, np.ndarray, float]
437
+ ):
438
+ """Clips the signal at a given percentile. The higher it is,
439
+ the lower the threshold for clipping.
440
+
441
+ Parameters
442
+ ----------
443
+ clip_percentile : typing.Union[torch.Tensor, np.ndarray, float]
444
+ Values are between 0.0 to 1.0. Typical values are 0.1 or below.
445
+
446
+ Returns
447
+ -------
448
+ AudioSignal
449
+ Audio signal with clipped audio data.
450
+ """
451
+ clip_percentile = util.ensure_tensor(clip_percentile, ndim=1)
452
+ min_thresh = torch.quantile(self.audio_data, clip_percentile / 2, dim=-1)
453
+ max_thresh = torch.quantile(self.audio_data, 1 - (clip_percentile / 2), dim=-1)
454
+
455
+ nc = self.audio_data.shape[1]
456
+ min_thresh = min_thresh[:, :nc, :]
457
+ max_thresh = max_thresh[:, :nc, :]
458
+
459
+ self.audio_data = self.audio_data.clamp(min_thresh, max_thresh)
460
+
461
+ return self
462
+
463
+ def quantization(
464
+ self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
465
+ ):
466
+ """Applies quantization to the input waveform.
467
+
468
+ Parameters
469
+ ----------
470
+ quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
471
+ Number of evenly spaced quantization channels to quantize
472
+ to.
473
+
474
+ Returns
475
+ -------
476
+ AudioSignal
477
+ Quantized AudioSignal.
478
+ """
479
+ quantization_channels = util.ensure_tensor(quantization_channels, ndim=3)
480
+
481
+ x = self.audio_data
482
+ x = (x + 1) / 2
483
+ x = x * quantization_channels
484
+ x = x.floor()
485
+ x = x / quantization_channels
486
+ x = 2 * x - 1
487
+
488
+ residual = (self.audio_data - x).detach()
489
+ self.audio_data = self.audio_data - residual
490
+ return self
491
+
492
+ def mulaw_quantization(
493
+ self, quantization_channels: typing.Union[torch.Tensor, np.ndarray, int]
494
+ ):
495
+ """Applies mu-law quantization to the input waveform.
496
+
497
+ Parameters
498
+ ----------
499
+ quantization_channels : typing.Union[torch.Tensor, np.ndarray, int]
500
+ Number of mu-law spaced quantization channels to quantize
501
+ to.
502
+
503
+ Returns
504
+ -------
505
+ AudioSignal
506
+ Quantized AudioSignal.
507
+ """
508
+ mu = quantization_channels - 1.0
509
+ mu = util.ensure_tensor(mu, ndim=3)
510
+
511
+ x = self.audio_data
512
+
513
+ # quantize
514
+ x = torch.sign(x) * torch.log1p(mu * torch.abs(x)) / torch.log1p(mu)
515
+ x = ((x + 1) / 2 * mu + 0.5).to(torch.int64)
516
+
517
+ # unquantize
518
+ x = (x / mu) * 2 - 1.0
519
+ x = torch.sign(x) * (torch.exp(torch.abs(x) * torch.log1p(mu)) - 1.0) / mu
520
+
521
+ residual = (self.audio_data - x).detach()
522
+ self.audio_data = self.audio_data - residual
523
+ return self
524
+
525
+ def __matmul__(self, other):
526
+ return self.convolve(other)
527
+
528
+
529
+ class ImpulseResponseMixin:
530
+ """These functions are generally only used with AudioSignals that are derived
531
+ from impulse responses, not other sources like music or speech. These methods
532
+ are used to replicate the data augmentation described in [1].
533
+
534
+ 1. Bryan, Nicholas J. "Impulse response data augmentation and deep
535
+ neural networks for blind room acoustic parameter estimation."
536
+ ICASSP 2020-2020 IEEE International Conference on Acoustics,
537
+ Speech and Signal Processing (ICASSP). IEEE, 2020.
538
+ """
539
+
540
+ def decompose_ir(self):
541
+ """Decomposes an impulse response into early and late
542
+ field responses.
543
+ """
544
+ # Equations 1 and 2
545
+ # -----------------
546
+ # Breaking up into early
547
+ # response + late field response.
548
+
549
+ td = torch.argmax(self.audio_data, dim=-1, keepdim=True)
550
+ t0 = int(self.sample_rate * 0.0025)
551
+
552
+ idx = torch.arange(self.audio_data.shape[-1], device=self.device)[None, None, :]
553
+ idx = idx.expand(self.batch_size, -1, -1)
554
+ early_idx = (idx >= td - t0) * (idx <= td + t0)
555
+
556
+ early_response = torch.zeros_like(self.audio_data, device=self.device)
557
+ early_response[early_idx] = self.audio_data[early_idx]
558
+
559
+ late_idx = ~early_idx
560
+ late_field = torch.zeros_like(self.audio_data, device=self.device)
561
+ late_field[late_idx] = self.audio_data[late_idx]
562
+
563
+ # Equation 4
564
+ # ----------
565
+ # Decompose early response into windowed
566
+ # direct path and windowed residual.
567
+
568
+ window = torch.zeros_like(self.audio_data, device=self.device)
569
+ for idx in range(self.batch_size):
570
+ window_idx = early_idx[idx, 0].nonzero()
571
+ window[idx, ..., window_idx] = self.get_window(
572
+ "hann", window_idx.shape[-1], self.device
573
+ )
574
+ return early_response, late_field, window
575
+
576
+ def measure_drr(self):
577
+ """Measures the direct-to-reverberant ratio of the impulse
578
+ response.
579
+
580
+ Returns
581
+ -------
582
+ float
583
+ Direct-to-reverberant ratio
584
+ """
585
+ early_response, late_field, _ = self.decompose_ir()
586
+ num = (early_response**2).sum(dim=-1)
587
+ den = (late_field**2).sum(dim=-1)
588
+ drr = 10 * torch.log10(num / den)
589
+ return drr
590
+
591
+ @staticmethod
592
+ def solve_alpha(early_response, late_field, wd, target_drr):
593
+ """Used to solve for the alpha value, which is used
594
+ to alter the drr.
595
+ """
596
+ # Equation 5
597
+ # ----------
598
+ # Apply the good ol' quadratic formula.
599
+
600
+ wd_sq = wd**2
601
+ wd_sq_1 = (1 - wd) ** 2
602
+ e_sq = early_response**2
603
+ l_sq = late_field**2
604
+ a = (wd_sq * e_sq).sum(dim=-1)
605
+ b = (2 * (1 - wd) * wd * e_sq).sum(dim=-1)
606
+ c = (wd_sq_1 * e_sq).sum(dim=-1) - torch.pow(10, target_drr / 10) * l_sq.sum(
607
+ dim=-1
608
+ )
609
+
610
+ expr = ((b**2) - 4 * a * c).sqrt()
611
+ alpha = torch.maximum(
612
+ (-b - expr) / (2 * a),
613
+ (-b + expr) / (2 * a),
614
+ )
615
+ return alpha
616
+
617
+ def alter_drr(self, drr: typing.Union[torch.Tensor, np.ndarray, float]):
618
+ """Alters the direct-to-reverberant ratio of the impulse response.
619
+
620
+ Parameters
621
+ ----------
622
+ drr : typing.Union[torch.Tensor, np.ndarray, float]
623
+ Direct-to-reverberant ratio that impulse response will be
624
+ altered to, if specified, by default None
625
+
626
+ Returns
627
+ -------
628
+ AudioSignal
629
+ Altered impulse response.
630
+ """
631
+ drr = util.ensure_tensor(drr, 2, self.batch_size).to(self.device)
632
+
633
+ early_response, late_field, window = self.decompose_ir()
634
+ alpha = self.solve_alpha(early_response, late_field, window, drr)
635
+ min_alpha = (
636
+ late_field.abs().max(dim=-1)[0] / early_response.abs().max(dim=-1)[0]
637
+ )
638
+ alpha = torch.maximum(alpha, min_alpha)[..., None]
639
+
640
+ aug_ir_data = (
641
+ alpha * window * early_response
642
+ + ((1 - window) * early_response)
643
+ + late_field
644
+ )
645
+ self.audio_data = aug_ir_data
646
+ self.ensure_max_of_audio()
647
+ return self
audiotools/core/ffmpeg.py ADDED
@@ -0,0 +1,204 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import json
2
+ import shlex
3
+ import subprocess
4
+ import tempfile
5
+ from pathlib import Path
6
+ from typing import Tuple
7
+
8
+ import ffmpy
9
+ import numpy as np
10
+ import torch
11
+
12
+
13
+ def r128stats(filepath: str, quiet: bool):
14
+ """Takes a path to an audio file, returns a dict with the loudness
15
+ stats computed by the ffmpeg ebur128 filter.
16
+
17
+ Parameters
18
+ ----------
19
+ filepath : str
20
+ Path to compute loudness stats on.
21
+ quiet : bool
22
+ Whether to show FFMPEG output during computation.
23
+
24
+ Returns
25
+ -------
26
+ dict
27
+ Dictionary containing loudness stats.
28
+ """
29
+ ffargs = [
30
+ "ffmpeg",
31
+ "-nostats",
32
+ "-i",
33
+ filepath,
34
+ "-filter_complex",
35
+ "ebur128",
36
+ "-f",
37
+ "null",
38
+ "-",
39
+ ]
40
+ if quiet:
41
+ ffargs += ["-hide_banner"]
42
+ proc = subprocess.Popen(ffargs, stderr=subprocess.PIPE, universal_newlines=True)
43
+ stats = proc.communicate()[1]
44
+ summary_index = stats.rfind("Summary:")
45
+
46
+ summary_list = stats[summary_index:].split()
47
+ i_lufs = float(summary_list[summary_list.index("I:") + 1])
48
+ i_thresh = float(summary_list[summary_list.index("I:") + 4])
49
+ lra = float(summary_list[summary_list.index("LRA:") + 1])
50
+ lra_thresh = float(summary_list[summary_list.index("LRA:") + 4])
51
+ lra_low = float(summary_list[summary_list.index("low:") + 1])
52
+ lra_high = float(summary_list[summary_list.index("high:") + 1])
53
+ stats_dict = {
54
+ "I": i_lufs,
55
+ "I Threshold": i_thresh,
56
+ "LRA": lra,
57
+ "LRA Threshold": lra_thresh,
58
+ "LRA Low": lra_low,
59
+ "LRA High": lra_high,
60
+ }
61
+
62
+ return stats_dict
63
+
64
+
65
+ def ffprobe_offset_and_codec(path: str) -> Tuple[float, str]:
66
+ """Given a path to a file, returns the start time offset and codec of
67
+ the first audio stream.
68
+ """
69
+ ff = ffmpy.FFprobe(
70
+ inputs={path: None},
71
+ global_options="-show_entries format=start_time:stream=duration,start_time,codec_type,codec_name,start_pts,time_base -of json -v quiet",
72
+ )
73
+ streams = json.loads(ff.run(stdout=subprocess.PIPE)[0])["streams"]
74
+ seconds_offset = 0.0
75
+ codec = None
76
+
77
+ # Get the offset and codec of the first audio stream we find
78
+ # and return its start time, if it has one.
79
+ for stream in streams:
80
+ if stream["codec_type"] == "audio":
81
+ seconds_offset = stream.get("start_time", 0.0)
82
+ codec = stream.get("codec_name")
83
+ break
84
+ return float(seconds_offset), codec
85
+
86
+
87
+ class FFMPEGMixin:
88
+ _loudness = None
89
+
90
+ def ffmpeg_loudness(self, quiet: bool = True):
91
+ """Computes loudness of audio file using FFMPEG.
92
+
93
+ Parameters
94
+ ----------
95
+ quiet : bool, optional
96
+ Whether to show FFMPEG output during computation,
97
+ by default True
98
+
99
+ Returns
100
+ -------
101
+ torch.Tensor
102
+ Loudness of every item in the batch, computed via
103
+ FFMPEG.
104
+ """
105
+ loudness = []
106
+
107
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
108
+ for i in range(self.batch_size):
109
+ self[i].write(f.name)
110
+ loudness_stats = r128stats(f.name, quiet=quiet)
111
+ loudness.append(loudness_stats["I"])
112
+
113
+ self._loudness = torch.from_numpy(np.array(loudness)).float()
114
+ return self.loudness()
115
+
116
+ def ffmpeg_resample(self, sample_rate: int, quiet: bool = True):
117
+ """Resamples AudioSignal using FFMPEG. More memory-efficient
118
+ than using julius.resample for long audio files.
119
+
120
+ Parameters
121
+ ----------
122
+ sample_rate : int
123
+ Sample rate to resample to.
124
+ quiet : bool, optional
125
+ Whether to show FFMPEG output during computation,
126
+ by default True
127
+
128
+ Returns
129
+ -------
130
+ AudioSignal
131
+ Resampled AudioSignal.
132
+ """
133
+ from audiotools import AudioSignal
134
+
135
+ if sample_rate == self.sample_rate:
136
+ return self
137
+
138
+ with tempfile.NamedTemporaryFile(suffix=".wav") as f:
139
+ self.write(f.name)
140
+ f_out = f.name.replace("wav", "rs.wav")
141
+ command = f"ffmpeg -i {f.name} -ar {sample_rate} {f_out}"
142
+ if quiet:
143
+ command += " -hide_banner -loglevel error"
144
+ subprocess.check_call(shlex.split(command))
145
+ resampled = AudioSignal(f_out)
146
+ Path.unlink(Path(f_out))
147
+ return resampled
148
+
149
+ @classmethod
150
+ def load_from_file_with_ffmpeg(cls, audio_path: str, quiet: bool = True, **kwargs):
151
+ """Loads AudioSignal object after decoding it to a wav file using FFMPEG.
152
+ Useful for loading audio that isn't covered by librosa's loading mechanism. Also
153
+ useful for loading mp3 files, without any offset.
154
+
155
+ Parameters
156
+ ----------
157
+ audio_path : str
158
+ Path to load AudioSignal from.
159
+ quiet : bool, optional
160
+ Whether to show FFMPEG output during computation,
161
+ by default True
162
+
163
+ Returns
164
+ -------
165
+ AudioSignal
166
+ AudioSignal loaded from file with FFMPEG.
167
+ """
168
+ audio_path = str(audio_path)
169
+ with tempfile.TemporaryDirectory() as d:
170
+ wav_file = str(Path(d) / "extracted.wav")
171
+ padded_wav = str(Path(d) / "padded.wav")
172
+
173
+ global_options = "-y"
174
+ if quiet:
175
+ global_options += " -loglevel error"
176
+
177
+ ff = ffmpy.FFmpeg(
178
+ inputs={audio_path: None},
179
+ outputs={wav_file: None},
180
+ global_options=global_options,
181
+ )
182
+ ff.run()
183
+
184
+ # We pad the file using the start time offset in case it's an audio
185
+ # stream starting at some offset in a video container.
186
+ pad, codec = ffprobe_offset_and_codec(audio_path)
187
+
188
+ # For mp3s, don't pad files with discrepancies less than 0.027s -
189
+ # it's likely due to codec latency. The amount of latency introduced
190
+ # by mp3 is 1152, which is 0.0261 44khz. So we set the threshold
191
+ # here slightly above that.
192
+ # Source: https://lame.sourceforge.io/tech-FAQ.txt.
193
+ if codec == "mp3" and pad < 0.027:
194
+ pad = 0.0
195
+ ff = ffmpy.FFmpeg(
196
+ inputs={wav_file: None},
197
+ outputs={padded_wav: f"-af 'adelay={pad*1000}:all=true'"},
198
+ global_options=global_options,
199
+ )
200
+ ff.run()
201
+
202
+ signal = cls(padded_wav, **kwargs)
203
+
204
+ return signal
audiotools/core/loudness.py ADDED
@@ -0,0 +1,320 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+
3
+ import julius
4
+ import numpy as np
5
+ import scipy
6
+ import torch
7
+ import torch.nn.functional as F
8
+ import torchaudio
9
+
10
+
11
+ class Meter(torch.nn.Module):
12
+ """Tensorized version of pyloudnorm.Meter. Works with batched audio tensors.
13
+
14
+ Parameters
15
+ ----------
16
+ rate : int
17
+ Sample rate of audio.
18
+ filter_class : str, optional
19
+ Class of weighting filter used.
20
+ K-weighting' (default), 'Fenton/Lee 1'
21
+ 'Fenton/Lee 2', 'Dash et al.'
22
+ by default "K-weighting"
23
+ block_size : float, optional
24
+ Gating block size in seconds, by default 0.400
25
+ zeros : int, optional
26
+ Number of zeros to use in FIR approximation of
27
+ IIR filters, by default 512
28
+ use_fir : bool, optional
29
+ Whether to use FIR approximation or exact IIR formulation.
30
+ If computing on GPU, ``use_fir=True`` will be used, as its
31
+ much faster, by default False
32
+ """
33
+
34
+ def __init__(
35
+ self,
36
+ rate: int,
37
+ filter_class: str = "K-weighting",
38
+ block_size: float = 0.400,
39
+ zeros: int = 512,
40
+ use_fir: bool = False,
41
+ ):
42
+ super().__init__()
43
+
44
+ self.rate = rate
45
+ self.filter_class = filter_class
46
+ self.block_size = block_size
47
+ self.use_fir = use_fir
48
+
49
+ G = torch.from_numpy(np.array([1.0, 1.0, 1.0, 1.41, 1.41]))
50
+ self.register_buffer("G", G)
51
+
52
+ # Compute impulse responses so that filtering is fast via
53
+ # a convolution at runtime, on GPU, unlike lfilter.
54
+ impulse = np.zeros((zeros,))
55
+ impulse[..., 0] = 1.0
56
+
57
+ firs = np.zeros((len(self._filters), 1, zeros))
58
+ passband_gain = torch.zeros(len(self._filters))
59
+
60
+ for i, (_, filter_stage) in enumerate(self._filters.items()):
61
+ firs[i] = scipy.signal.lfilter(filter_stage.b, filter_stage.a, impulse)
62
+ passband_gain[i] = filter_stage.passband_gain
63
+
64
+ firs = torch.from_numpy(firs[..., ::-1].copy()).float()
65
+
66
+ self.register_buffer("firs", firs)
67
+ self.register_buffer("passband_gain", passband_gain)
68
+
69
+ def apply_filter_gpu(self, data: torch.Tensor):
70
+ """Performs FIR approximation of loudness computation.
71
+
72
+ Parameters
73
+ ----------
74
+ data : torch.Tensor
75
+ Audio data of shape (nb, nch, nt).
76
+
77
+ Returns
78
+ -------
79
+ torch.Tensor
80
+ Filtered audio data.
81
+ """
82
+ # Data is of shape (nb, nch, nt)
83
+ # Reshape to (nb*nch, 1, nt)
84
+ nb, nt, nch = data.shape
85
+ data = data.permute(0, 2, 1)
86
+ data = data.reshape(nb * nch, 1, nt)
87
+
88
+ # Apply padding
89
+ pad_length = self.firs.shape[-1]
90
+
91
+ # Apply filtering in sequence
92
+ for i in range(self.firs.shape[0]):
93
+ data = F.pad(data, (pad_length, pad_length))
94
+ data = julius.fftconv.fft_conv1d(data, self.firs[i, None, ...])
95
+ data = self.passband_gain[i] * data
96
+ data = data[..., 1 : nt + 1]
97
+
98
+ data = data.permute(0, 2, 1)
99
+ data = data[:, :nt, :]
100
+ return data
101
+
102
+ def apply_filter_cpu(self, data: torch.Tensor):
103
+ """Performs IIR formulation of loudness computation.
104
+
105
+ Parameters
106
+ ----------
107
+ data : torch.Tensor
108
+ Audio data of shape (nb, nch, nt).
109
+
110
+ Returns
111
+ -------
112
+ torch.Tensor
113
+ Filtered audio data.
114
+ """
115
+ for _, filter_stage in self._filters.items():
116
+ passband_gain = filter_stage.passband_gain
117
+
118
+ a_coeffs = torch.from_numpy(filter_stage.a).float().to(data.device)
119
+ b_coeffs = torch.from_numpy(filter_stage.b).float().to(data.device)
120
+
121
+ _data = data.permute(0, 2, 1)
122
+ filtered = torchaudio.functional.lfilter(
123
+ _data, a_coeffs, b_coeffs, clamp=False
124
+ )
125
+ data = passband_gain * filtered.permute(0, 2, 1)
126
+ return data
127
+
128
+ def apply_filter(self, data: torch.Tensor):
129
+ """Applies filter on either CPU or GPU, depending
130
+ on if the audio is on GPU or is on CPU, or if
131
+ ``self.use_fir`` is True.
132
+
133
+ Parameters
134
+ ----------
135
+ data : torch.Tensor
136
+ Audio data of shape (nb, nch, nt).
137
+
138
+ Returns
139
+ -------
140
+ torch.Tensor
141
+ Filtered audio data.
142
+ """
143
+ if data.is_cuda or self.use_fir:
144
+ data = self.apply_filter_gpu(data)
145
+ else:
146
+ data = self.apply_filter_cpu(data)
147
+ return data
148
+
149
+ def forward(self, data: torch.Tensor):
150
+ """Computes integrated loudness of data.
151
+
152
+ Parameters
153
+ ----------
154
+ data : torch.Tensor
155
+ Audio data of shape (nb, nch, nt).
156
+
157
+ Returns
158
+ -------
159
+ torch.Tensor
160
+ Filtered audio data.
161
+ """
162
+ return self.integrated_loudness(data)
163
+
164
+ def _unfold(self, input_data):
165
+ T_g = self.block_size
166
+ overlap = 0.75 # overlap of 75% of the block duration
167
+ step = 1.0 - overlap # step size by percentage
168
+
169
+ kernel_size = int(T_g * self.rate)
170
+ stride = int(T_g * self.rate * step)
171
+ unfolded = julius.core.unfold(input_data.permute(0, 2, 1), kernel_size, stride)
172
+ unfolded = unfolded.transpose(-1, -2)
173
+
174
+ return unfolded
175
+
176
+ def integrated_loudness(self, data: torch.Tensor):
177
+ """Computes integrated loudness of data.
178
+
179
+ Parameters
180
+ ----------
181
+ data : torch.Tensor
182
+ Audio data of shape (nb, nch, nt).
183
+
184
+ Returns
185
+ -------
186
+ torch.Tensor
187
+ Filtered audio data.
188
+ """
189
+ if not torch.is_tensor(data):
190
+ data = torch.from_numpy(data).float()
191
+ else:
192
+ data = data.float()
193
+
194
+ input_data = copy.copy(data)
195
+ # Data always has a batch and channel dimension.
196
+ # Is of shape (nb, nt, nch)
197
+ if input_data.ndim < 2:
198
+ input_data = input_data.unsqueeze(-1)
199
+ if input_data.ndim < 3:
200
+ input_data = input_data.unsqueeze(0)
201
+
202
+ nb, nt, nch = input_data.shape
203
+
204
+ # Apply frequency weighting filters - account
205
+ # for the acoustic respose of the head and auditory system
206
+ input_data = self.apply_filter(input_data)
207
+
208
+ G = self.G # channel gains
209
+ T_g = self.block_size # 400 ms gating block standard
210
+ Gamma_a = -70.0 # -70 LKFS = absolute loudness threshold
211
+
212
+ unfolded = self._unfold(input_data)
213
+
214
+ z = (1.0 / (T_g * self.rate)) * unfolded.square().sum(2)
215
+ l = -0.691 + 10.0 * torch.log10((G[None, :nch, None] * z).sum(1, keepdim=True))
216
+ l = l.expand_as(z)
217
+
218
+ # find gating block indices above absolute threshold
219
+ z_avg_gated = z
220
+ z_avg_gated[l <= Gamma_a] = 0
221
+ masked = l > Gamma_a
222
+ z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
223
+
224
+ # calculate the relative threshold value (see eq. 6)
225
+ Gamma_r = (
226
+ -0.691 + 10.0 * torch.log10((z_avg_gated * G[None, :nch]).sum(-1)) - 10.0
227
+ )
228
+ Gamma_r = Gamma_r[:, None, None]
229
+ Gamma_r = Gamma_r.expand(nb, nch, l.shape[-1])
230
+
231
+ # find gating block indices above relative and absolute thresholds (end of eq. 7)
232
+ z_avg_gated = z
233
+ z_avg_gated[l <= Gamma_a] = 0
234
+ z_avg_gated[l <= Gamma_r] = 0
235
+ masked = (l > Gamma_a) * (l > Gamma_r)
236
+ z_avg_gated = z_avg_gated.sum(2) / masked.sum(2)
237
+
238
+ # # Cannot use nan_to_num (pytorch 1.8 does not come with GCP-supported cuda version)
239
+ # z_avg_gated = torch.nan_to_num(z_avg_gated)
240
+ z_avg_gated = torch.where(
241
+ z_avg_gated.isnan(), torch.zeros_like(z_avg_gated), z_avg_gated
242
+ )
243
+ z_avg_gated[z_avg_gated == float("inf")] = float(np.finfo(np.float32).max)
244
+ z_avg_gated[z_avg_gated == -float("inf")] = float(np.finfo(np.float32).min)
245
+
246
+ LUFS = -0.691 + 10.0 * torch.log10((G[None, :nch] * z_avg_gated).sum(1))
247
+ return LUFS.float()
248
+
249
+ @property
250
+ def filter_class(self):
251
+ return self._filter_class
252
+
253
+ @filter_class.setter
254
+ def filter_class(self, value):
255
+ from pyloudnorm import Meter
256
+
257
+ meter = Meter(self.rate)
258
+ meter.filter_class = value
259
+ self._filter_class = value
260
+ self._filters = meter._filters
261
+
262
+
263
+ class LoudnessMixin:
264
+ _loudness = None
265
+ MIN_LOUDNESS = -70
266
+ """Minimum loudness possible."""
267
+
268
+ def loudness(
269
+ self, filter_class: str = "K-weighting", block_size: float = 0.400, **kwargs
270
+ ):
271
+ """Calculates loudness using an implementation of ITU-R BS.1770-4.
272
+ Allows control over gating block size and frequency weighting filters for
273
+ additional control. Measure the integrated gated loudness of a signal.
274
+
275
+ API is derived from PyLoudnorm, but this implementation is ported to PyTorch
276
+ and is tensorized across batches. When on GPU, an FIR approximation of the IIR
277
+ filters is used to compute loudness for speed.
278
+
279
+ Uses the weighting filters and block size defined by the meter
280
+ the integrated loudness is measured based upon the gating algorithm
281
+ defined in the ITU-R BS.1770-4 specification.
282
+
283
+ Parameters
284
+ ----------
285
+ filter_class : str, optional
286
+ Class of weighting filter used.
287
+ K-weighting' (default), 'Fenton/Lee 1'
288
+ 'Fenton/Lee 2', 'Dash et al.'
289
+ by default "K-weighting"
290
+ block_size : float, optional
291
+ Gating block size in seconds, by default 0.400
292
+ kwargs : dict, optional
293
+ Keyword arguments to :py:func:`audiotools.core.loudness.Meter`.
294
+
295
+ Returns
296
+ -------
297
+ torch.Tensor
298
+ Loudness of audio data.
299
+ """
300
+ if self._loudness is not None:
301
+ return self._loudness.to(self.device)
302
+ original_length = self.signal_length
303
+ if self.signal_duration < 0.5:
304
+ pad_len = int((0.5 - self.signal_duration) * self.sample_rate)
305
+ self.zero_pad(0, pad_len)
306
+
307
+ # create BS.1770 meter
308
+ meter = Meter(
309
+ self.sample_rate, filter_class=filter_class, block_size=block_size, **kwargs
310
+ )
311
+ meter = meter.to(self.device)
312
+ # measure loudness
313
+ loudness = meter.integrated_loudness(self.audio_data.permute(0, 2, 1))
314
+ self.truncate_samples(original_length)
315
+ min_loudness = (
316
+ torch.ones_like(loudness, device=loudness.device) * self.MIN_LOUDNESS
317
+ )
318
+ self._loudness = torch.maximum(loudness, min_loudness)
319
+
320
+ return self._loudness.to(self.device)
audiotools/core/playback.py ADDED
@@ -0,0 +1,252 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ These are utilities that allow one to embed an AudioSignal
3
+ as a playable object in a Jupyter notebook, or to play audio from
4
+ the terminal, etc.
5
+ """ # fmt: skip
6
+ import base64
7
+ import io
8
+ import random
9
+ import string
10
+ import subprocess
11
+ from tempfile import NamedTemporaryFile
12
+
13
+ import importlib_resources as pkg_resources
14
+
15
+ from . import templates
16
+ from .util import _close_temp_files
17
+ from .util import format_figure
18
+
19
+ headers = pkg_resources.files(templates).joinpath("headers.html").read_text()
20
+ widget = pkg_resources.files(templates).joinpath("widget.html").read_text()
21
+
22
+ DEFAULT_EXTENSION = ".wav"
23
+
24
+
25
+ def _check_imports(): # pragma: no cover
26
+ try:
27
+ import ffmpy
28
+ except:
29
+ ffmpy = False
30
+
31
+ try:
32
+ import IPython
33
+ except:
34
+ raise ImportError("IPython must be installed in order to use this function!")
35
+ return ffmpy, IPython
36
+
37
+
38
+ class PlayMixin:
39
+ def embed(self, ext: str = None, display: bool = True, return_html: bool = False):
40
+ """Embeds audio as a playable audio embed in a notebook, or HTML
41
+ document, etc.
42
+
43
+ Parameters
44
+ ----------
45
+ ext : str, optional
46
+ Extension to use when saving the audio, by default ".wav"
47
+ display : bool, optional
48
+ This controls whether or not to display the audio when called. This
49
+ is used when the embed is the last line in a Jupyter cell, to prevent
50
+ the audio from being embedded twice, by default True
51
+ return_html : bool, optional
52
+ Whether to return the data wrapped in an HTML audio element, by default False
53
+
54
+ Returns
55
+ -------
56
+ str
57
+ Either the element for display, or the HTML string of it.
58
+ """
59
+ if ext is None:
60
+ ext = DEFAULT_EXTENSION
61
+ ext = f".{ext}" if not ext.startswith(".") else ext
62
+ ffmpy, IPython = _check_imports()
63
+ sr = self.sample_rate
64
+ tmpfiles = []
65
+
66
+ with _close_temp_files(tmpfiles):
67
+ tmp_wav = NamedTemporaryFile(mode="w+", suffix=".wav", delete=False)
68
+ tmpfiles.append(tmp_wav)
69
+ self.write(tmp_wav.name)
70
+ if ext != ".wav" and ffmpy:
71
+ tmp_converted = NamedTemporaryFile(mode="w+", suffix=ext, delete=False)
72
+ tmpfiles.append(tmp_wav)
73
+ ff = ffmpy.FFmpeg(
74
+ inputs={tmp_wav.name: None},
75
+ outputs={
76
+ tmp_converted.name: "-write_xing 0 -codec:a libmp3lame -b:a 128k -y -hide_banner -loglevel error"
77
+ },
78
+ )
79
+ ff.run()
80
+ else:
81
+ tmp_converted = tmp_wav
82
+
83
+ audio_element = IPython.display.Audio(data=tmp_converted.name, rate=sr)
84
+ if display:
85
+ IPython.display.display(audio_element)
86
+
87
+ if return_html:
88
+ audio_element = (
89
+ f"<audio "
90
+ f" controls "
91
+ f" src='{audio_element.src_attr()}'> "
92
+ f"</audio> "
93
+ )
94
+ return audio_element
95
+
96
+ def widget(
97
+ self,
98
+ title: str = None,
99
+ ext: str = ".wav",
100
+ add_headers: bool = True,
101
+ player_width: str = "100%",
102
+ margin: str = "10px",
103
+ plot_fn: str = "specshow",
104
+ return_html: bool = False,
105
+ **kwargs,
106
+ ):
107
+ """Creates a playable widget with spectrogram. Inspired (heavily) by
108
+ https://sjvasquez.github.io/blog/melnet/.
109
+
110
+ Parameters
111
+ ----------
112
+ title : str, optional
113
+ Title of plot, placed in upper right of top-most axis.
114
+ ext : str, optional
115
+ Extension for embedding, by default ".mp3"
116
+ add_headers : bool, optional
117
+ Whether or not to add headers (use for first embed, False for later embeds), by default True
118
+ player_width : str, optional
119
+ Width of the player, as a string in a CSS rule, by default "100%"
120
+ margin : str, optional
121
+ Margin on all sides of player, by default "10px"
122
+ plot_fn : function, optional
123
+ Plotting function to use (by default self.specshow).
124
+ return_html : bool, optional
125
+ Whether to return the data wrapped in an HTML audio element, by default False
126
+ kwargs : dict, optional
127
+ Keyword arguments to plot_fn (by default self.specshow).
128
+
129
+ Returns
130
+ -------
131
+ HTML
132
+ HTML object.
133
+ """
134
+ import matplotlib.pyplot as plt
135
+
136
+ def _save_fig_to_tag():
137
+ buffer = io.BytesIO()
138
+
139
+ plt.savefig(buffer, bbox_inches="tight", pad_inches=0)
140
+ plt.close()
141
+
142
+ buffer.seek(0)
143
+ data_uri = base64.b64encode(buffer.read()).decode("ascii")
144
+ tag = "data:image/png;base64,{0}".format(data_uri)
145
+
146
+ return tag
147
+
148
+ _, IPython = _check_imports()
149
+
150
+ header_html = ""
151
+
152
+ if add_headers:
153
+ header_html = headers.replace("PLAYER_WIDTH", str(player_width))
154
+ header_html = header_html.replace("MARGIN", str(margin))
155
+ IPython.display.display(IPython.display.HTML(header_html))
156
+
157
+ widget_html = widget
158
+ if isinstance(plot_fn, str):
159
+ plot_fn = getattr(self, plot_fn)
160
+ kwargs["title"] = title
161
+ plot_fn(**kwargs)
162
+
163
+ fig = plt.gcf()
164
+ pixels = fig.get_size_inches() * fig.dpi
165
+
166
+ tag = _save_fig_to_tag()
167
+
168
+ # Make the source image for the levels
169
+ self.specshow()
170
+ format_figure((12, 1.5))
171
+ levels_tag = _save_fig_to_tag()
172
+
173
+ player_id = "".join(random.choice(string.ascii_uppercase) for _ in range(10))
174
+
175
+ audio_elem = self.embed(ext=ext, display=False)
176
+ widget_html = widget_html.replace("AUDIO_SRC", audio_elem.src_attr())
177
+ widget_html = widget_html.replace("IMAGE_SRC", tag)
178
+ widget_html = widget_html.replace("LEVELS_SRC", levels_tag)
179
+ widget_html = widget_html.replace("PLAYER_ID", player_id)
180
+
181
+ # Calculate width/height of figure based on figure size.
182
+ widget_html = widget_html.replace("PADDING_AMOUNT", f"{int(pixels[1])}px")
183
+ widget_html = widget_html.replace("MAX_WIDTH", f"{int(pixels[0])}px")
184
+
185
+ IPython.display.display(IPython.display.HTML(widget_html))
186
+
187
+ if return_html:
188
+ html = header_html if add_headers else ""
189
+ html += widget_html
190
+ return html
191
+
192
+ def play(self):
193
+ """
194
+ Plays an audio signal if ffplay from the ffmpeg suite of tools is installed.
195
+ Otherwise, will fail. The audio signal is written to a temporary file
196
+ and then played with ffplay.
197
+ """
198
+ tmpfiles = []
199
+ with _close_temp_files(tmpfiles):
200
+ tmp_wav = NamedTemporaryFile(suffix=".wav", delete=False)
201
+ tmpfiles.append(tmp_wav)
202
+ self.write(tmp_wav.name)
203
+ print(self)
204
+ subprocess.call(
205
+ [
206
+ "ffplay",
207
+ "-nodisp",
208
+ "-autoexit",
209
+ "-hide_banner",
210
+ "-loglevel",
211
+ "error",
212
+ tmp_wav.name,
213
+ ]
214
+ )
215
+ return self
216
+
217
+
218
+ if __name__ == "__main__": # pragma: no cover
219
+ from audiotools import AudioSignal
220
+
221
+ signal = AudioSignal(
222
+ "tests/audio/spk/f10_script4_produced.mp3", offset=5, duration=5
223
+ )
224
+
225
+ wave_html = signal.widget(
226
+ "Waveform",
227
+ plot_fn="waveplot",
228
+ return_html=True,
229
+ )
230
+
231
+ spec_html = signal.widget("Spectrogram", return_html=True, add_headers=False)
232
+
233
+ combined_html = signal.widget(
234
+ "Waveform + spectrogram",
235
+ plot_fn="wavespec",
236
+ return_html=True,
237
+ add_headers=False,
238
+ )
239
+
240
+ signal.low_pass(8000)
241
+ lowpass_html = signal.widget(
242
+ "Lowpassed audio",
243
+ plot_fn="wavespec",
244
+ return_html=True,
245
+ add_headers=False,
246
+ )
247
+
248
+ with open("/tmp/index.html", "w") as f:
249
+ f.write(wave_html)
250
+ f.write(spec_html)
251
+ f.write(combined_html)
252
+ f.write(lowpass_html)
audiotools/core/templates/__init__.py ADDED
File without changes
audiotools/core/templates/headers.html ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <style>
2
+ .player {
3
+ width: 100%;
4
+ /*border: 1px solid black;*/
5
+ margin: 10px;
6
+ }
7
+
8
+ .underlay img {
9
+ width: 100%;
10
+ height: 100%;
11
+ }
12
+
13
+ .spectrogram {
14
+ height: 0;
15
+ width: 100%;
16
+ position: relative;
17
+ }
18
+
19
+ .audio-controls {
20
+ width: 100%;
21
+ height: 54px;
22
+ display: flex;
23
+ /*border-top: 1px solid black;*/
24
+ /*background-color: rgb(241, 243, 244);*/
25
+ background-color: rgb(248, 248, 248);
26
+ background-color: rgb(253, 253, 254);
27
+ border: 1px solid rgb(205, 208, 211);
28
+ margin-top: 20px;
29
+ /*border: 1px solid black;*/
30
+ border-radius: 30px;
31
+
32
+ }
33
+
34
+ .play-img {
35
+ margin: auto;
36
+ height: 45%;
37
+ width: 45%;
38
+ display: block;
39
+ }
40
+
41
+ .download-img {
42
+ margin: auto;
43
+ height: 100%;
44
+ width: 100%;
45
+ display: block;
46
+ }
47
+
48
+ .pause-img {
49
+ margin: auto;
50
+ height: 45%;
51
+ width: 45%;
52
+ display: none
53
+ }
54
+
55
+ .playpause {
56
+ margin:11px 11px 11px 11px;
57
+ width: 32px;
58
+ min-width: 32px;
59
+ height: 32px;
60
+ /*background-color: rgb(241, 243, 244);*/
61
+ background-color: rgba(0, 0, 0, 0.0);
62
+ /*border-right: 1px solid black;*/
63
+ /*border: 1px solid red;*/
64
+ border-radius: 16px;
65
+ color: black;
66
+ transition: 0.25s;
67
+ box-sizing: border-box !important;
68
+ }
69
+
70
+ .download {
71
+ margin:11px 11px 11px 11px;
72
+ width: 32px;
73
+ min-width: 32px;
74
+ height: 32px;
75
+ /*background-color: rgb(241, 243, 244);*/
76
+ background-color: rgba(0, 0, 0, 0.0);
77
+ /*border-right: 1px solid black;*/
78
+ /*border: 1px solid red;*/
79
+ border-radius: 16px;
80
+ color: black;
81
+ transition: 0.25s;
82
+ box-sizing: border-box !important;
83
+ }
84
+
85
+ /*.playpause:disabled {
86
+ background-color: red;
87
+ }*/
88
+
89
+ .playpause:hover {
90
+ background-color: rgba(10, 20, 30, 0.03);
91
+ }
92
+
93
+ .playpause:focus {
94
+ outline:none;
95
+ }
96
+
97
+ .response {
98
+ padding:0px 20px 0px 0px;
99
+ width: calc(100% - 132px);
100
+ height: 100%;
101
+
102
+ /*border: 1px solid red;*/
103
+ /*border-bottom: 1px solid rgb(89, 89, 89);*/
104
+ }
105
+
106
+ .response-canvas {
107
+ height: 100%;
108
+ width: 100%;
109
+ }
110
+
111
+
112
+ .underlay {
113
+ height: 100%;
114
+ width: 100%;
115
+ position: absolute;
116
+ top: 0;
117
+ left: 0;
118
+ }
119
+
120
+ .overlay{
121
+ width: 0%;
122
+ height:100%;
123
+ top: 0;
124
+ right: 0px;
125
+
126
+ background:rgba(255, 255, 255, 0.15);
127
+ overflow:hidden;
128
+ position: absolute;
129
+ z-index: 10;
130
+ border-left: solid 1px rgba(0, 0, 0, 0.664);
131
+
132
+ position: absolute;
133
+ pointer-events: none;
134
+ }
135
+ </style>
136
+
137
+ <script>
138
+ !function(t){if("object"==typeof exports&&"undefined"!=typeof module)module.exports=t();else if("function"==typeof define&&define.amd)define([],t);else{("undefined"!=typeof window?window:"undefined"!=typeof global?global:"undefined"!=typeof self?self:this).pako=t()}}(function(){return function(){return function t(e,a,i){function n(s,o){if(!a[s]){if(!e[s]){var l="function"==typeof require&&require;if(!o&&l)return l(s,!0);if(r)return r(s,!0);var h=new Error("Cannot find module '"+s+"'");throw h.code="MODULE_NOT_FOUND",h}var d=a[s]={exports:{}};e[s][0].call(d.exports,function(t){return n(e[s][1][t]||t)},d,d.exports,t,e,a,i)}return a[s].exports}for(var r="function"==typeof require&&require,s=0;s<i.length;s++)n(i[s]);return n}}()({1:[function(t,e,a){"use strict";var i=t("./zlib/deflate"),n=t("./utils/common"),r=t("./utils/strings"),s=t("./zlib/messages"),o=t("./zlib/zstream"),l=Object.prototype.toString,h=0,d=-1,f=0,_=8;function u(t){if(!(this instanceof u))return new u(t);this.options=n.assign({level:d,method:_,chunkSize:16384,windowBits:15,memLevel:8,strategy:f,to:""},t||{});var e=this.options;e.raw&&e.windowBits>0?e.windowBits=-e.windowBits:e.gzip&&e.windowBits>0&&e.windowBits<16&&(e.windowBits+=16),this.err=0,this.msg="",this.ended=!1,this.chunks=[],this.strm=new o,this.strm.avail_out=0;var a=i.deflateInit2(this.strm,e.level,e.method,e.windowBits,e.memLevel,e.strategy);if(a!==h)throw new Error(s[a]);if(e.header&&i.deflateSetHeader(this.strm,e.header),e.dictionary){var c;if(c="string"==typeof e.dictionary?r.string2buf(e.dictionary):"[object ArrayBuffer]"===l.call(e.dictionary)?new Uint8Array(e.dictionary):e.dictionary,(a=i.deflateSetDictionary(this.strm,c))!==h)throw new Error(s[a]);this._dict_set=!0}}function c(t,e){var a=new u(e);if(a.push(t,!0),a.err)throw a.msg||s[a.err];return a.result}u.prototype.push=function(t,e){var a,s,o=this.strm,d=this.options.chunkSize;if(this.ended)return!1;s=e===~~e?e:!0===e?4:0,"string"==typeof t?o.input=r.string2buf(t):"[object ArrayBuffer]"===l.call(t)?o.input=new Uint8Array(t):o.input=t,o.next_in=0,o.avail_in=o.input.length;do{if(0===o.avail_out&&(o.output=new n.Buf8(d),o.next_out=0,o.avail_out=d),1!==(a=i.deflate(o,s))&&a!==h)return this.onEnd(a),this.ended=!0,!1;0!==o.avail_out&&(0!==o.avail_in||4!==s&&2!==s)||("string"===this.options.to?this.onData(r.buf2binstring(n.shrinkBuf(o.output,o.next_out))):this.onData(n.shrinkBuf(o.output,o.next_out)))}while((o.avail_in>0||0===o.avail_out)&&1!==a);return 4===s?(a=i.deflateEnd(this.strm),this.onEnd(a),this.ended=!0,a===h):2!==s||(this.onEnd(h),o.avail_out=0,!0)},u.prototype.onData=function(t){this.chunks.push(t)},u.prototype.onEnd=function(t){t===h&&("string"===this.options.to?this.result=this.chunks.join(""):this.result=n.flattenChunks(this.chunks)),this.chunks=[],this.err=t,this.msg=this.strm.msg},a.Deflate=u,a.deflate=c,a.deflateRaw=function(t,e){return(e=e||{}).raw=!0,c(t,e)},a.gzip=function(t,e){return(e=e||{}).gzip=!0,c(t,e)}},{"./utils/common":3,"./utils/strings":4,"./zlib/deflate":8,"./zlib/messages":13,"./zlib/zstream":15}],2:[function(t,e,a){"use strict";var i=t("./zlib/inflate"),n=t("./utils/common"),r=t("./utils/strings"),s=t("./zlib/constants"),o=t("./zlib/messages"),l=t("./zlib/zstream"),h=t("./zlib/gzheader"),d=Object.prototype.toString;function f(t){if(!(this instanceof f))return new f(t);this.options=n.assign({chunkSize:16384,windowBits:0,to:""},t||{});var e=this.options;e.raw&&e.windowBits>=0&&e.windowBits<16&&(e.windowBits=-e.windowBits,0===e.windowBits&&(e.windowBits=-15)),!(e.windowBits>=0&&e.windowBits<16)||t&&t.windowBits||(e.windowBits+=32),e.windowBits>15&&e.windowBits<48&&0==(15&e.windowBits)&&(e.windowBits|=15),this.err=0,this.msg="",this.ended=!1,this.chunks=[],this.strm=new l,this.strm.avail_out=0;var a=i.inflateInit2(this.strm,e.windowBits);if(a!==s.Z_OK)throw new Error(o[a]);if(this.header=new h,i.inflateGetHeader(this.strm,this.header),e.dictionary&&("string"==typeof e.dictionary?e.dictionary=r.string2buf(e.dictionary):"[object ArrayBuffer]"===d.call(e.dictionary)&&(e.dictionary=new Uint8Array(e.dictionary)),e.raw&&(a=i.inflateSetDictionary(this.strm,e.dictionary))!==s.Z_OK))throw new Error(o[a])}function _(t,e){var a=new f(e);if(a.push(t,!0),a.err)throw a.msg||o[a.err];return a.result}f.prototype.push=function(t,e){var a,o,l,h,f,_=this.strm,u=this.options.chunkSize,c=this.options.dictionary,b=!1;if(this.ended)return!1;o=e===~~e?e:!0===e?s.Z_FINISH:s.Z_NO_FLUSH,"string"==typeof t?_.input=r.binstring2buf(t):"[object ArrayBuffer]"===d.call(t)?_.input=new Uint8Array(t):_.input=t,_.next_in=0,_.avail_in=_.input.length;do{if(0===_.avail_out&&(_.output=new n.Buf8(u),_.next_out=0,_.avail_out=u),(a=i.inflate(_,s.Z_NO_FLUSH))===s.Z_NEED_DICT&&c&&(a=i.inflateSetDictionary(this.strm,c)),a===s.Z_BUF_ERROR&&!0===b&&(a=s.Z_OK,b=!1),a!==s.Z_STREAM_END&&a!==s.Z_OK)return this.onEnd(a),this.ended=!0,!1;_.next_out&&(0!==_.avail_out&&a!==s.Z_STREAM_END&&(0!==_.avail_in||o!==s.Z_FINISH&&o!==s.Z_SYNC_FLUSH)||("string"===this.options.to?(l=r.utf8border(_.output,_.next_out),h=_.next_out-l,f=r.buf2string(_.output,l),_.next_out=h,_.avail_out=u-h,h&&n.arraySet(_.output,_.output,l,h,0),this.onData(f)):this.onData(n.shrinkBuf(_.output,_.next_out)))),0===_.avail_in&&0===_.avail_out&&(b=!0)}while((_.avail_in>0||0===_.avail_out)&&a!==s.Z_STREAM_END);return a===s.Z_STREAM_END&&(o=s.Z_FINISH),o===s.Z_FINISH?(a=i.inflateEnd(this.strm),this.onEnd(a),this.ended=!0,a===s.Z_OK):o!==s.Z_SYNC_FLUSH||(this.onEnd(s.Z_OK),_.avail_out=0,!0)},f.prototype.onData=function(t){this.chunks.push(t)},f.prototype.onEnd=function(t){t===s.Z_OK&&("string"===this.options.to?this.result=this.chunks.join(""):this.result=n.flattenChunks(this.chunks)),this.chunks=[],this.err=t,this.msg=this.strm.msg},a.Inflate=f,a.inflate=_,a.inflateRaw=function(t,e){return(e=e||{}).raw=!0,_(t,e)},a.ungzip=_},{"./utils/common":3,"./utils/strings":4,"./zlib/constants":6,"./zlib/gzheader":9,"./zlib/inflate":11,"./zlib/messages":13,"./zlib/zstream":15}],3:[function(t,e,a){"use strict";var i="undefined"!=typeof Uint8Array&&"undefined"!=typeof Uint16Array&&"undefined"!=typeof Int32Array;function n(t,e){return Object.prototype.hasOwnProperty.call(t,e)}a.assign=function(t){for(var e=Array.prototype.slice.call(arguments,1);e.length;){var a=e.shift();if(a){if("object"!=typeof a)throw new TypeError(a+"must be non-object");for(var i in a)n(a,i)&&(t[i]=a[i])}}return t},a.shrinkBuf=function(t,e){return t.length===e?t:t.subarray?t.subarray(0,e):(t.length=e,t)};var r={arraySet:function(t,e,a,i,n){if(e.subarray&&t.subarray)t.set(e.subarray(a,a+i),n);else for(var r=0;r<i;r++)t[n+r]=e[a+r]},flattenChunks:function(t){var e,a,i,n,r,s;for(i=0,e=0,a=t.length;e<a;e++)i+=t[e].length;for(s=new Uint8Array(i),n=0,e=0,a=t.length;e<a;e++)r=t[e],s.set(r,n),n+=r.length;return s}},s={arraySet:function(t,e,a,i,n){for(var r=0;r<i;r++)t[n+r]=e[a+r]},flattenChunks:function(t){return[].concat.apply([],t)}};a.setTyped=function(t){t?(a.Buf8=Uint8Array,a.Buf16=Uint16Array,a.Buf32=Int32Array,a.assign(a,r)):(a.Buf8=Array,a.Buf16=Array,a.Buf32=Array,a.assign(a,s))},a.setTyped(i)},{}],4:[function(t,e,a){"use strict";var i=t("./common"),n=!0,r=!0;try{String.fromCharCode.apply(null,[0])}catch(t){n=!1}try{String.fromCharCode.apply(null,new Uint8Array(1))}catch(t){r=!1}for(var s=new i.Buf8(256),o=0;o<256;o++)s[o]=o>=252?6:o>=248?5:o>=240?4:o>=224?3:o>=192?2:1;function l(t,e){if(e<65534&&(t.subarray&&r||!t.subarray&&n))return String.fromCharCode.apply(null,i.shrinkBuf(t,e));for(var a="",s=0;s<e;s++)a+=String.fromCharCode(t[s]);return a}s[254]=s[254]=1,a.string2buf=function(t){var e,a,n,r,s,o=t.length,l=0;for(r=0;r<o;r++)55296==(64512&(a=t.charCodeAt(r)))&&r+1<o&&56320==(64512&(n=t.charCodeAt(r+1)))&&(a=65536+(a-55296<<10)+(n-56320),r++),l+=a<128?1:a<2048?2:a<65536?3:4;for(e=new i.Buf8(l),s=0,r=0;s<l;r++)55296==(64512&(a=t.charCodeAt(r)))&&r+1<o&&56320==(64512&(n=t.charCodeAt(r+1)))&&(a=65536+(a-55296<<10)+(n-56320),r++),a<128?e[s++]=a:a<2048?(e[s++]=192|a>>>6,e[s++]=128|63&a):a<65536?(e[s++]=224|a>>>12,e[s++]=128|a>>>6&63,e[s++]=128|63&a):(e[s++]=240|a>>>18,e[s++]=128|a>>>12&63,e[s++]=128|a>>>6&63,e[s++]=128|63&a);return e},a.buf2binstring=function(t){return l(t,t.length)},a.binstring2buf=function(t){for(var e=new i.Buf8(t.length),a=0,n=e.length;a<n;a++)e[a]=t.charCodeAt(a);return e},a.buf2string=function(t,e){var a,i,n,r,o=e||t.length,h=new Array(2*o);for(i=0,a=0;a<o;)if((n=t[a++])<128)h[i++]=n;else if((r=s[n])>4)h[i++]=65533,a+=r-1;else{for(n&=2===r?31:3===r?15:7;r>1&&a<o;)n=n<<6|63&t[a++],r--;r>1?h[i++]=65533:n<65536?h[i++]=n:(n-=65536,h[i++]=55296|n>>10&1023,h[i++]=56320|1023&n)}return l(h,i)},a.utf8border=function(t,e){var a;for((e=e||t.length)>t.length&&(e=t.length),a=e-1;a>=0&&128==(192&t[a]);)a--;return a<0?e:0===a?e:a+s[t[a]]>e?a:e}},{"./common":3}],5:[function(t,e,a){"use strict";e.exports=function(t,e,a,i){for(var n=65535&t|0,r=t>>>16&65535|0,s=0;0!==a;){a-=s=a>2e3?2e3:a;do{r=r+(n=n+e[i++]|0)|0}while(--s);n%=65521,r%=65521}return n|r<<16|0}},{}],6:[function(t,e,a){"use strict";e.exports={Z_NO_FLUSH:0,Z_PARTIAL_FLUSH:1,Z_SYNC_FLUSH:2,Z_FULL_FLUSH:3,Z_FINISH:4,Z_BLOCK:5,Z_TREES:6,Z_OK:0,Z_STREAM_END:1,Z_NEED_DICT:2,Z_ERRNO:-1,Z_STREAM_ERROR:-2,Z_DATA_ERROR:-3,Z_BUF_ERROR:-5,Z_NO_COMPRESSION:0,Z_BEST_SPEED:1,Z_BEST_COMPRESSION:9,Z_DEFAULT_COMPRESSION:-1,Z_FILTERED:1,Z_HUFFMAN_ONLY:2,Z_RLE:3,Z_FIXED:4,Z_DEFAULT_STRATEGY:0,Z_BINARY:0,Z_TEXT:1,Z_UNKNOWN:2,Z_DEFLATED:8}},{}],7:[function(t,e,a){"use strict";var i=function(){for(var t,e=[],a=0;a<256;a++){t=a;for(var i=0;i<8;i++)t=1&t?3988292384^t>>>1:t>>>1;e[a]=t}return e}();e.exports=function(t,e,a,n){var r=i,s=n+a;t^=-1;for(var o=n;o<s;o++)t=t>>>8^r[255&(t^e[o])];return-1^t}},{}],8:[function(t,e,a){"use strict";var i,n=t("../utils/common"),r=t("./trees"),s=t("./adler32"),o=t("./crc32"),l=t("./messages"),h=0,d=1,f=3,_=4,u=5,c=0,b=1,g=-2,m=-3,w=-5,p=-1,v=1,k=2,y=3,x=4,z=0,B=2,S=8,E=9,A=15,Z=8,R=286,C=30,N=19,O=2*R+1,D=15,I=3,U=258,T=U+I+1,F=32,L=42,H=69,j=73,K=91,M=103,P=113,Y=666,q=1,G=2,X=3,W=4,J=3;function Q(t,e){return t.msg=l[e],e}function V(t){return(t<<1)-(t>4?9:0)}function $(t){for(var e=t.length;--e>=0;)t[e]=0}function tt(t){var e=t.state,a=e.pending;a>t.avail_out&&(a=t.avail_out),0!==a&&(n.arraySet(t.output,e.pending_buf,e.pending_out,a,t.next_out),t.next_out+=a,e.pending_out+=a,t.total_out+=a,t.avail_out-=a,e.pending-=a,0===e.pending&&(e.pending_out=0))}function et(t,e){r._tr_flush_block(t,t.block_start>=0?t.block_start:-1,t.strstart-t.block_start,e),t.block_start=t.strstart,tt(t.strm)}function at(t,e){t.pending_buf[t.pending++]=e}function it(t,e){t.pending_buf[t.pending++]=e>>>8&255,t.pending_buf[t.pending++]=255&e}function nt(t,e){var a,i,n=t.max_chain_length,r=t.strstart,s=t.prev_length,o=t.nice_match,l=t.strstart>t.w_size-T?t.strstart-(t.w_size-T):0,h=t.window,d=t.w_mask,f=t.prev,_=t.strstart+U,u=h[r+s-1],c=h[r+s];t.prev_length>=t.good_match&&(n>>=2),o>t.lookahead&&(o=t.lookahead);do{if(h[(a=e)+s]===c&&h[a+s-1]===u&&h[a]===h[r]&&h[++a]===h[r+1]){r+=2,a++;do{}while(h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&h[++r]===h[++a]&&r<_);if(i=U-(_-r),r=_-U,i>s){if(t.match_start=e,s=i,i>=o)break;u=h[r+s-1],c=h[r+s]}}}while((e=f[e&d])>l&&0!=--n);return s<=t.lookahead?s:t.lookahead}function rt(t){var e,a,i,r,l,h,d,f,_,u,c=t.w_size;do{if(r=t.window_size-t.lookahead-t.strstart,t.strstart>=c+(c-T)){n.arraySet(t.window,t.window,c,c,0),t.match_start-=c,t.strstart-=c,t.block_start-=c,e=a=t.hash_size;do{i=t.head[--e],t.head[e]=i>=c?i-c:0}while(--a);e=a=c;do{i=t.prev[--e],t.prev[e]=i>=c?i-c:0}while(--a);r+=c}if(0===t.strm.avail_in)break;if(h=t.strm,d=t.window,f=t.strstart+t.lookahead,_=r,u=void 0,(u=h.avail_in)>_&&(u=_),a=0===u?0:(h.avail_in-=u,n.arraySet(d,h.input,h.next_in,u,f),1===h.state.wrap?h.adler=s(h.adler,d,u,f):2===h.state.wrap&&(h.adler=o(h.adler,d,u,f)),h.next_in+=u,h.total_in+=u,u),t.lookahead+=a,t.lookahead+t.insert>=I)for(l=t.strstart-t.insert,t.ins_h=t.window[l],t.ins_h=(t.ins_h<<t.hash_shift^t.window[l+1])&t.hash_mask;t.insert&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[l+I-1])&t.hash_mask,t.prev[l&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=l,l++,t.insert--,!(t.lookahead+t.insert<I)););}while(t.lookahead<T&&0!==t.strm.avail_in)}function st(t,e){for(var a,i;;){if(t.lookahead<T){if(rt(t),t.lookahead<T&&e===h)return q;if(0===t.lookahead)break}if(a=0,t.lookahead>=I&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart),0!==a&&t.strstart-a<=t.w_size-T&&(t.match_length=nt(t,a)),t.match_length>=I)if(i=r._tr_tally(t,t.strstart-t.match_start,t.match_length-I),t.lookahead-=t.match_length,t.match_length<=t.max_lazy_match&&t.lookahead>=I){t.match_length--;do{t.strstart++,t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart}while(0!=--t.match_length);t.strstart++}else t.strstart+=t.match_length,t.match_length=0,t.ins_h=t.window[t.strstart],t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+1])&t.hash_mask;else i=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++;if(i&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=t.strstart<I-1?t.strstart:I-1,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}function ot(t,e){for(var a,i,n;;){if(t.lookahead<T){if(rt(t),t.lookahead<T&&e===h)return q;if(0===t.lookahead)break}if(a=0,t.lookahead>=I&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart),t.prev_length=t.match_length,t.prev_match=t.match_start,t.match_length=I-1,0!==a&&t.prev_length<t.max_lazy_match&&t.strstart-a<=t.w_size-T&&(t.match_length=nt(t,a),t.match_length<=5&&(t.strategy===v||t.match_length===I&&t.strstart-t.match_start>4096)&&(t.match_length=I-1)),t.prev_length>=I&&t.match_length<=t.prev_length){n=t.strstart+t.lookahead-I,i=r._tr_tally(t,t.strstart-1-t.prev_match,t.prev_length-I),t.lookahead-=t.prev_length-1,t.prev_length-=2;do{++t.strstart<=n&&(t.ins_h=(t.ins_h<<t.hash_shift^t.window[t.strstart+I-1])&t.hash_mask,a=t.prev[t.strstart&t.w_mask]=t.head[t.ins_h],t.head[t.ins_h]=t.strstart)}while(0!=--t.prev_length);if(t.match_available=0,t.match_length=I-1,t.strstart++,i&&(et(t,!1),0===t.strm.avail_out))return q}else if(t.match_available){if((i=r._tr_tally(t,0,t.window[t.strstart-1]))&&et(t,!1),t.strstart++,t.lookahead--,0===t.strm.avail_out)return q}else t.match_available=1,t.strstart++,t.lookahead--}return t.match_available&&(i=r._tr_tally(t,0,t.window[t.strstart-1]),t.match_available=0),t.insert=t.strstart<I-1?t.strstart:I-1,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}function lt(t,e,a,i,n){this.good_length=t,this.max_lazy=e,this.nice_length=a,this.max_chain=i,this.func=n}function ht(){this.strm=null,this.status=0,this.pending_buf=null,this.pending_buf_size=0,this.pending_out=0,this.pending=0,this.wrap=0,this.gzhead=null,this.gzindex=0,this.method=S,this.last_flush=-1,this.w_size=0,this.w_bits=0,this.w_mask=0,this.window=null,this.window_size=0,this.prev=null,this.head=null,this.ins_h=0,this.hash_size=0,this.hash_bits=0,this.hash_mask=0,this.hash_shift=0,this.block_start=0,this.match_length=0,this.prev_match=0,this.match_available=0,this.strstart=0,this.match_start=0,this.lookahead=0,this.prev_length=0,this.max_chain_length=0,this.max_lazy_match=0,this.level=0,this.strategy=0,this.good_match=0,this.nice_match=0,this.dyn_ltree=new n.Buf16(2*O),this.dyn_dtree=new n.Buf16(2*(2*C+1)),this.bl_tree=new n.Buf16(2*(2*N+1)),$(this.dyn_ltree),$(this.dyn_dtree),$(this.bl_tree),this.l_desc=null,this.d_desc=null,this.bl_desc=null,this.bl_count=new n.Buf16(D+1),this.heap=new n.Buf16(2*R+1),$(this.heap),this.heap_len=0,this.heap_max=0,this.depth=new n.Buf16(2*R+1),$(this.depth),this.l_buf=0,this.lit_bufsize=0,this.last_lit=0,this.d_buf=0,this.opt_len=0,this.static_len=0,this.matches=0,this.insert=0,this.bi_buf=0,this.bi_valid=0}function dt(t){var e;return t&&t.state?(t.total_in=t.total_out=0,t.data_type=B,(e=t.state).pending=0,e.pending_out=0,e.wrap<0&&(e.wrap=-e.wrap),e.status=e.wrap?L:P,t.adler=2===e.wrap?0:1,e.last_flush=h,r._tr_init(e),c):Q(t,g)}function ft(t){var e,a=dt(t);return a===c&&((e=t.state).window_size=2*e.w_size,$(e.head),e.max_lazy_match=i[e.level].max_lazy,e.good_match=i[e.level].good_length,e.nice_match=i[e.level].nice_length,e.max_chain_length=i[e.level].max_chain,e.strstart=0,e.block_start=0,e.lookahead=0,e.insert=0,e.match_length=e.prev_length=I-1,e.match_available=0,e.ins_h=0),a}function _t(t,e,a,i,r,s){if(!t)return g;var o=1;if(e===p&&(e=6),i<0?(o=0,i=-i):i>15&&(o=2,i-=16),r<1||r>E||a!==S||i<8||i>15||e<0||e>9||s<0||s>x)return Q(t,g);8===i&&(i=9);var l=new ht;return t.state=l,l.strm=t,l.wrap=o,l.gzhead=null,l.w_bits=i,l.w_size=1<<l.w_bits,l.w_mask=l.w_size-1,l.hash_bits=r+7,l.hash_size=1<<l.hash_bits,l.hash_mask=l.hash_size-1,l.hash_shift=~~((l.hash_bits+I-1)/I),l.window=new n.Buf8(2*l.w_size),l.head=new n.Buf16(l.hash_size),l.prev=new n.Buf16(l.w_size),l.lit_bufsize=1<<r+6,l.pending_buf_size=4*l.lit_bufsize,l.pending_buf=new n.Buf8(l.pending_buf_size),l.d_buf=1*l.lit_bufsize,l.l_buf=3*l.lit_bufsize,l.level=e,l.strategy=s,l.method=a,ft(t)}i=[new lt(0,0,0,0,function(t,e){var a=65535;for(a>t.pending_buf_size-5&&(a=t.pending_buf_size-5);;){if(t.lookahead<=1){if(rt(t),0===t.lookahead&&e===h)return q;if(0===t.lookahead)break}t.strstart+=t.lookahead,t.lookahead=0;var i=t.block_start+a;if((0===t.strstart||t.strstart>=i)&&(t.lookahead=t.strstart-i,t.strstart=i,et(t,!1),0===t.strm.avail_out))return q;if(t.strstart-t.block_start>=t.w_size-T&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):(t.strstart>t.block_start&&(et(t,!1),t.strm.avail_out),q)}),new lt(4,4,8,4,st),new lt(4,5,16,8,st),new lt(4,6,32,32,st),new lt(4,4,16,16,ot),new lt(8,16,32,32,ot),new lt(8,16,128,128,ot),new lt(8,32,128,256,ot),new lt(32,128,258,1024,ot),new lt(32,258,258,4096,ot)],a.deflateInit=function(t,e){return _t(t,e,S,A,Z,z)},a.deflateInit2=_t,a.deflateReset=ft,a.deflateResetKeep=dt,a.deflateSetHeader=function(t,e){return t&&t.state?2!==t.state.wrap?g:(t.state.gzhead=e,c):g},a.deflate=function(t,e){var a,n,s,l;if(!t||!t.state||e>u||e<0)return t?Q(t,g):g;if(n=t.state,!t.output||!t.input&&0!==t.avail_in||n.status===Y&&e!==_)return Q(t,0===t.avail_out?w:g);if(n.strm=t,a=n.last_flush,n.last_flush=e,n.status===L)if(2===n.wrap)t.adler=0,at(n,31),at(n,139),at(n,8),n.gzhead?(at(n,(n.gzhead.text?1:0)+(n.gzhead.hcrc?2:0)+(n.gzhead.extra?4:0)+(n.gzhead.name?8:0)+(n.gzhead.comment?16:0)),at(n,255&n.gzhead.time),at(n,n.gzhead.time>>8&255),at(n,n.gzhead.time>>16&255),at(n,n.gzhead.time>>24&255),at(n,9===n.level?2:n.strategy>=k||n.level<2?4:0),at(n,255&n.gzhead.os),n.gzhead.extra&&n.gzhead.extra.length&&(at(n,255&n.gzhead.extra.length),at(n,n.gzhead.extra.length>>8&255)),n.gzhead.hcrc&&(t.adler=o(t.adler,n.pending_buf,n.pending,0)),n.gzindex=0,n.status=H):(at(n,0),at(n,0),at(n,0),at(n,0),at(n,0),at(n,9===n.level?2:n.strategy>=k||n.level<2?4:0),at(n,J),n.status=P);else{var m=S+(n.w_bits-8<<4)<<8;m|=(n.strategy>=k||n.level<2?0:n.level<6?1:6===n.level?2:3)<<6,0!==n.strstart&&(m|=F),m+=31-m%31,n.status=P,it(n,m),0!==n.strstart&&(it(n,t.adler>>>16),it(n,65535&t.adler)),t.adler=1}if(n.status===H)if(n.gzhead.extra){for(s=n.pending;n.gzindex<(65535&n.gzhead.extra.length)&&(n.pending!==n.pending_buf_size||(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending!==n.pending_buf_size));)at(n,255&n.gzhead.extra[n.gzindex]),n.gzindex++;n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),n.gzindex===n.gzhead.extra.length&&(n.gzindex=0,n.status=j)}else n.status=j;if(n.status===j)if(n.gzhead.name){s=n.pending;do{if(n.pending===n.pending_buf_size&&(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending===n.pending_buf_size)){l=1;break}l=n.gzindex<n.gzhead.name.length?255&n.gzhead.name.charCodeAt(n.gzindex++):0,at(n,l)}while(0!==l);n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),0===l&&(n.gzindex=0,n.status=K)}else n.status=K;if(n.status===K)if(n.gzhead.comment){s=n.pending;do{if(n.pending===n.pending_buf_size&&(n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),tt(t),s=n.pending,n.pending===n.pending_buf_size)){l=1;break}l=n.gzindex<n.gzhead.comment.length?255&n.gzhead.comment.charCodeAt(n.gzindex++):0,at(n,l)}while(0!==l);n.gzhead.hcrc&&n.pending>s&&(t.adler=o(t.adler,n.pending_buf,n.pending-s,s)),0===l&&(n.status=M)}else n.status=M;if(n.status===M&&(n.gzhead.hcrc?(n.pending+2>n.pending_buf_size&&tt(t),n.pending+2<=n.pending_buf_size&&(at(n,255&t.adler),at(n,t.adler>>8&255),t.adler=0,n.status=P)):n.status=P),0!==n.pending){if(tt(t),0===t.avail_out)return n.last_flush=-1,c}else if(0===t.avail_in&&V(e)<=V(a)&&e!==_)return Q(t,w);if(n.status===Y&&0!==t.avail_in)return Q(t,w);if(0!==t.avail_in||0!==n.lookahead||e!==h&&n.status!==Y){var p=n.strategy===k?function(t,e){for(var a;;){if(0===t.lookahead&&(rt(t),0===t.lookahead)){if(e===h)return q;break}if(t.match_length=0,a=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++,a&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}(n,e):n.strategy===y?function(t,e){for(var a,i,n,s,o=t.window;;){if(t.lookahead<=U){if(rt(t),t.lookahead<=U&&e===h)return q;if(0===t.lookahead)break}if(t.match_length=0,t.lookahead>=I&&t.strstart>0&&(i=o[n=t.strstart-1])===o[++n]&&i===o[++n]&&i===o[++n]){s=t.strstart+U;do{}while(i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&i===o[++n]&&n<s);t.match_length=U-(s-n),t.match_length>t.lookahead&&(t.match_length=t.lookahead)}if(t.match_length>=I?(a=r._tr_tally(t,1,t.match_length-I),t.lookahead-=t.match_length,t.strstart+=t.match_length,t.match_length=0):(a=r._tr_tally(t,0,t.window[t.strstart]),t.lookahead--,t.strstart++),a&&(et(t,!1),0===t.strm.avail_out))return q}return t.insert=0,e===_?(et(t,!0),0===t.strm.avail_out?X:W):t.last_lit&&(et(t,!1),0===t.strm.avail_out)?q:G}(n,e):i[n.level].func(n,e);if(p!==X&&p!==W||(n.status=Y),p===q||p===X)return 0===t.avail_out&&(n.last_flush=-1),c;if(p===G&&(e===d?r._tr_align(n):e!==u&&(r._tr_stored_block(n,0,0,!1),e===f&&($(n.head),0===n.lookahead&&(n.strstart=0,n.block_start=0,n.insert=0))),tt(t),0===t.avail_out))return n.last_flush=-1,c}return e!==_?c:n.wrap<=0?b:(2===n.wrap?(at(n,255&t.adler),at(n,t.adler>>8&255),at(n,t.adler>>16&255),at(n,t.adler>>24&255),at(n,255&t.total_in),at(n,t.total_in>>8&255),at(n,t.total_in>>16&255),at(n,t.total_in>>24&255)):(it(n,t.adler>>>16),it(n,65535&t.adler)),tt(t),n.wrap>0&&(n.wrap=-n.wrap),0!==n.pending?c:b)},a.deflateEnd=function(t){var e;return t&&t.state?(e=t.state.status)!==L&&e!==H&&e!==j&&e!==K&&e!==M&&e!==P&&e!==Y?Q(t,g):(t.state=null,e===P?Q(t,m):c):g},a.deflateSetDictionary=function(t,e){var a,i,r,o,l,h,d,f,_=e.length;if(!t||!t.state)return g;if(2===(o=(a=t.state).wrap)||1===o&&a.status!==L||a.lookahead)return g;for(1===o&&(t.adler=s(t.adler,e,_,0)),a.wrap=0,_>=a.w_size&&(0===o&&($(a.head),a.strstart=0,a.block_start=0,a.insert=0),f=new n.Buf8(a.w_size),n.arraySet(f,e,_-a.w_size,a.w_size,0),e=f,_=a.w_size),l=t.avail_in,h=t.next_in,d=t.input,t.avail_in=_,t.next_in=0,t.input=e,rt(a);a.lookahead>=I;){i=a.strstart,r=a.lookahead-(I-1);do{a.ins_h=(a.ins_h<<a.hash_shift^a.window[i+I-1])&a.hash_mask,a.prev[i&a.w_mask]=a.head[a.ins_h],a.head[a.ins_h]=i,i++}while(--r);a.strstart=i,a.lookahead=I-1,rt(a)}return a.strstart+=a.lookahead,a.block_start=a.strstart,a.insert=a.lookahead,a.lookahead=0,a.match_length=a.prev_length=I-1,a.match_available=0,t.next_in=h,t.input=d,t.avail_in=l,a.wrap=o,c},a.deflateInfo="pako deflate (from Nodeca project)"},{"../utils/common":3,"./adler32":5,"./crc32":7,"./messages":13,"./trees":14}],9:[function(t,e,a){"use strict";e.exports=function(){this.text=0,this.time=0,this.xflags=0,this.os=0,this.extra=null,this.extra_len=0,this.name="",this.comment="",this.hcrc=0,this.done=!1}},{}],10:[function(t,e,a){"use strict";e.exports=function(t,e){var a,i,n,r,s,o,l,h,d,f,_,u,c,b,g,m,w,p,v,k,y,x,z,B,S;a=t.state,i=t.next_in,B=t.input,n=i+(t.avail_in-5),r=t.next_out,S=t.output,s=r-(e-t.avail_out),o=r+(t.avail_out-257),l=a.dmax,h=a.wsize,d=a.whave,f=a.wnext,_=a.window,u=a.hold,c=a.bits,b=a.lencode,g=a.distcode,m=(1<<a.lenbits)-1,w=(1<<a.distbits)-1;t:do{c<15&&(u+=B[i++]<<c,c+=8,u+=B[i++]<<c,c+=8),p=b[u&m];e:for(;;){if(u>>>=v=p>>>24,c-=v,0===(v=p>>>16&255))S[r++]=65535&p;else{if(!(16&v)){if(0==(64&v)){p=b[(65535&p)+(u&(1<<v)-1)];continue e}if(32&v){a.mode=12;break t}t.msg="invalid literal/length code",a.mode=30;break t}k=65535&p,(v&=15)&&(c<v&&(u+=B[i++]<<c,c+=8),k+=u&(1<<v)-1,u>>>=v,c-=v),c<15&&(u+=B[i++]<<c,c+=8,u+=B[i++]<<c,c+=8),p=g[u&w];a:for(;;){if(u>>>=v=p>>>24,c-=v,!(16&(v=p>>>16&255))){if(0==(64&v)){p=g[(65535&p)+(u&(1<<v)-1)];continue a}t.msg="invalid distance code",a.mode=30;break t}if(y=65535&p,c<(v&=15)&&(u+=B[i++]<<c,(c+=8)<v&&(u+=B[i++]<<c,c+=8)),(y+=u&(1<<v)-1)>l){t.msg="invalid distance too far back",a.mode=30;break t}if(u>>>=v,c-=v,y>(v=r-s)){if((v=y-v)>d&&a.sane){t.msg="invalid distance too far back",a.mode=30;break t}if(x=0,z=_,0===f){if(x+=h-v,v<k){k-=v;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}}else if(f<v){if(x+=h+f-v,(v-=f)<k){k-=v;do{S[r++]=_[x++]}while(--v);if(x=0,f<k){k-=v=f;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}}}else if(x+=f-v,v<k){k-=v;do{S[r++]=_[x++]}while(--v);x=r-y,z=S}for(;k>2;)S[r++]=z[x++],S[r++]=z[x++],S[r++]=z[x++],k-=3;k&&(S[r++]=z[x++],k>1&&(S[r++]=z[x++]))}else{x=r-y;do{S[r++]=S[x++],S[r++]=S[x++],S[r++]=S[x++],k-=3}while(k>2);k&&(S[r++]=S[x++],k>1&&(S[r++]=S[x++]))}break}}break}}while(i<n&&r<o);i-=k=c>>3,u&=(1<<(c-=k<<3))-1,t.next_in=i,t.next_out=r,t.avail_in=i<n?n-i+5:5-(i-n),t.avail_out=r<o?o-r+257:257-(r-o),a.hold=u,a.bits=c}},{}],11:[function(t,e,a){"use strict";var i=t("../utils/common"),n=t("./adler32"),r=t("./crc32"),s=t("./inffast"),o=t("./inftrees"),l=0,h=1,d=2,f=4,_=5,u=6,c=0,b=1,g=2,m=-2,w=-3,p=-4,v=-5,k=8,y=1,x=2,z=3,B=4,S=5,E=6,A=7,Z=8,R=9,C=10,N=11,O=12,D=13,I=14,U=15,T=16,F=17,L=18,H=19,j=20,K=21,M=22,P=23,Y=24,q=25,G=26,X=27,W=28,J=29,Q=30,V=31,$=32,tt=852,et=592,at=15;function it(t){return(t>>>24&255)+(t>>>8&65280)+((65280&t)<<8)+((255&t)<<24)}function nt(){this.mode=0,this.last=!1,this.wrap=0,this.havedict=!1,this.flags=0,this.dmax=0,this.check=0,this.total=0,this.head=null,this.wbits=0,this.wsize=0,this.whave=0,this.wnext=0,this.window=null,this.hold=0,this.bits=0,this.length=0,this.offset=0,this.extra=0,this.lencode=null,this.distcode=null,this.lenbits=0,this.distbits=0,this.ncode=0,this.nlen=0,this.ndist=0,this.have=0,this.next=null,this.lens=new i.Buf16(320),this.work=new i.Buf16(288),this.lendyn=null,this.distdyn=null,this.sane=0,this.back=0,this.was=0}function rt(t){var e;return t&&t.state?(e=t.state,t.total_in=t.total_out=e.total=0,t.msg="",e.wrap&&(t.adler=1&e.wrap),e.mode=y,e.last=0,e.havedict=0,e.dmax=32768,e.head=null,e.hold=0,e.bits=0,e.lencode=e.lendyn=new i.Buf32(tt),e.distcode=e.distdyn=new i.Buf32(et),e.sane=1,e.back=-1,c):m}function st(t){var e;return t&&t.state?((e=t.state).wsize=0,e.whave=0,e.wnext=0,rt(t)):m}function ot(t,e){var a,i;return t&&t.state?(i=t.state,e<0?(a=0,e=-e):(a=1+(e>>4),e<48&&(e&=15)),e&&(e<8||e>15)?m:(null!==i.window&&i.wbits!==e&&(i.window=null),i.wrap=a,i.wbits=e,st(t))):m}function lt(t,e){var a,i;return t?(i=new nt,t.state=i,i.window=null,(a=ot(t,e))!==c&&(t.state=null),a):m}var ht,dt,ft=!0;function _t(t){if(ft){var e;for(ht=new i.Buf32(512),dt=new i.Buf32(32),e=0;e<144;)t.lens[e++]=8;for(;e<256;)t.lens[e++]=9;for(;e<280;)t.lens[e++]=7;for(;e<288;)t.lens[e++]=8;for(o(h,t.lens,0,288,ht,0,t.work,{bits:9}),e=0;e<32;)t.lens[e++]=5;o(d,t.lens,0,32,dt,0,t.work,{bits:5}),ft=!1}t.lencode=ht,t.lenbits=9,t.distcode=dt,t.distbits=5}function ut(t,e,a,n){var r,s=t.state;return null===s.window&&(s.wsize=1<<s.wbits,s.wnext=0,s.whave=0,s.window=new i.Buf8(s.wsize)),n>=s.wsize?(i.arraySet(s.window,e,a-s.wsize,s.wsize,0),s.wnext=0,s.whave=s.wsize):((r=s.wsize-s.wnext)>n&&(r=n),i.arraySet(s.window,e,a-n,r,s.wnext),(n-=r)?(i.arraySet(s.window,e,a-n,n,0),s.wnext=n,s.whave=s.wsize):(s.wnext+=r,s.wnext===s.wsize&&(s.wnext=0),s.whave<s.wsize&&(s.whave+=r))),0}a.inflateReset=st,a.inflateReset2=ot,a.inflateResetKeep=rt,a.inflateInit=function(t){return lt(t,at)},a.inflateInit2=lt,a.inflate=function(t,e){var a,tt,et,at,nt,rt,st,ot,lt,ht,dt,ft,ct,bt,gt,mt,wt,pt,vt,kt,yt,xt,zt,Bt,St=0,Et=new i.Buf8(4),At=[16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15];if(!t||!t.state||!t.output||!t.input&&0!==t.avail_in)return m;(a=t.state).mode===O&&(a.mode=D),nt=t.next_out,et=t.output,st=t.avail_out,at=t.next_in,tt=t.input,rt=t.avail_in,ot=a.hold,lt=a.bits,ht=rt,dt=st,xt=c;t:for(;;)switch(a.mode){case y:if(0===a.wrap){a.mode=D;break}for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(2&a.wrap&&35615===ot){a.check=0,Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0),ot=0,lt=0,a.mode=x;break}if(a.flags=0,a.head&&(a.head.done=!1),!(1&a.wrap)||(((255&ot)<<8)+(ot>>8))%31){t.msg="incorrect header check",a.mode=Q;break}if((15&ot)!==k){t.msg="unknown compression method",a.mode=Q;break}if(lt-=4,yt=8+(15&(ot>>>=4)),0===a.wbits)a.wbits=yt;else if(yt>a.wbits){t.msg="invalid window size",a.mode=Q;break}a.dmax=1<<yt,t.adler=a.check=1,a.mode=512&ot?C:O,ot=0,lt=0;break;case x:for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(a.flags=ot,(255&a.flags)!==k){t.msg="unknown compression method",a.mode=Q;break}if(57344&a.flags){t.msg="unknown header flags set",a.mode=Q;break}a.head&&(a.head.text=ot>>8&1),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0,a.mode=z;case z:for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.head&&(a.head.time=ot),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,Et[2]=ot>>>16&255,Et[3]=ot>>>24&255,a.check=r(a.check,Et,4,0)),ot=0,lt=0,a.mode=B;case B:for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.head&&(a.head.xflags=255&ot,a.head.os=ot>>8),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0,a.mode=S;case S:if(1024&a.flags){for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.length=ot,a.head&&(a.head.extra_len=ot),512&a.flags&&(Et[0]=255&ot,Et[1]=ot>>>8&255,a.check=r(a.check,Et,2,0)),ot=0,lt=0}else a.head&&(a.head.extra=null);a.mode=E;case E:if(1024&a.flags&&((ft=a.length)>rt&&(ft=rt),ft&&(a.head&&(yt=a.head.extra_len-a.length,a.head.extra||(a.head.extra=new Array(a.head.extra_len)),i.arraySet(a.head.extra,tt,at,ft,yt)),512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,a.length-=ft),a.length))break t;a.length=0,a.mode=A;case A:if(2048&a.flags){if(0===rt)break t;ft=0;do{yt=tt[at+ft++],a.head&&yt&&a.length<65536&&(a.head.name+=String.fromCharCode(yt))}while(yt&&ft<rt);if(512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,yt)break t}else a.head&&(a.head.name=null);a.length=0,a.mode=Z;case Z:if(4096&a.flags){if(0===rt)break t;ft=0;do{yt=tt[at+ft++],a.head&&yt&&a.length<65536&&(a.head.comment+=String.fromCharCode(yt))}while(yt&&ft<rt);if(512&a.flags&&(a.check=r(a.check,tt,ft,at)),rt-=ft,at+=ft,yt)break t}else a.head&&(a.head.comment=null);a.mode=R;case R:if(512&a.flags){for(;lt<16;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot!==(65535&a.check)){t.msg="header crc mismatch",a.mode=Q;break}ot=0,lt=0}a.head&&(a.head.hcrc=a.flags>>9&1,a.head.done=!0),t.adler=a.check=0,a.mode=O;break;case C:for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}t.adler=a.check=it(ot),ot=0,lt=0,a.mode=N;case N:if(0===a.havedict)return t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,g;t.adler=a.check=1,a.mode=O;case O:if(e===_||e===u)break t;case D:if(a.last){ot>>>=7&lt,lt-=7&lt,a.mode=X;break}for(;lt<3;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}switch(a.last=1&ot,lt-=1,3&(ot>>>=1)){case 0:a.mode=I;break;case 1:if(_t(a),a.mode=j,e===u){ot>>>=2,lt-=2;break t}break;case 2:a.mode=F;break;case 3:t.msg="invalid block type",a.mode=Q}ot>>>=2,lt-=2;break;case I:for(ot>>>=7&lt,lt-=7&lt;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if((65535&ot)!=(ot>>>16^65535)){t.msg="invalid stored block lengths",a.mode=Q;break}if(a.length=65535&ot,ot=0,lt=0,a.mode=U,e===u)break t;case U:a.mode=T;case T:if(ft=a.length){if(ft>rt&&(ft=rt),ft>st&&(ft=st),0===ft)break t;i.arraySet(et,tt,at,ft,nt),rt-=ft,at+=ft,st-=ft,nt+=ft,a.length-=ft;break}a.mode=O;break;case F:for(;lt<14;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(a.nlen=257+(31&ot),ot>>>=5,lt-=5,a.ndist=1+(31&ot),ot>>>=5,lt-=5,a.ncode=4+(15&ot),ot>>>=4,lt-=4,a.nlen>286||a.ndist>30){t.msg="too many length or distance symbols",a.mode=Q;break}a.have=0,a.mode=L;case L:for(;a.have<a.ncode;){for(;lt<3;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.lens[At[a.have++]]=7&ot,ot>>>=3,lt-=3}for(;a.have<19;)a.lens[At[a.have++]]=0;if(a.lencode=a.lendyn,a.lenbits=7,zt={bits:a.lenbits},xt=o(l,a.lens,0,19,a.lencode,0,a.work,zt),a.lenbits=zt.bits,xt){t.msg="invalid code lengths set",a.mode=Q;break}a.have=0,a.mode=H;case H:for(;a.have<a.nlen+a.ndist;){for(;mt=(St=a.lencode[ot&(1<<a.lenbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(wt<16)ot>>>=gt,lt-=gt,a.lens[a.have++]=wt;else{if(16===wt){for(Bt=gt+2;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot>>>=gt,lt-=gt,0===a.have){t.msg="invalid bit length repeat",a.mode=Q;break}yt=a.lens[a.have-1],ft=3+(3&ot),ot>>>=2,lt-=2}else if(17===wt){for(Bt=gt+3;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}lt-=gt,yt=0,ft=3+(7&(ot>>>=gt)),ot>>>=3,lt-=3}else{for(Bt=gt+7;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}lt-=gt,yt=0,ft=11+(127&(ot>>>=gt)),ot>>>=7,lt-=7}if(a.have+ft>a.nlen+a.ndist){t.msg="invalid bit length repeat",a.mode=Q;break}for(;ft--;)a.lens[a.have++]=yt}}if(a.mode===Q)break;if(0===a.lens[256]){t.msg="invalid code -- missing end-of-block",a.mode=Q;break}if(a.lenbits=9,zt={bits:a.lenbits},xt=o(h,a.lens,0,a.nlen,a.lencode,0,a.work,zt),a.lenbits=zt.bits,xt){t.msg="invalid literal/lengths set",a.mode=Q;break}if(a.distbits=6,a.distcode=a.distdyn,zt={bits:a.distbits},xt=o(d,a.lens,a.nlen,a.ndist,a.distcode,0,a.work,zt),a.distbits=zt.bits,xt){t.msg="invalid distances set",a.mode=Q;break}if(a.mode=j,e===u)break t;case j:a.mode=K;case K:if(rt>=6&&st>=258){t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,s(t,dt),nt=t.next_out,et=t.output,st=t.avail_out,at=t.next_in,tt=t.input,rt=t.avail_in,ot=a.hold,lt=a.bits,a.mode===O&&(a.back=-1);break}for(a.back=0;mt=(St=a.lencode[ot&(1<<a.lenbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(mt&&0==(240&mt)){for(pt=gt,vt=mt,kt=wt;mt=(St=a.lencode[kt+((ot&(1<<pt+vt)-1)>>pt)])>>>16&255,wt=65535&St,!(pt+(gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}ot>>>=pt,lt-=pt,a.back+=pt}if(ot>>>=gt,lt-=gt,a.back+=gt,a.length=wt,0===mt){a.mode=G;break}if(32&mt){a.back=-1,a.mode=O;break}if(64&mt){t.msg="invalid literal/length code",a.mode=Q;break}a.extra=15&mt,a.mode=M;case M:if(a.extra){for(Bt=a.extra;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.length+=ot&(1<<a.extra)-1,ot>>>=a.extra,lt-=a.extra,a.back+=a.extra}a.was=a.length,a.mode=P;case P:for(;mt=(St=a.distcode[ot&(1<<a.distbits)-1])>>>16&255,wt=65535&St,!((gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(0==(240&mt)){for(pt=gt,vt=mt,kt=wt;mt=(St=a.distcode[kt+((ot&(1<<pt+vt)-1)>>pt)])>>>16&255,wt=65535&St,!(pt+(gt=St>>>24)<=lt);){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}ot>>>=pt,lt-=pt,a.back+=pt}if(ot>>>=gt,lt-=gt,a.back+=gt,64&mt){t.msg="invalid distance code",a.mode=Q;break}a.offset=wt,a.extra=15&mt,a.mode=Y;case Y:if(a.extra){for(Bt=a.extra;lt<Bt;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}a.offset+=ot&(1<<a.extra)-1,ot>>>=a.extra,lt-=a.extra,a.back+=a.extra}if(a.offset>a.dmax){t.msg="invalid distance too far back",a.mode=Q;break}a.mode=q;case q:if(0===st)break t;if(ft=dt-st,a.offset>ft){if((ft=a.offset-ft)>a.whave&&a.sane){t.msg="invalid distance too far back",a.mode=Q;break}ft>a.wnext?(ft-=a.wnext,ct=a.wsize-ft):ct=a.wnext-ft,ft>a.length&&(ft=a.length),bt=a.window}else bt=et,ct=nt-a.offset,ft=a.length;ft>st&&(ft=st),st-=ft,a.length-=ft;do{et[nt++]=bt[ct++]}while(--ft);0===a.length&&(a.mode=K);break;case G:if(0===st)break t;et[nt++]=a.length,st--,a.mode=K;break;case X:if(a.wrap){for(;lt<32;){if(0===rt)break t;rt--,ot|=tt[at++]<<lt,lt+=8}if(dt-=st,t.total_out+=dt,a.total+=dt,dt&&(t.adler=a.check=a.flags?r(a.check,et,dt,nt-dt):n(a.check,et,dt,nt-dt)),dt=st,(a.flags?ot:it(ot))!==a.check){t.msg="incorrect data check",a.mode=Q;break}ot=0,lt=0}a.mode=W;case W:if(a.wrap&&a.flags){for(;lt<32;){if(0===rt)break t;rt--,ot+=tt[at++]<<lt,lt+=8}if(ot!==(4294967295&a.total)){t.msg="incorrect length check",a.mode=Q;break}ot=0,lt=0}a.mode=J;case J:xt=b;break t;case Q:xt=w;break t;case V:return p;case $:default:return m}return t.next_out=nt,t.avail_out=st,t.next_in=at,t.avail_in=rt,a.hold=ot,a.bits=lt,(a.wsize||dt!==t.avail_out&&a.mode<Q&&(a.mode<X||e!==f))&&ut(t,t.output,t.next_out,dt-t.avail_out)?(a.mode=V,p):(ht-=t.avail_in,dt-=t.avail_out,t.total_in+=ht,t.total_out+=dt,a.total+=dt,a.wrap&&dt&&(t.adler=a.check=a.flags?r(a.check,et,dt,t.next_out-dt):n(a.check,et,dt,t.next_out-dt)),t.data_type=a.bits+(a.last?64:0)+(a.mode===O?128:0)+(a.mode===j||a.mode===U?256:0),(0===ht&&0===dt||e===f)&&xt===c&&(xt=v),xt)},a.inflateEnd=function(t){if(!t||!t.state)return m;var e=t.state;return e.window&&(e.window=null),t.state=null,c},a.inflateGetHeader=function(t,e){var a;return t&&t.state?0==(2&(a=t.state).wrap)?m:(a.head=e,e.done=!1,c):m},a.inflateSetDictionary=function(t,e){var a,i=e.length;return t&&t.state?0!==(a=t.state).wrap&&a.mode!==N?m:a.mode===N&&n(1,e,i,0)!==a.check?w:ut(t,e,i,i)?(a.mode=V,p):(a.havedict=1,c):m},a.inflateInfo="pako inflate (from Nodeca project)"},{"../utils/common":3,"./adler32":5,"./crc32":7,"./inffast":10,"./inftrees":12}],12:[function(t,e,a){"use strict";var i=t("../utils/common"),n=[3,4,5,6,7,8,9,10,11,13,15,17,19,23,27,31,35,43,51,59,67,83,99,115,131,163,195,227,258,0,0],r=[16,16,16,16,16,16,16,16,17,17,17,17,18,18,18,18,19,19,19,19,20,20,20,20,21,21,21,21,16,72,78],s=[1,2,3,4,5,7,9,13,17,25,33,49,65,97,129,193,257,385,513,769,1025,1537,2049,3073,4097,6145,8193,12289,16385,24577,0,0],o=[16,16,16,16,17,17,18,18,19,19,20,20,21,21,22,22,23,23,24,24,25,25,26,26,27,27,28,28,29,29,64,64];e.exports=function(t,e,a,l,h,d,f,_){var u,c,b,g,m,w,p,v,k,y=_.bits,x=0,z=0,B=0,S=0,E=0,A=0,Z=0,R=0,C=0,N=0,O=null,D=0,I=new i.Buf16(16),U=new i.Buf16(16),T=null,F=0;for(x=0;x<=15;x++)I[x]=0;for(z=0;z<l;z++)I[e[a+z]]++;for(E=y,S=15;S>=1&&0===I[S];S--);if(E>S&&(E=S),0===S)return h[d++]=20971520,h[d++]=20971520,_.bits=1,0;for(B=1;B<S&&0===I[B];B++);for(E<B&&(E=B),R=1,x=1;x<=15;x++)if(R<<=1,(R-=I[x])<0)return-1;if(R>0&&(0===t||1!==S))return-1;for(U[1]=0,x=1;x<15;x++)U[x+1]=U[x]+I[x];for(z=0;z<l;z++)0!==e[a+z]&&(f[U[e[a+z]]++]=z);if(0===t?(O=T=f,w=19):1===t?(O=n,D-=257,T=r,F-=257,w=256):(O=s,T=o,w=-1),N=0,z=0,x=B,m=d,A=E,Z=0,b=-1,g=(C=1<<E)-1,1===t&&C>852||2===t&&C>592)return 1;for(;;){p=x-Z,f[z]<w?(v=0,k=f[z]):f[z]>w?(v=T[F+f[z]],k=O[D+f[z]]):(v=96,k=0),u=1<<x-Z,B=c=1<<A;do{h[m+(N>>Z)+(c-=u)]=p<<24|v<<16|k|0}while(0!==c);for(u=1<<x-1;N&u;)u>>=1;if(0!==u?(N&=u-1,N+=u):N=0,z++,0==--I[x]){if(x===S)break;x=e[a+f[z]]}if(x>E&&(N&g)!==b){for(0===Z&&(Z=E),m+=B,R=1<<(A=x-Z);A+Z<S&&!((R-=I[A+Z])<=0);)A++,R<<=1;if(C+=1<<A,1===t&&C>852||2===t&&C>592)return 1;h[b=N&g]=E<<24|A<<16|m-d|0}}return 0!==N&&(h[m+N]=x-Z<<24|64<<16|0),_.bits=E,0}},{"../utils/common":3}],13:[function(t,e,a){"use strict";e.exports={2:"need dictionary",1:"stream end",0:"","-1":"file error","-2":"stream error","-3":"data error","-4":"insufficient memory","-5":"buffer error","-6":"incompatible version"}},{}],14:[function(t,e,a){"use strict";var i=t("../utils/common"),n=4,r=0,s=1,o=2;function l(t){for(var e=t.length;--e>=0;)t[e]=0}var h=0,d=1,f=2,_=29,u=256,c=u+1+_,b=30,g=19,m=2*c+1,w=15,p=16,v=7,k=256,y=16,x=17,z=18,B=[0,0,0,0,0,0,0,0,1,1,1,1,2,2,2,2,3,3,3,3,4,4,4,4,5,5,5,5,0],S=[0,0,0,0,1,1,2,2,3,3,4,4,5,5,6,6,7,7,8,8,9,9,10,10,11,11,12,12,13,13],E=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,2,3,7],A=[16,17,18,0,8,7,9,6,10,5,11,4,12,3,13,2,14,1,15],Z=new Array(2*(c+2));l(Z);var R=new Array(2*b);l(R);var C=new Array(512);l(C);var N=new Array(256);l(N);var O=new Array(_);l(O);var D,I,U,T=new Array(b);function F(t,e,a,i,n){this.static_tree=t,this.extra_bits=e,this.extra_base=a,this.elems=i,this.max_length=n,this.has_stree=t&&t.length}function L(t,e){this.dyn_tree=t,this.max_code=0,this.stat_desc=e}function H(t){return t<256?C[t]:C[256+(t>>>7)]}function j(t,e){t.pending_buf[t.pending++]=255&e,t.pending_buf[t.pending++]=e>>>8&255}function K(t,e,a){t.bi_valid>p-a?(t.bi_buf|=e<<t.bi_valid&65535,j(t,t.bi_buf),t.bi_buf=e>>p-t.bi_valid,t.bi_valid+=a-p):(t.bi_buf|=e<<t.bi_valid&65535,t.bi_valid+=a)}function M(t,e,a){K(t,a[2*e],a[2*e+1])}function P(t,e){var a=0;do{a|=1&t,t>>>=1,a<<=1}while(--e>0);return a>>>1}function Y(t,e,a){var i,n,r=new Array(w+1),s=0;for(i=1;i<=w;i++)r[i]=s=s+a[i-1]<<1;for(n=0;n<=e;n++){var o=t[2*n+1];0!==o&&(t[2*n]=P(r[o]++,o))}}function q(t){var e;for(e=0;e<c;e++)t.dyn_ltree[2*e]=0;for(e=0;e<b;e++)t.dyn_dtree[2*e]=0;for(e=0;e<g;e++)t.bl_tree[2*e]=0;t.dyn_ltree[2*k]=1,t.opt_len=t.static_len=0,t.last_lit=t.matches=0}function G(t){t.bi_valid>8?j(t,t.bi_buf):t.bi_valid>0&&(t.pending_buf[t.pending++]=t.bi_buf),t.bi_buf=0,t.bi_valid=0}function X(t,e,a,i){var n=2*e,r=2*a;return t[n]<t[r]||t[n]===t[r]&&i[e]<=i[a]}function W(t,e,a){for(var i=t.heap[a],n=a<<1;n<=t.heap_len&&(n<t.heap_len&&X(e,t.heap[n+1],t.heap[n],t.depth)&&n++,!X(e,i,t.heap[n],t.depth));)t.heap[a]=t.heap[n],a=n,n<<=1;t.heap[a]=i}function J(t,e,a){var i,n,r,s,o=0;if(0!==t.last_lit)do{i=t.pending_buf[t.d_buf+2*o]<<8|t.pending_buf[t.d_buf+2*o+1],n=t.pending_buf[t.l_buf+o],o++,0===i?M(t,n,e):(M(t,(r=N[n])+u+1,e),0!==(s=B[r])&&K(t,n-=O[r],s),M(t,r=H(--i),a),0!==(s=S[r])&&K(t,i-=T[r],s))}while(o<t.last_lit);M(t,k,e)}function Q(t,e){var a,i,n,r=e.dyn_tree,s=e.stat_desc.static_tree,o=e.stat_desc.has_stree,l=e.stat_desc.elems,h=-1;for(t.heap_len=0,t.heap_max=m,a=0;a<l;a++)0!==r[2*a]?(t.heap[++t.heap_len]=h=a,t.depth[a]=0):r[2*a+1]=0;for(;t.heap_len<2;)r[2*(n=t.heap[++t.heap_len]=h<2?++h:0)]=1,t.depth[n]=0,t.opt_len--,o&&(t.static_len-=s[2*n+1]);for(e.max_code=h,a=t.heap_len>>1;a>=1;a--)W(t,r,a);n=l;do{a=t.heap[1],t.heap[1]=t.heap[t.heap_len--],W(t,r,1),i=t.heap[1],t.heap[--t.heap_max]=a,t.heap[--t.heap_max]=i,r[2*n]=r[2*a]+r[2*i],t.depth[n]=(t.depth[a]>=t.depth[i]?t.depth[a]:t.depth[i])+1,r[2*a+1]=r[2*i+1]=n,t.heap[1]=n++,W(t,r,1)}while(t.heap_len>=2);t.heap[--t.heap_max]=t.heap[1],function(t,e){var a,i,n,r,s,o,l=e.dyn_tree,h=e.max_code,d=e.stat_desc.static_tree,f=e.stat_desc.has_stree,_=e.stat_desc.extra_bits,u=e.stat_desc.extra_base,c=e.stat_desc.max_length,b=0;for(r=0;r<=w;r++)t.bl_count[r]=0;for(l[2*t.heap[t.heap_max]+1]=0,a=t.heap_max+1;a<m;a++)(r=l[2*l[2*(i=t.heap[a])+1]+1]+1)>c&&(r=c,b++),l[2*i+1]=r,i>h||(t.bl_count[r]++,s=0,i>=u&&(s=_[i-u]),o=l[2*i],t.opt_len+=o*(r+s),f&&(t.static_len+=o*(d[2*i+1]+s)));if(0!==b){do{for(r=c-1;0===t.bl_count[r];)r--;t.bl_count[r]--,t.bl_count[r+1]+=2,t.bl_count[c]--,b-=2}while(b>0);for(r=c;0!==r;r--)for(i=t.bl_count[r];0!==i;)(n=t.heap[--a])>h||(l[2*n+1]!==r&&(t.opt_len+=(r-l[2*n+1])*l[2*n],l[2*n+1]=r),i--)}}(t,e),Y(r,h,t.bl_count)}function V(t,e,a){var i,n,r=-1,s=e[1],o=0,l=7,h=4;for(0===s&&(l=138,h=3),e[2*(a+1)+1]=65535,i=0;i<=a;i++)n=s,s=e[2*(i+1)+1],++o<l&&n===s||(o<h?t.bl_tree[2*n]+=o:0!==n?(n!==r&&t.bl_tree[2*n]++,t.bl_tree[2*y]++):o<=10?t.bl_tree[2*x]++:t.bl_tree[2*z]++,o=0,r=n,0===s?(l=138,h=3):n===s?(l=6,h=3):(l=7,h=4))}function $(t,e,a){var i,n,r=-1,s=e[1],o=0,l=7,h=4;for(0===s&&(l=138,h=3),i=0;i<=a;i++)if(n=s,s=e[2*(i+1)+1],!(++o<l&&n===s)){if(o<h)do{M(t,n,t.bl_tree)}while(0!=--o);else 0!==n?(n!==r&&(M(t,n,t.bl_tree),o--),M(t,y,t.bl_tree),K(t,o-3,2)):o<=10?(M(t,x,t.bl_tree),K(t,o-3,3)):(M(t,z,t.bl_tree),K(t,o-11,7));o=0,r=n,0===s?(l=138,h=3):n===s?(l=6,h=3):(l=7,h=4)}}l(T);var tt=!1;function et(t,e,a,n){K(t,(h<<1)+(n?1:0),3),function(t,e,a,n){G(t),n&&(j(t,a),j(t,~a)),i.arraySet(t.pending_buf,t.window,e,a,t.pending),t.pending+=a}(t,e,a,!0)}a._tr_init=function(t){tt||(function(){var t,e,a,i,n,r=new Array(w+1);for(a=0,i=0;i<_-1;i++)for(O[i]=a,t=0;t<1<<B[i];t++)N[a++]=i;for(N[a-1]=i,n=0,i=0;i<16;i++)for(T[i]=n,t=0;t<1<<S[i];t++)C[n++]=i;for(n>>=7;i<b;i++)for(T[i]=n<<7,t=0;t<1<<S[i]-7;t++)C[256+n++]=i;for(e=0;e<=w;e++)r[e]=0;for(t=0;t<=143;)Z[2*t+1]=8,t++,r[8]++;for(;t<=255;)Z[2*t+1]=9,t++,r[9]++;for(;t<=279;)Z[2*t+1]=7,t++,r[7]++;for(;t<=287;)Z[2*t+1]=8,t++,r[8]++;for(Y(Z,c+1,r),t=0;t<b;t++)R[2*t+1]=5,R[2*t]=P(t,5);D=new F(Z,B,u+1,c,w),I=new F(R,S,0,b,w),U=new F(new Array(0),E,0,g,v)}(),tt=!0),t.l_desc=new L(t.dyn_ltree,D),t.d_desc=new L(t.dyn_dtree,I),t.bl_desc=new L(t.bl_tree,U),t.bi_buf=0,t.bi_valid=0,q(t)},a._tr_stored_block=et,a._tr_flush_block=function(t,e,a,i){var l,h,_=0;t.level>0?(t.strm.data_type===o&&(t.strm.data_type=function(t){var e,a=4093624447;for(e=0;e<=31;e++,a>>>=1)if(1&a&&0!==t.dyn_ltree[2*e])return r;if(0!==t.dyn_ltree[18]||0!==t.dyn_ltree[20]||0!==t.dyn_ltree[26])return s;for(e=32;e<u;e++)if(0!==t.dyn_ltree[2*e])return s;return r}(t)),Q(t,t.l_desc),Q(t,t.d_desc),_=function(t){var e;for(V(t,t.dyn_ltree,t.l_desc.max_code),V(t,t.dyn_dtree,t.d_desc.max_code),Q(t,t.bl_desc),e=g-1;e>=3&&0===t.bl_tree[2*A[e]+1];e--);return t.opt_len+=3*(e+1)+5+5+4,e}(t),l=t.opt_len+3+7>>>3,(h=t.static_len+3+7>>>3)<=l&&(l=h)):l=h=a+5,a+4<=l&&-1!==e?et(t,e,a,i):t.strategy===n||h===l?(K(t,(d<<1)+(i?1:0),3),J(t,Z,R)):(K(t,(f<<1)+(i?1:0),3),function(t,e,a,i){var n;for(K(t,e-257,5),K(t,a-1,5),K(t,i-4,4),n=0;n<i;n++)K(t,t.bl_tree[2*A[n]+1],3);$(t,t.dyn_ltree,e-1),$(t,t.dyn_dtree,a-1)}(t,t.l_desc.max_code+1,t.d_desc.max_code+1,_+1),J(t,t.dyn_ltree,t.dyn_dtree)),q(t),i&&G(t)},a._tr_tally=function(t,e,a){return t.pending_buf[t.d_buf+2*t.last_lit]=e>>>8&255,t.pending_buf[t.d_buf+2*t.last_lit+1]=255&e,t.pending_buf[t.l_buf+t.last_lit]=255&a,t.last_lit++,0===e?t.dyn_ltree[2*a]++:(t.matches++,e--,t.dyn_ltree[2*(N[a]+u+1)]++,t.dyn_dtree[2*H(e)]++),t.last_lit===t.lit_bufsize-1},a._tr_align=function(t){K(t,d<<1,3),M(t,k,Z),function(t){16===t.bi_valid?(j(t,t.bi_buf),t.bi_buf=0,t.bi_valid=0):t.bi_valid>=8&&(t.pending_buf[t.pending++]=255&t.bi_buf,t.bi_buf>>=8,t.bi_valid-=8)}(t)}},{"../utils/common":3}],15:[function(t,e,a){"use strict";e.exports=function(){this.input=null,this.next_in=0,this.avail_in=0,this.total_in=0,this.output=null,this.next_out=0,this.avail_out=0,this.total_out=0,this.msg="",this.state=null,this.data_type=2,this.adler=0}},{}],"/":[function(t,e,a){"use strict";var i={};(0,t("./lib/utils/common").assign)(i,t("./lib/deflate"),t("./lib/inflate"),t("./lib/zlib/constants")),e.exports=i},{"./lib/deflate":1,"./lib/inflate":2,"./lib/utils/common":3,"./lib/zlib/constants":6}]},{},[])("/")});
139
+ </script>
140
+ <script>
141
+ !function(){var e={};"object"==typeof module?module.exports=e:window.UPNG=e,function(e,r){e.toRGBA8=function(r){var t=r.width,n=r.height;if(null==r.tabs.acTL)return[e.toRGBA8.decodeImage(r.data,t,n,r).buffer];var i=[];null==r.frames[0].data&&(r.frames[0].data=r.data);for(var a,f=new Uint8Array(t*n*4),o=0;o<r.frames.length;o++){var s=r.frames[o],l=s.rect.x,c=s.rect.y,u=s.rect.width,d=s.rect.height,h=e.toRGBA8.decodeImage(s.data,u,d,r);if(0==o?a=h:0==s.blend?e._copyTile(h,u,d,a,t,n,l,c,0):1==s.blend&&e._copyTile(h,u,d,a,t,n,l,c,1),i.push(a.buffer),a=a.slice(0),0==s.dispose);else if(1==s.dispose)e._copyTile(f,u,d,a,t,n,l,c,0);else if(2==s.dispose){for(var v=o-1;2==r.frames[v].dispose;)v--;a=new Uint8Array(i[v]).slice(0)}}return i},e.toRGBA8.decodeImage=function(r,t,n,i){var a=t*n,f=e.decode._getBPP(i),o=Math.ceil(t*f/8),s=new Uint8Array(4*a),l=new Uint32Array(s.buffer),c=i.ctype,u=i.depth,d=e._bin.readUshort;if(6==c){var h=a<<2;if(8==u)for(var v=0;v<h;v++)s[v]=r[v];if(16==u)for(v=0;v<h;v++)s[v]=r[v<<1]}else if(2==c){var p=i.tabs.tRNS,b=-1,g=-1,m=-1;if(p&&(b=p[0],g=p[1],m=p[2]),8==u)for(v=0;v<a;v++){var y=3*v;s[M=v<<2]=r[y],s[M+1]=r[y+1],s[M+2]=r[y+2],s[M+3]=255,-1!=b&&r[y]==b&&r[y+1]==g&&r[y+2]==m&&(s[M+3]=0)}if(16==u)for(v=0;v<a;v++){y=6*v;s[M=v<<2]=r[y],s[M+1]=r[y+2],s[M+2]=r[y+4],s[M+3]=255,-1!=b&&d(r,y)==b&&d(r,y+2)==g&&d(r,y+4)==m&&(s[M+3]=0)}}else if(3==c){var w=i.tabs.PLTE,A=i.tabs.tRNS,U=A?A.length:0;if(1==u)for(var _=0;_<n;_++){var q=_*o,I=_*t;for(v=0;v<t;v++){var M=I+v<<2,T=3*(z=r[q+(v>>3)]>>7-((7&v)<<0)&1);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}}if(2==u)for(_=0;_<n;_++)for(q=_*o,I=_*t,v=0;v<t;v++){M=I+v<<2,T=3*(z=r[q+(v>>2)]>>6-((3&v)<<1)&3);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}if(4==u)for(_=0;_<n;_++)for(q=_*o,I=_*t,v=0;v<t;v++){M=I+v<<2,T=3*(z=r[q+(v>>1)]>>4-((1&v)<<2)&15);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}if(8==u)for(v=0;v<a;v++){var z;M=v<<2,T=3*(z=r[v]);s[M]=w[T],s[M+1]=w[T+1],s[M+2]=w[T+2],s[M+3]=z<U?A[z]:255}}else if(4==c){if(8==u)for(v=0;v<a;v++){M=v<<2;var R=r[N=v<<1];s[M]=R,s[M+1]=R,s[M+2]=R,s[M+3]=r[N+1]}if(16==u)for(v=0;v<a;v++){var N;M=v<<2,R=r[N=v<<2];s[M]=R,s[M+1]=R,s[M+2]=R,s[M+3]=r[N+2]}}else if(0==c){b=i.tabs.tRNS?i.tabs.tRNS:-1;if(1==u)for(v=0;v<a;v++){var L=(R=255*(r[v>>3]>>7-(7&v)&1))==255*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(2==u)for(v=0;v<a;v++){L=(R=85*(r[v>>2]>>6-((3&v)<<1)&3))==85*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(4==u)for(v=0;v<a;v++){L=(R=17*(r[v>>1]>>4-((1&v)<<2)&15))==17*b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(8==u)for(v=0;v<a;v++){L=(R=r[v])==b?0:255;l[v]=L<<24|R<<16|R<<8|R}if(16==u)for(v=0;v<a;v++){R=r[v<<1],L=d(r,v<<1)==b?0:255;l[v]=L<<24|R<<16|R<<8|R}}return s},e.decode=function(r){for(var t,n=new Uint8Array(r),i=8,a=e._bin,f=a.readUshort,o=a.readUint,s={tabs:{},frames:[]},l=new Uint8Array(n.length),c=0,u=0,d=[137,80,78,71,13,10,26,10],h=0;h<8;h++)if(n[h]!=d[h])throw"The input is not a PNG file!";for(;i<n.length;){var v=a.readUint(n,i);i+=4;var p=a.readASCII(n,i,4);if(i+=4,"IHDR"==p)e.decode._IHDR(n,i,s);else if("IDAT"==p){for(h=0;h<v;h++)l[c+h]=n[i+h];c+=v}else if("acTL"==p)s.tabs[p]={num_frames:o(n,i),num_plays:o(n,i+4)},t=new Uint8Array(n.length);else if("fcTL"==p){var b;if(0!=u)(b=s.frames[s.frames.length-1]).data=e.decode._decompress(s,t.slice(0,u),b.rect.width,b.rect.height),u=0;var g={x:o(n,i+12),y:o(n,i+16),width:o(n,i+4),height:o(n,i+8)},m=f(n,i+22);m=f(n,i+20)/(0==m?100:m);var y={rect:g,delay:Math.round(1e3*m),dispose:n[i+24],blend:n[i+25]};s.frames.push(y)}else if("fdAT"==p){for(h=0;h<v-4;h++)t[u+h]=n[i+h+4];u+=v-4}else if("pHYs"==p)s.tabs[p]=[a.readUint(n,i),a.readUint(n,i+4),n[i+8]];else if("cHRM"==p){s.tabs[p]=[];for(h=0;h<8;h++)s.tabs[p].push(a.readUint(n,i+4*h))}else if("tEXt"==p){null==s.tabs[p]&&(s.tabs[p]={});var w=a.nextZero(n,i),A=a.readASCII(n,i,w-i),U=a.readASCII(n,w+1,i+v-w-1);s.tabs[p][A]=U}else if("iTXt"==p){null==s.tabs[p]&&(s.tabs[p]={});w=0;var _=i;w=a.nextZero(n,_);A=a.readASCII(n,_,w-_),n[_=w+1],n[_+1];_+=2,w=a.nextZero(n,_);a.readASCII(n,_,w-_);_=w+1,w=a.nextZero(n,_);a.readUTF8(n,_,w-_);_=w+1;U=a.readUTF8(n,_,v-(_-i));s.tabs[p][A]=U}else if("PLTE"==p)s.tabs[p]=a.readBytes(n,i,v);else if("hIST"==p){var q=s.tabs.PLTE.length/3;s.tabs[p]=[];for(h=0;h<q;h++)s.tabs[p].push(f(n,i+2*h))}else if("tRNS"==p)3==s.ctype?s.tabs[p]=a.readBytes(n,i,v):0==s.ctype?s.tabs[p]=f(n,i):2==s.ctype&&(s.tabs[p]=[f(n,i),f(n,i+2),f(n,i+4)]);else if("gAMA"==p)s.tabs[p]=a.readUint(n,i)/1e5;else if("sRGB"==p)s.tabs[p]=n[i];else if("bKGD"==p)0==s.ctype||4==s.ctype?s.tabs[p]=[f(n,i)]:2==s.ctype||6==s.ctype?s.tabs[p]=[f(n,i),f(n,i+2),f(n,i+4)]:3==s.ctype&&(s.tabs[p]=n[i]);else if("IEND"==p)break;i+=v;a.readUint(n,i);i+=4}0!=u&&((b=s.frames[s.frames.length-1]).data=e.decode._decompress(s,t.slice(0,u),b.rect.width,b.rect.height),u=0);return s.data=e.decode._decompress(s,l,s.width,s.height),delete s.compress,delete s.interlace,delete s.filter,s},e.decode._decompress=function(r,t,n,i){return 0==r.compress&&(t=e.decode._inflate(t)),0==r.interlace?t=e.decode._filterZero(t,r,0,n,i):1==r.interlace&&(t=e.decode._readInterlace(t,r)),t},e.decode._inflate=function(e){return r.inflate(e)},e.decode._readInterlace=function(r,t){for(var n=t.width,i=t.height,a=e.decode._getBPP(t),f=a>>3,o=Math.ceil(n*a/8),s=new Uint8Array(i*o),l=0,c=[0,0,4,0,2,0,1],u=[0,4,0,2,0,1,0],d=[8,8,8,4,4,2,2],h=[8,8,4,4,2,2,1],v=0;v<7;){for(var p=d[v],b=h[v],g=0,m=0,y=c[v];y<i;)y+=p,m++;for(var w=u[v];w<n;)w+=b,g++;var A=Math.ceil(g*a/8);e.decode._filterZero(r,t,l,g,m);for(var U=0,_=c[v];_<i;){for(var q=u[v],I=l+U*A<<3;q<n;){var M;if(1==a)M=(M=r[I>>3])>>7-(7&I)&1,s[_*o+(q>>3)]|=M<<7-((3&q)<<0);if(2==a)M=(M=r[I>>3])>>6-(7&I)&3,s[_*o+(q>>2)]|=M<<6-((3&q)<<1);if(4==a)M=(M=r[I>>3])>>4-(7&I)&15,s[_*o+(q>>1)]|=M<<4-((1&q)<<2);if(a>=8)for(var T=_*o+q*f,z=0;z<f;z++)s[T+z]=r[(I>>3)+z];I+=a,q+=b}U++,_+=p}g*m!=0&&(l+=m*(1+A)),v+=1}return s},e.decode._getBPP=function(e){return[1,null,3,1,2,null,4][e.ctype]*e.depth},e.decode._filterZero=function(r,t,n,i,a){var f=e.decode._getBPP(t),o=Math.ceil(i*f/8),s=e.decode._paeth;f=Math.ceil(f/8);for(var l=0;l<a;l++){var c=n+l*o,u=c+l+1,d=r[u-1];if(0==d)for(var h=0;h<o;h++)r[c+h]=r[u+h];else if(1==d){for(h=0;h<f;h++)r[c+h]=r[u+h];for(h=f;h<o;h++)r[c+h]=r[u+h]+r[c+h-f]&255}else if(0==l){for(h=0;h<f;h++)r[c+h]=r[u+h];if(2==d)for(h=f;h<o;h++)r[c+h]=255&r[u+h];if(3==d)for(h=f;h<o;h++)r[c+h]=r[u+h]+(r[c+h-f]>>1)&255;if(4==d)for(h=f;h<o;h++)r[c+h]=r[u+h]+s(r[c+h-f],0,0)&255}else{if(2==d)for(h=0;h<o;h++)r[c+h]=r[u+h]+r[c+h-o]&255;if(3==d){for(h=0;h<f;h++)r[c+h]=r[u+h]+(r[c+h-o]>>1)&255;for(h=f;h<o;h++)r[c+h]=r[u+h]+(r[c+h-o]+r[c+h-f]>>1)&255}if(4==d){for(h=0;h<f;h++)r[c+h]=r[u+h]+s(0,r[c+h-o],0)&255;for(h=f;h<o;h++)r[c+h]=r[u+h]+s(r[c+h-f],r[c+h-o],r[c+h-f-o])&255}}}return r},e.decode._paeth=function(e,r,t){var n=e+r-t,i=Math.abs(n-e),a=Math.abs(n-r),f=Math.abs(n-t);return i<=a&&i<=f?e:a<=f?r:t},e.decode._IHDR=function(r,t,n){var i=e._bin;n.width=i.readUint(r,t),t+=4,n.height=i.readUint(r,t),t+=4,n.depth=r[t],t++,n.ctype=r[t],t++,n.compress=r[t],t++,n.filter=r[t],t++,n.interlace=r[t],t++},e._bin={nextZero:function(e,r){for(;0!=e[r];)r++;return r},readUshort:function(e,r){return e[r]<<8|e[r+1]},writeUshort:function(e,r,t){e[r]=t>>8&255,e[r+1]=255&t},readUint:function(e,r){return 16777216*e[r]+(e[r+1]<<16|e[r+2]<<8|e[r+3])},writeUint:function(e,r,t){e[r]=t>>24&255,e[r+1]=t>>16&255,e[r+2]=t>>8&255,e[r+3]=255&t},readASCII:function(e,r,t){for(var n="",i=0;i<t;i++)n+=String.fromCharCode(e[r+i]);return n},writeASCII:function(e,r,t){for(var n=0;n<t.length;n++)e[r+n]=t.charCodeAt(n)},readBytes:function(e,r,t){for(var n=[],i=0;i<t;i++)n.push(e[r+i]);return n},pad:function(e){return e.length<2?"0"+e:e},readUTF8:function(r,t,n){for(var i,a="",f=0;f<n;f++)a+="%"+e._bin.pad(r[t+f].toString(16));try{i=decodeURIComponent(a)}catch(i){return e._bin.readASCII(r,t,n)}return i}},e._copyTile=function(e,r,t,n,i,a,f,o,s){for(var l=Math.min(r,i),c=Math.min(t,a),u=0,d=0,h=0;h<c;h++)for(var v=0;v<l;v++)if(f>=0&&o>=0?(u=h*r+v<<2,d=(o+h)*i+f+v<<2):(u=(-o+h)*r-f+v<<2,d=h*i+v<<2),0==s)n[d]=e[u],n[d+1]=e[u+1],n[d+2]=e[u+2],n[d+3]=e[u+3];else if(1==s){var p=e[u+3]*(1/255),b=e[u]*p,g=e[u+1]*p,m=e[u+2]*p,y=n[d+3]*(1/255),w=n[d]*y,A=n[d+1]*y,U=n[d+2]*y,_=1-p,q=p+y*_,I=0==q?0:1/q;n[d+3]=255*q,n[d+0]=(b+w*_)*I,n[d+1]=(g+A*_)*I,n[d+2]=(m+U*_)*I}else if(2==s){p=e[u+3],b=e[u],g=e[u+1],m=e[u+2],y=n[d+3],w=n[d],A=n[d+1],U=n[d+2];p==y&&b==w&&g==A&&m==U?(n[d]=0,n[d+1]=0,n[d+2]=0,n[d+3]=0):(n[d]=b,n[d+1]=g,n[d+2]=m,n[d+3]=p)}else if(3==s){p=e[u+3],b=e[u],g=e[u+1],m=e[u+2],y=n[d+3],w=n[d],A=n[d+1],U=n[d+2];if(p==y&&b==w&&g==A&&m==U)continue;if(p<220&&y>20)return!1}return!0},e.encode=function(r,t,n,i,a,f){null==i&&(i=0),null==f&&(f=!1);var o=e.encode.compress(r,t,n,i,!1,f);return e.encode.compressPNG(o,-1),e.encode._main(o,t,n,a)},e.encodeLL=function(r,t,n,i,a,f,o){for(var s={ctype:0+(1==i?0:2)+(0==a?0:4),depth:f,frames:[]},l=(i+a)*f,c=l*t,u=0;u<r.length;u++)s.frames.push({rect:{x:0,y:0,width:t,height:n},img:new Uint8Array(r[u]),blend:0,dispose:1,bpp:Math.ceil(l/8),bpl:Math.ceil(c/8)});return e.encode.compressPNG(s,4),e.encode._main(s,t,n,o)},e.encode._main=function(r,t,n,i){var a=e.crc.crc,f=e._bin.writeUint,o=e._bin.writeUshort,s=e._bin.writeASCII,l=8,c=r.frames.length>1,u=!1,d=46+(c?20:0);if(3==r.ctype){for(var h=r.plte.length,v=0;v<h;v++)r.plte[v]>>>24!=255&&(u=!0);d+=8+3*h+4+(u?8+1*h+4:0)}for(var p=0;p<r.frames.length;p++){c&&(d+=38),d+=(q=r.frames[p]).cimg.length+12,0!=p&&(d+=4)}d+=12;var b=new Uint8Array(d),g=[137,80,78,71,13,10,26,10];for(v=0;v<8;v++)b[v]=g[v];if(f(b,l,13),s(b,l+=4,"IHDR"),f(b,l+=4,t),f(b,l+=4,n),b[l+=4]=r.depth,b[++l]=r.ctype,b[++l]=0,b[++l]=0,b[++l]=0,f(b,++l,a(b,l-17,17)),f(b,l+=4,1),s(b,l+=4,"sRGB"),b[l+=4]=1,f(b,++l,a(b,l-5,5)),l+=4,c&&(f(b,l,8),s(b,l+=4,"acTL"),f(b,l+=4,r.frames.length),f(b,l+=4,0),f(b,l+=4,a(b,l-12,12)),l+=4),3==r.ctype){f(b,l,3*(h=r.plte.length)),s(b,l+=4,"PLTE"),l+=4;for(v=0;v<h;v++){var m=3*v,y=r.plte[v],w=255&y,A=y>>>8&255,U=y>>>16&255;b[l+m+0]=w,b[l+m+1]=A,b[l+m+2]=U}if(f(b,l+=3*h,a(b,l-3*h-4,3*h+4)),l+=4,u){f(b,l,h),s(b,l+=4,"tRNS"),l+=4;for(v=0;v<h;v++)b[l+v]=r.plte[v]>>>24&255;f(b,l+=h,a(b,l-h-4,h+4)),l+=4}}var _=0;for(p=0;p<r.frames.length;p++){var q=r.frames[p];c&&(f(b,l,26),s(b,l+=4,"fcTL"),f(b,l+=4,_++),f(b,l+=4,q.rect.width),f(b,l+=4,q.rect.height),f(b,l+=4,q.rect.x),f(b,l+=4,q.rect.y),o(b,l+=4,i[p]),o(b,l+=2,1e3),b[l+=2]=q.dispose,b[++l]=q.blend,f(b,++l,a(b,l-30,30)),l+=4);var I=q.cimg;f(b,l,(h=I.length)+(0==p?0:4));var M=l+=4;s(b,l,0==p?"IDAT":"fdAT"),l+=4,0!=p&&(f(b,l,_++),l+=4);for(v=0;v<h;v++)b[l+v]=I[v];f(b,l+=h,a(b,M,l-M)),l+=4}return f(b,l,0),s(b,l+=4,"IEND"),f(b,l+=4,a(b,l-4,4)),l+=4,b.buffer},e.encode.compressPNG=function(r,t){for(var n=0;n<r.frames.length;n++){var i=r.frames[n],a=(i.rect.width,i.rect.height),f=new Uint8Array(a*i.bpl+a);i.cimg=e.encode._filterZero(i.img,a,i.bpp,i.bpl,f,t)}},e.encode.compress=function(r,t,n,i,a,f){null==f&&(f=!1);for(var o=6,s=8,l=255,c=0;c<r.length;c++)for(var u=new Uint8Array(r[c]),d=u.length,h=0;h<d;h+=4)l&=u[h+3];var v=255!=l,p=v&&a,b=e.encode.framize(r,t,n,a,p),g={},m=[],y=[];if(0!=i){var w=[];for(h=0;h<b.length;h++)w.push(b[h].img.buffer);var A=e.encode.concatRGBA(w,a),U=e.quantize(A,i),_=0,q=new Uint8Array(U.abuf);for(h=0;h<b.length;h++){var I=(F=b[h].img).length;y.push(new Uint8Array(U.inds.buffer,_>>2,I>>2));for(c=0;c<I;c+=4)F[c]=q[_+c],F[c+1]=q[_+c+1],F[c+2]=q[_+c+2],F[c+3]=q[_+c+3];_+=I}for(h=0;h<U.plte.length;h++)m.push(U.plte[h].est.rgba)}else for(c=0;c<b.length;c++){var M=b[c],T=new Uint32Array(M.img.buffer),z=M.rect.width,R=(d=T.length,new Uint8Array(d));y.push(R);for(h=0;h<d;h++){var N=T[h];if(0!=h&&N==T[h-1])R[h]=R[h-1];else if(h>z&&N==T[h-z])R[h]=R[h-z];else{var L=g[N];if(null==L&&(g[N]=L=m.length,m.push(N),m.length>=300))break;R[h]=L}}}var P=m.length;P<=256&&0==f&&(s=P<=2?1:P<=4?2:P<=16?4:8,a&&(s=8));for(c=0;c<b.length;c++){(M=b[c]).rect.x,M.rect.y,z=M.rect.width;var S=M.rect.height,D=M.img,B=(new Uint32Array(D.buffer),4*z),x=4;if(P<=256&&0==f){B=Math.ceil(s*z/8);for(var C=new Uint8Array(B*S),G=y[c],Z=0;Z<S;Z++){h=Z*B;var k=Z*z;if(8==s)for(var E=0;E<z;E++)C[h+E]=G[k+E];else if(4==s)for(E=0;E<z;E++)C[h+(E>>1)]|=G[k+E]<<4-4*(1&E);else if(2==s)for(E=0;E<z;E++)C[h+(E>>2)]|=G[k+E]<<6-2*(3&E);else if(1==s)for(E=0;E<z;E++)C[h+(E>>3)]|=G[k+E]<<7-1*(7&E)}D=C,o=3,x=1}else if(0==v&&1==b.length){C=new Uint8Array(z*S*3);var H=z*S;for(h=0;h<H;h++){var F,K=4*h;C[F=3*h]=D[K],C[F+1]=D[K+1],C[F+2]=D[K+2]}D=C,o=2,x=3,B=3*z}M.img=D,M.bpl=B,M.bpp=x}return{ctype:o,depth:s,plte:m,frames:b}},e.encode.framize=function(r,t,n,i,a){for(var f=[],o=0;o<r.length;o++){var s=new Uint8Array(r[o]),l=new Uint32Array(s.buffer),c=0,u=0,d=t,h=n,v=0;if(0==o||a)s=s.slice(0);else{for(var p=i||1==o||2==f[f.length-2].dispose?1:2,b=0,g=1e9,m=0;m<p;m++){for(var y=new Uint8Array(r[o-1-m]),w=new Uint32Array(r[o-1-m]),A=t,U=n,_=-1,q=-1,I=0;I<n;I++)for(var M=0;M<t;M++){var T=I*t+M;l[T]!=w[T]&&(M<A&&(A=M),M>_&&(_=M),I<U&&(U=I),I>q&&(q=I))}var z=-1==_?1:(_-A+1)*(q-U+1);z<g&&(g=z,b=m,-1==_?(c=u=0,d=h=1):(c=A,u=U,d=_-A+1,h=q-U+1))}y=new Uint8Array(r[o-1-b]);1==b&&(f[f.length-1].dispose=2);var R=new Uint8Array(d*h*4);new Uint32Array(R.buffer);e._copyTile(y,t,n,R,d,h,-c,-u,0),e._copyTile(s,t,n,R,d,h,-c,-u,3)?(e._copyTile(s,t,n,R,d,h,-c,-u,2),v=1):(e._copyTile(s,t,n,R,d,h,-c,-u,0),v=0),s=R}f.push({rect:{x:c,y:u,width:d,height:h},img:s,blend:v,dispose:a?1:0})}return f},e.encode._filterZero=function(t,n,i,a,f,o){if(-1!=o){for(var s=0;s<n;s++)e.encode._filterLine(f,t,s,a,i,o);return r.deflate(f)}for(var l=[],c=0;c<5;c++)if(!(n*a>5e5)||2!=c&&3!=c&&4!=c){for(s=0;s<n;s++)e.encode._filterLine(f,t,s,a,i,c);if(l.push(r.deflate(f)),1==i)break}for(var u,d=1e9,h=0;h<l.length;h++)l[h].length<d&&(u=h,d=l[h].length);return l[u]},e.encode._filterLine=function(r,t,n,i,a,f){var o=n*i,s=o+n,l=e.decode._paeth;if(r[s]=f,s++,0==f)for(var c=0;c<i;c++)r[s+c]=t[o+c];else if(1==f){for(c=0;c<a;c++)r[s+c]=t[o+c];for(c=a;c<i;c++)r[s+c]=t[o+c]-t[o+c-a]+256&255}else if(0==n){for(c=0;c<a;c++)r[s+c]=t[o+c];if(2==f)for(c=a;c<i;c++)r[s+c]=t[o+c];if(3==f)for(c=a;c<i;c++)r[s+c]=t[o+c]-(t[o+c-a]>>1)+256&255;if(4==f)for(c=a;c<i;c++)r[s+c]=t[o+c]-l(t[o+c-a],0,0)+256&255}else{if(2==f)for(c=0;c<i;c++)r[s+c]=t[o+c]+256-t[o+c-i]&255;if(3==f){for(c=0;c<a;c++)r[s+c]=t[o+c]+256-(t[o+c-i]>>1)&255;for(c=a;c<i;c++)r[s+c]=t[o+c]+256-(t[o+c-i]+t[o+c-a]>>1)&255}if(4==f){for(c=0;c<a;c++)r[s+c]=t[o+c]+256-l(0,t[o+c-i],0)&255;for(c=a;c<i;c++)r[s+c]=t[o+c]+256-l(t[o+c-a],t[o+c-i],t[o+c-a-i])&255}}},e.crc={table:function(){for(var e=new Uint32Array(256),r=0;r<256;r++){for(var t=r,n=0;n<8;n++)1&t?t=3988292384^t>>>1:t>>>=1;e[r]=t}return e}(),update:function(r,t,n,i){for(var a=0;a<i;a++)r=e.crc.table[255&(r^t[n+a])]^r>>>8;return r},crc:function(r,t,n){return 4294967295^e.crc.update(4294967295,r,t,n)}},e.quantize=function(r,t){for(var n=new Uint8Array(r),i=n.slice(0),a=new Uint32Array(i.buffer),f=e.quantize.getKDtree(i,t),o=f[0],s=f[1],l=(e.quantize.planeDst,n),c=a,u=l.length,d=new Uint8Array(n.length>>2),h=0;h<u;h+=4){var v=l[h]*(1/255),p=l[h+1]*(1/255),b=l[h+2]*(1/255),g=l[h+3]*(1/255),m=e.quantize.getNearest(o,v,p,b,g);d[h>>2]=m.ind,c[h>>2]=m.est.rgba}return{abuf:i.buffer,inds:d,plte:s}},e.quantize.getKDtree=function(r,t,n){null==n&&(n=1e-4);var i=new Uint32Array(r.buffer),a={i0:0,i1:r.length,bst:null,est:null,tdst:0,left:null,right:null};a.bst=e.quantize.stats(r,a.i0,a.i1),a.est=e.quantize.estats(a.bst);for(var f=[a];f.length<t;){for(var o=0,s=0,l=0;l<f.length;l++)f[l].est.L>o&&(o=f[l].est.L,s=l);if(o<n)break;var c=f[s],u=e.quantize.splitPixels(r,i,c.i0,c.i1,c.est.e,c.est.eMq255);if(c.i0>=u||c.i1<=u)c.est.L=0;else{var d={i0:c.i0,i1:u,bst:null,est:null,tdst:0,left:null,right:null};d.bst=e.quantize.stats(r,d.i0,d.i1),d.est=e.quantize.estats(d.bst);var h={i0:u,i1:c.i1,bst:null,est:null,tdst:0,left:null,right:null};h.bst={R:[],m:[],N:c.bst.N-d.bst.N};for(l=0;l<16;l++)h.bst.R[l]=c.bst.R[l]-d.bst.R[l];for(l=0;l<4;l++)h.bst.m[l]=c.bst.m[l]-d.bst.m[l];h.est=e.quantize.estats(h.bst),c.left=d,c.right=h,f[s]=d,f.push(h)}}f.sort(function(e,r){return r.bst.N-e.bst.N});for(l=0;l<f.length;l++)f[l].ind=l;return[a,f]},e.quantize.getNearest=function(r,t,n,i,a){if(null==r.left)return r.tdst=e.quantize.dist(r.est.q,t,n,i,a),r;var f=e.quantize.planeDst(r.est,t,n,i,a),o=r.left,s=r.right;f>0&&(o=r.right,s=r.left);var l=e.quantize.getNearest(o,t,n,i,a);if(l.tdst<=f*f)return l;var c=e.quantize.getNearest(s,t,n,i,a);return c.tdst<l.tdst?c:l},e.quantize.planeDst=function(e,r,t,n,i){var a=e.e;return a[0]*r+a[1]*t+a[2]*n+a[3]*i-e.eMq},e.quantize.dist=function(e,r,t,n,i){var a=r-e[0],f=t-e[1],o=n-e[2],s=i-e[3];return a*a+f*f+o*o+s*s},e.quantize.splitPixels=function(r,t,n,i,a,f){var o=e.quantize.vecDot;i-=4;for(;n<i;){for(;o(r,n,a)<=f;)n+=4;for(;o(r,i,a)>f;)i-=4;if(n>=i)break;var s=t[n>>2];t[n>>2]=t[i>>2],t[i>>2]=s,n+=4,i-=4}for(;o(r,n,a)>f;)n-=4;return n+4},e.quantize.vecDot=function(e,r,t){return e[r]*t[0]+e[r+1]*t[1]+e[r+2]*t[2]+e[r+3]*t[3]},e.quantize.stats=function(e,r,t){for(var n=[0,0,0,0,0,0,0,0,0,0,0,0,0,0,0,0],i=[0,0,0,0],a=t-r>>2,f=r;f<t;f+=4){var o=e[f]*(1/255),s=e[f+1]*(1/255),l=e[f+2]*(1/255),c=e[f+3]*(1/255);i[0]+=o,i[1]+=s,i[2]+=l,i[3]+=c,n[0]+=o*o,n[1]+=o*s,n[2]+=o*l,n[3]+=o*c,n[5]+=s*s,n[6]+=s*l,n[7]+=s*c,n[10]+=l*l,n[11]+=l*c,n[15]+=c*c}return n[4]=n[1],n[8]=n[2],n[9]=n[6],n[12]=n[3],n[13]=n[7],n[14]=n[11],{R:n,m:i,N:a}},e.quantize.estats=function(r){var t=r.R,n=r.m,i=r.N,a=n[0],f=n[1],o=n[2],s=n[3],l=0==i?0:1/i,c=[t[0]-a*a*l,t[1]-a*f*l,t[2]-a*o*l,t[3]-a*s*l,t[4]-f*a*l,t[5]-f*f*l,t[6]-f*o*l,t[7]-f*s*l,t[8]-o*a*l,t[9]-o*f*l,t[10]-o*o*l,t[11]-o*s*l,t[12]-s*a*l,t[13]-s*f*l,t[14]-s*o*l,t[15]-s*s*l],u=c,d=e.M4,h=[.5,.5,.5,.5],v=0,p=0;if(0!=i)for(var b=0;b<10&&(h=d.multVec(u,h),p=Math.sqrt(d.dot(h,h)),h=d.sml(1/p,h),!(Math.abs(p-v)<1e-9));b++)v=p;var g=[a*l,f*l,o*l,s*l];return{Cov:c,q:g,e:h,L:v,eMq255:d.dot(d.sml(255,g),h),eMq:d.dot(h,g),rgba:(Math.round(255*g[3])<<24|Math.round(255*g[2])<<16|Math.round(255*g[1])<<8|Math.round(255*g[0])<<0)>>>0}},e.M4={multVec:function(e,r){return[e[0]*r[0]+e[1]*r[1]+e[2]*r[2]+e[3]*r[3],e[4]*r[0]+e[5]*r[1]+e[6]*r[2]+e[7]*r[3],e[8]*r[0]+e[9]*r[1]+e[10]*r[2]+e[11]*r[3],e[12]*r[0]+e[13]*r[1]+e[14]*r[2]+e[15]*r[3]]},dot:function(e,r){return e[0]*r[0]+e[1]*r[1]+e[2]*r[2]+e[3]*r[3]},sml:function(e,r){return[e*r[0],e*r[1],e*r[2],e*r[3]]}},e.encode.concatRGBA=function(e,r){for(var t=0,n=0;n<e.length;n++)t+=e[n].byteLength;var i=new Uint8Array(t),a=0;for(n=0;n<e.length;n++){for(var f=new Uint8Array(e[n]),o=f.length,s=0;s<o;s+=4){var l=f[s],c=f[s+1],u=f[s+2],d=f[s+3];r&&(d=0==(128&d)?0:255),0==d&&(l=c=u=0),i[a+s]=l,i[a+s+1]=c,i[a+s+2]=u,i[a+s+3]=d}a+=o}return i.buffer}}(e,"function"==typeof require?require("pako"):window.pako)}();
142
+ </script>
143
+
144
+ <script>
145
+ class Player {
146
+
147
+ constructor(container) {
148
+ this.container = container
149
+ this.global_frac = 0.0
150
+ this.container = document.getElementById(container)
151
+ this.progress = null;
152
+ this.mat = [[]]
153
+
154
+ this.player = this.container.querySelector('audio')
155
+ this.demo_img = this.container.querySelector('.underlay > img')
156
+ this.overlay = this.container.querySelector('.overlay')
157
+ this.playpause = this.container.querySelector(".playpause");
158
+ this.download = this.container.querySelector(".download");
159
+ this.play_img = this.container.querySelector('.play-img')
160
+ this.pause_img = this.container.querySelector('.pause-img')
161
+ this.canvas = this.container.querySelector('.response-canvas')
162
+ this.response_container = this.container.querySelector('.response')
163
+ this.context = this.canvas.getContext('2d');
164
+
165
+ // console.log(this.player.duration)
166
+ var togglePlayPause = () => {
167
+ if (this.player.networkState !== 1) {
168
+ return
169
+ }
170
+ if (this.player.paused || this.player.ended) {
171
+ this.play()
172
+ } else {
173
+ this.pause()
174
+ }
175
+ }
176
+
177
+ this.update = () => {
178
+ this.global_frac = this.player.currentTime / this.player.duration
179
+ // this.global_frac = frac
180
+ // console.log(this.player.currentTime, this.player.duration, this.global_frac)
181
+ this.overlay.style.width = (100*(1.0 - this.global_frac)).toString() + '%'
182
+ this.redraw()
183
+ }
184
+
185
+ // var start = null;
186
+ this.updateLoop = (timestamp) => {
187
+ // if (!start) start = timestamp;
188
+ // var progress = timestamp - start;
189
+ this.update()
190
+ // this.progress = setTimeout(this.updateLoop, 10)
191
+ this.progress = window.requestAnimationFrame(this.updateLoop)
192
+ }
193
+
194
+ this.seek = (e) => {
195
+ this.global_frac = e.offsetX / this.demo_img.width
196
+ this.player.currentTime = this.global_frac * this.player.duration
197
+ // console.log(this.global_frac)
198
+ this.overlay.style.width = (100*(1.0 - this.global_frac)).toString() + '%'
199
+ this.redraw()
200
+ }
201
+
202
+ var download_audio = () => {
203
+ var url = this.player.querySelector('#src').src
204
+ const a = document.createElement('a')
205
+ a.href = url
206
+ a.download = "download"
207
+ document.body.appendChild(a)
208
+ a.click()
209
+ document.body.removeChild(a)
210
+ }
211
+
212
+ this.demo_img.onclick = this.seek;
213
+ this.playpause.disabled = true
214
+ this.player.onplay = this.updateLoop
215
+ this.player.onpause = () => {
216
+ window.cancelAnimationFrame(this.progress)
217
+ this.update();
218
+ }
219
+ this.player.onended = () => {this.pause()}
220
+ this.playpause.onclick = togglePlayPause;
221
+ this.download.onclick = download_audio;
222
+ }
223
+
224
+ load(audio_fname, img_fname, levels_fname) {
225
+ this.pause()
226
+ window.cancelAnimationFrame(this.progress)
227
+ this.playpause.disabled = true
228
+
229
+ this.player.querySelector('#src').setAttribute("src", audio_fname)
230
+ this.player.load()
231
+ this.demo_img.setAttribute("src", img_fname)
232
+ this.overlay.style.width = '0%'
233
+
234
+ fetch(levels_fname)
235
+ .then(response => response.arrayBuffer())
236
+ .then(text => {
237
+ this.mat = this.parse(text);
238
+ this.playpause.disabled = false;
239
+ this.redraw();
240
+ })
241
+ }
242
+
243
+ parse(buffer) {
244
+ var img = UPNG.decode(buffer)
245
+ var dat = UPNG.toRGBA8(img)[0]
246
+ var view = new DataView(dat)
247
+ var data = new Array(img.width).fill(0).map(() => new Array(img.height).fill(0));
248
+
249
+ var min =100
250
+ var max = -100
251
+ var idx = 0
252
+ for (let i=0; i < img.height*img.width*4; i+=4) {
253
+ var rgba = [view.getUint8(i, 1) / 255, view.getUint8(i + 1, 1) / 255, view.getUint8(i + 2, 1) / 255, view.getUint8(i + 3, 1) / 255]
254
+ var norm = Math.pow(Math.pow(rgba[0], 2) + Math.pow(rgba[1], 2) + Math.pow(rgba[2], 2), 0.5)
255
+ data[idx % img.width][img.height - Math.floor(idx / img.width) - 1] = norm
256
+
257
+ idx += 1
258
+ min = Math.min(min, norm)
259
+ max = Math.max(max, norm)
260
+ }
261
+ for (let i = 0; i < data.length; i++) {
262
+ for (let j = 0; j < data[i].length; j++) {
263
+ data[i][j] = Math.pow((data[i][j] - min) / (max - min), 1.5)
264
+ }
265
+ }
266
+ var data3 = new Array(img.width).fill(0).map(() => new Array(img.height).fill(0));
267
+ for (let i = 0; i < data.length; i++) {
268
+ for (let j = 0; j < data[i].length; j++) {
269
+ if (i == 0 || i == (data.length - 1)) {
270
+ data3[i][j] = data[i][j]
271
+ } else{
272
+ data3[i][j] = 0.33*(data[i - 1][j]) + 0.33*(data[i][j]) + 0.33*(data[i + 1][j])
273
+ // data3[i][j] = 0.00*(data[i - 1][j]) + 1.00*(data[i][j]) + 0.00*(data[i + 1][j])
274
+ }
275
+ }
276
+ }
277
+
278
+ var scale = 5
279
+ var data2 = new Array(scale*img.width).fill(0).map(() => new Array(img.height).fill(0));
280
+ for (let j = 0; j < data[0].length; j++) {
281
+ for (let i = 0; i < data.length - 1; i++) {
282
+ for (let k = 0; k < scale; k++) {
283
+ data2[scale*i + k][j] = (1.0 - (k/scale))*data3[i][j] + (k / scale)*data3[i + 1][j]
284
+ }
285
+ }
286
+ }
287
+ return data2
288
+ }
289
+
290
+ play() {
291
+ this.player.play();
292
+ this.play_img.style.display = 'none'
293
+ this.pause_img.style.display = 'block'
294
+ }
295
+
296
+ pause() {
297
+ this.player.pause();
298
+ this.pause_img.style.display = 'none'
299
+ this.play_img.style.display = 'block'
300
+ }
301
+
302
+ redraw() {
303
+ this.canvas.width = window.devicePixelRatio*this.response_container.offsetWidth;
304
+ this.canvas.height = window.devicePixelRatio*this.response_container.offsetHeight;
305
+
306
+ this.context.clearRect(0, 0, this.canvas.width, this.canvas.height)
307
+ this.canvas.style.width = (this.canvas.width / window.devicePixelRatio).toString() + "px";
308
+ this.canvas.style.height = (this.canvas.height / window.devicePixelRatio).toString() + "px";
309
+
310
+ var f = this.global_frac*this.mat.length
311
+ var tstep = Math.min(Math.floor(f), this.mat.length - 2)
312
+ var heights = this.mat[tstep]
313
+ var bar_width = (this.canvas.width / heights.length) - 1
314
+
315
+ for (let k = 0; k < heights.length - 1; k++) {
316
+ var height = Math.max(Math.round((heights[k])*this.canvas.height), 3)
317
+ this.context.fillStyle = '#696f7b';
318
+ this.context.fillRect(k*(bar_width + 1), (this.canvas.height - height) / 2, bar_width, height);
319
+ }
320
+ }
321
+ }
322
+ </script>
audiotools/core/templates/pandoc.css ADDED
@@ -0,0 +1,407 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ /*
2
+ Copyright (c) 2017 Chris Patuzzo
3
+ https://twitter.com/chrispatuzzo
4
+
5
+ Permission is hereby granted, free of charge, to any person obtaining a copy
6
+ of this software and associated documentation files (the "Software"), to deal
7
+ in the Software without restriction, including without limitation the rights
8
+ to use, copy, modify, merge, publish, distribute, sublicense, and/or sell
9
+ copies of the Software, and to permit persons to whom the Software is
10
+ furnished to do so, subject to the following conditions:
11
+
12
+ The above copyright notice and this permission notice shall be included in all
13
+ copies or substantial portions of the Software.
14
+
15
+ THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR
16
+ IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY,
17
+ FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE
18
+ AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER
19
+ LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM,
20
+ OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE
21
+ SOFTWARE.
22
+ */
23
+
24
+ body {
25
+ font-family: Helvetica, arial, sans-serif;
26
+ font-size: 14px;
27
+ line-height: 1.6;
28
+ padding-top: 10px;
29
+ padding-bottom: 10px;
30
+ background-color: white;
31
+ padding: 30px;
32
+ color: #333;
33
+ }
34
+
35
+ body > *:first-child {
36
+ margin-top: 0 !important;
37
+ }
38
+
39
+ body > *:last-child {
40
+ margin-bottom: 0 !important;
41
+ }
42
+
43
+ a {
44
+ color: #4183C4;
45
+ text-decoration: none;
46
+ }
47
+
48
+ a.absent {
49
+ color: #cc0000;
50
+ }
51
+
52
+ a.anchor {
53
+ display: block;
54
+ padding-left: 30px;
55
+ margin-left: -30px;
56
+ cursor: pointer;
57
+ position: absolute;
58
+ top: 0;
59
+ left: 0;
60
+ bottom: 0;
61
+ }
62
+
63
+ h1, h2, h3, h4, h5, h6 {
64
+ margin: 20px 0 10px;
65
+ padding: 0;
66
+ font-weight: bold;
67
+ -webkit-font-smoothing: antialiased;
68
+ cursor: text;
69
+ position: relative;
70
+ }
71
+
72
+ h2:first-child, h1:first-child, h1:first-child + h2, h3:first-child, h4:first-child, h5:first-child, h6:first-child {
73
+ margin-top: 0;
74
+ padding-top: 0;
75
+ }
76
+
77
+ h1:hover a.anchor, h2:hover a.anchor, h3:hover a.anchor, h4:hover a.anchor, h5:hover a.anchor, h6:hover a.anchor {
78
+ text-decoration: none;
79
+ }
80
+
81
+ h1 tt, h1 code {
82
+ font-size: inherit;
83
+ }
84
+
85
+ h2 tt, h2 code {
86
+ font-size: inherit;
87
+ }
88
+
89
+ h3 tt, h3 code {
90
+ font-size: inherit;
91
+ }
92
+
93
+ h4 tt, h4 code {
94
+ font-size: inherit;
95
+ }
96
+
97
+ h5 tt, h5 code {
98
+ font-size: inherit;
99
+ }
100
+
101
+ h6 tt, h6 code {
102
+ font-size: inherit;
103
+ }
104
+
105
+ h1 {
106
+ font-size: 28px;
107
+ color: black;
108
+ }
109
+
110
+ h2 {
111
+ font-size: 24px;
112
+ border-bottom: 1px solid #cccccc;
113
+ color: black;
114
+ }
115
+
116
+ h3 {
117
+ font-size: 18px;
118
+ }
119
+
120
+ h4 {
121
+ font-size: 16px;
122
+ }
123
+
124
+ h5 {
125
+ font-size: 14px;
126
+ }
127
+
128
+ h6 {
129
+ color: #777777;
130
+ font-size: 14px;
131
+ }
132
+
133
+ p, blockquote, ul, ol, dl, li, table, pre {
134
+ margin: 15px 0;
135
+ }
136
+
137
+ hr {
138
+ border: 0 none;
139
+ color: #cccccc;
140
+ height: 4px;
141
+ padding: 0;
142
+ }
143
+
144
+ body > h2:first-child {
145
+ margin-top: 0;
146
+ padding-top: 0;
147
+ }
148
+
149
+ body > h1:first-child {
150
+ margin-top: 0;
151
+ padding-top: 0;
152
+ }
153
+
154
+ body > h1:first-child + h2 {
155
+ margin-top: 0;
156
+ padding-top: 0;
157
+ }
158
+
159
+ body > h3:first-child, body > h4:first-child, body > h5:first-child, body > h6:first-child {
160
+ margin-top: 0;
161
+ padding-top: 0;
162
+ }
163
+
164
+ a:first-child h1, a:first-child h2, a:first-child h3, a:first-child h4, a:first-child h5, a:first-child h6 {
165
+ margin-top: 0;
166
+ padding-top: 0;
167
+ }
168
+
169
+ h1 p, h2 p, h3 p, h4 p, h5 p, h6 p {
170
+ margin-top: 0;
171
+ }
172
+
173
+ li p.first {
174
+ display: inline-block;
175
+ }
176
+
177
+ ul, ol {
178
+ padding-left: 30px;
179
+ }
180
+
181
+ ul :first-child, ol :first-child {
182
+ margin-top: 0;
183
+ }
184
+
185
+ ul :last-child, ol :last-child {
186
+ margin-bottom: 0;
187
+ }
188
+
189
+ dl {
190
+ padding: 0;
191
+ }
192
+
193
+ dl dt {
194
+ font-size: 14px;
195
+ font-weight: bold;
196
+ font-style: italic;
197
+ padding: 0;
198
+ margin: 15px 0 5px;
199
+ }
200
+
201
+ dl dt:first-child {
202
+ padding: 0;
203
+ }
204
+
205
+ dl dt > :first-child {
206
+ margin-top: 0;
207
+ }
208
+
209
+ dl dt > :last-child {
210
+ margin-bottom: 0;
211
+ }
212
+
213
+ dl dd {
214
+ margin: 0 0 15px;
215
+ padding: 0 15px;
216
+ }
217
+
218
+ dl dd > :first-child {
219
+ margin-top: 0;
220
+ }
221
+
222
+ dl dd > :last-child {
223
+ margin-bottom: 0;
224
+ }
225
+
226
+ blockquote {
227
+ border-left: 4px solid #dddddd;
228
+ padding: 0 15px;
229
+ color: #777777;
230
+ }
231
+
232
+ blockquote > :first-child {
233
+ margin-top: 0;
234
+ }
235
+
236
+ blockquote > :last-child {
237
+ margin-bottom: 0;
238
+ }
239
+
240
+ table {
241
+ padding: 0;
242
+ }
243
+ table tr {
244
+ border-top: 1px solid #cccccc;
245
+ background-color: white;
246
+ margin: 0;
247
+ padding: 0;
248
+ }
249
+
250
+ table tr:nth-child(2n) {
251
+ background-color: #f8f8f8;
252
+ }
253
+
254
+ table tr th {
255
+ font-weight: bold;
256
+ border: 1px solid #cccccc;
257
+ text-align: left;
258
+ margin: 0;
259
+ padding: 6px 13px;
260
+ }
261
+
262
+ table tr td {
263
+ border: 1px solid #cccccc;
264
+ text-align: left;
265
+ margin: 0;
266
+ padding: 6px 13px;
267
+ }
268
+
269
+ table tr th :first-child, table tr td :first-child {
270
+ margin-top: 0;
271
+ }
272
+
273
+ table tr th :last-child, table tr td :last-child {
274
+ margin-bottom: 0;
275
+ }
276
+
277
+ img {
278
+ max-width: 100%;
279
+ }
280
+
281
+ span.frame {
282
+ display: block;
283
+ overflow: hidden;
284
+ }
285
+
286
+ span.frame > span {
287
+ border: 1px solid #dddddd;
288
+ display: block;
289
+ float: left;
290
+ overflow: hidden;
291
+ margin: 13px 0 0;
292
+ padding: 7px;
293
+ width: auto;
294
+ }
295
+
296
+ span.frame span img {
297
+ display: block;
298
+ float: left;
299
+ }
300
+
301
+ span.frame span span {
302
+ clear: both;
303
+ color: #333333;
304
+ display: block;
305
+ padding: 5px 0 0;
306
+ }
307
+
308
+ span.align-center {
309
+ display: block;
310
+ overflow: hidden;
311
+ clear: both;
312
+ }
313
+
314
+ span.align-center > span {
315
+ display: block;
316
+ overflow: hidden;
317
+ margin: 13px auto 0;
318
+ text-align: center;
319
+ }
320
+
321
+ span.align-center span img {
322
+ margin: 0 auto;
323
+ text-align: center;
324
+ }
325
+
326
+ span.align-right {
327
+ display: block;
328
+ overflow: hidden;
329
+ clear: both;
330
+ }
331
+
332
+ span.align-right > span {
333
+ display: block;
334
+ overflow: hidden;
335
+ margin: 13px 0 0;
336
+ text-align: right;
337
+ }
338
+
339
+ span.align-right span img {
340
+ margin: 0;
341
+ text-align: right;
342
+ }
343
+
344
+ span.float-left {
345
+ display: block;
346
+ margin-right: 13px;
347
+ overflow: hidden;
348
+ float: left;
349
+ }
350
+
351
+ span.float-left span {
352
+ margin: 13px 0 0;
353
+ }
354
+
355
+ span.float-right {
356
+ display: block;
357
+ margin-left: 13px;
358
+ overflow: hidden;
359
+ float: right;
360
+ }
361
+
362
+ span.float-right > span {
363
+ display: block;
364
+ overflow: hidden;
365
+ margin: 13px auto 0;
366
+ text-align: right;
367
+ }
368
+
369
+ code, tt {
370
+ margin: 0 2px;
371
+ padding: 0 5px;
372
+ white-space: nowrap;
373
+ border-radius: 3px;
374
+ }
375
+
376
+ pre code {
377
+ margin: 0;
378
+ padding: 0;
379
+ white-space: pre;
380
+ border: none;
381
+ background: transparent;
382
+ }
383
+
384
+ .highlight pre {
385
+ font-size: 13px;
386
+ line-height: 19px;
387
+ overflow: auto;
388
+ padding: 6px 10px;
389
+ border-radius: 3px;
390
+ }
391
+
392
+ pre {
393
+ font-size: 13px;
394
+ line-height: 19px;
395
+ overflow: auto;
396
+ padding: 6px 10px;
397
+ border-radius: 3px;
398
+ }
399
+
400
+ pre code, pre tt {
401
+ background-color: transparent;
402
+ border: none;
403
+ }
404
+
405
+ body {
406
+ max-width: 600px;
407
+ }
audiotools/core/templates/widget.html ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div id='PLAYER_ID' class='player' style="max-width: MAX_WIDTH;">
2
+ <div class='spectrogram' style="padding-top: PADDING_AMOUNT;">
3
+ <div class='overlay'></div>
4
+ <div class='underlay'>
5
+ <img>
6
+ </div>
7
+ </div>
8
+
9
+ <div class='audio-controls'>
10
+ <button id="playpause" disabled class='playpause' title="play">
11
+ <svg class='play-img' width="14px" height="19px" viewBox="0 0 14 19">
12
+ <polygon id="Triangle" fill="#000000" transform="translate(9, 9.5) rotate(90) translate(-7, -9.5) " points="7 2.5 16.5 16.5 -2.5 16.5"></polygon>
13
+ </svg>
14
+ <svg class='pause-img' width="16px" height="19px" viewBox="0 0 16 19">
15
+ <g fill="#000000" stroke="#000000">
16
+ <rect id="Rectangle" x="0.5" y="0.5" width="4" height="18"></rect>
17
+ <rect id="Rectangle" x="11.5" y="0.5" width="4" height="18"></rect>
18
+ </g>
19
+ </svg>
20
+ </button>
21
+
22
+ <audio class='play'>
23
+ <source id='src'>
24
+ </audio>
25
+ <div class='response'>
26
+ <canvas class='response-canvas'></canvas>
27
+ </div>
28
+
29
+ <button id="download" class='download' title="download">
30
+ <svg class='download-img' x="0px" y="0px" viewBox="0 0 29.978 29.978" style="enable-background:new 0 0 29.978 29.978;" xml:space="preserve">
31
+ <g>
32
+ <path d="M25.462,19.105v6.848H4.515v-6.848H0.489v8.861c0,1.111,0.9,2.012,2.016,2.012h24.967c1.115,0,2.016-0.9,2.016-2.012
33
+ v-8.861H25.462z"/>
34
+ <path d="M14.62,18.426l-5.764-6.965c0,0-0.877-0.828,0.074-0.828s3.248,0,3.248,0s0-0.557,0-1.416c0-2.449,0-6.906,0-8.723
35
+ c0,0-0.129-0.494,0.615-0.494c0.75,0,4.035,0,4.572,0c0.536,0,0.524,0.416,0.524,0.416c0,1.762,0,6.373,0,8.742
36
+ c0,0.768,0,1.266,0,1.266s1.842,0,2.998,0c1.154,0,0.285,0.867,0.285,0.867s-4.904,6.51-5.588,7.193
37
+ C15.092,18.979,14.62,18.426,14.62,18.426z"/>
38
+ </g>
39
+ </svg>
40
+ </button>
41
+ </div>
42
+ </div>
43
+
44
+ <script>
45
+ var PLAYER_ID = new Player('PLAYER_ID')
46
+ PLAYER_ID.load(
47
+ "AUDIO_SRC",
48
+ "IMAGE_SRC",
49
+ "LEVELS_SRC"
50
+ )
51
+ window.addEventListener("resize", function() {PLAYER_ID.redraw()})
52
+ </script>
audiotools/core/util.py ADDED
@@ -0,0 +1,671 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import glob
3
+ import math
4
+ import numbers
5
+ import os
6
+ import random
7
+ import typing
8
+ from contextlib import contextmanager
9
+ from dataclasses import dataclass
10
+ from pathlib import Path
11
+ from typing import Dict
12
+ from typing import List
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torchaudio
17
+ from flatten_dict import flatten
18
+ from flatten_dict import unflatten
19
+
20
+
21
+ @dataclass
22
+ class Info:
23
+ """Shim for torchaudio.info API changes."""
24
+
25
+ sample_rate: float
26
+ num_frames: int
27
+
28
+ @property
29
+ def duration(self) -> float:
30
+ return self.num_frames / self.sample_rate
31
+
32
+
33
+ def info(audio_path: str):
34
+ """Shim for torchaudio.info to make 0.7.2 API match 0.8.0.
35
+
36
+ Parameters
37
+ ----------
38
+ audio_path : str
39
+ Path to audio file.
40
+ """
41
+ # try default backend first, then fallback to soundfile
42
+ try:
43
+ info = torchaudio.info(str(audio_path))
44
+ except: # pragma: no cover
45
+ info = torchaudio.backend.soundfile_backend.info(str(audio_path))
46
+
47
+ if isinstance(info, tuple): # pragma: no cover
48
+ signal_info = info[0]
49
+ info = Info(sample_rate=signal_info.rate, num_frames=signal_info.length)
50
+ else:
51
+ info = Info(sample_rate=info.sample_rate, num_frames=info.num_frames)
52
+
53
+ return info
54
+
55
+
56
+ def ensure_tensor(
57
+ x: typing.Union[np.ndarray, torch.Tensor, float, int],
58
+ ndim: int = None,
59
+ batch_size: int = None,
60
+ ):
61
+ """Ensures that the input ``x`` is a tensor of specified
62
+ dimensions and batch size.
63
+
64
+ Parameters
65
+ ----------
66
+ x : typing.Union[np.ndarray, torch.Tensor, float, int]
67
+ Data that will become a tensor on its way out.
68
+ ndim : int, optional
69
+ How many dimensions should be in the output, by default None
70
+ batch_size : int, optional
71
+ The batch size of the output, by default None
72
+
73
+ Returns
74
+ -------
75
+ torch.Tensor
76
+ Modified version of ``x`` as a tensor.
77
+ """
78
+ if not torch.is_tensor(x):
79
+ x = torch.as_tensor(x)
80
+ if ndim is not None:
81
+ assert x.ndim <= ndim
82
+ while x.ndim < ndim:
83
+ x = x.unsqueeze(-1)
84
+ if batch_size is not None:
85
+ if x.shape[0] != batch_size:
86
+ shape = list(x.shape)
87
+ shape[0] = batch_size
88
+ x = x.expand(*shape)
89
+ return x
90
+
91
+
92
+ def _get_value(other):
93
+ from . import AudioSignal
94
+
95
+ if isinstance(other, AudioSignal):
96
+ return other.audio_data
97
+ return other
98
+
99
+
100
+ def hz_to_bin(hz: torch.Tensor, n_fft: int, sample_rate: int):
101
+ """Closest frequency bin given a frequency, number
102
+ of bins, and a sampling rate.
103
+
104
+ Parameters
105
+ ----------
106
+ hz : torch.Tensor
107
+ Tensor of frequencies in Hz.
108
+ n_fft : int
109
+ Number of FFT bins.
110
+ sample_rate : int
111
+ Sample rate of audio.
112
+
113
+ Returns
114
+ -------
115
+ torch.Tensor
116
+ Closest bins to the data.
117
+ """
118
+ shape = hz.shape
119
+ hz = hz.flatten()
120
+ freqs = torch.linspace(0, sample_rate / 2, 2 + n_fft // 2)
121
+ hz[hz > sample_rate / 2] = sample_rate / 2
122
+
123
+ closest = (hz[None, :] - freqs[:, None]).abs()
124
+ closest_bins = closest.min(dim=0).indices
125
+
126
+ return closest_bins.reshape(*shape)
127
+
128
+
129
+ def random_state(seed: typing.Union[int, np.random.RandomState]):
130
+ """
131
+ Turn seed into a np.random.RandomState instance.
132
+
133
+ Parameters
134
+ ----------
135
+ seed : typing.Union[int, np.random.RandomState] or None
136
+ If seed is None, return the RandomState singleton used by np.random.
137
+ If seed is an int, return a new RandomState instance seeded with seed.
138
+ If seed is already a RandomState instance, return it.
139
+ Otherwise raise ValueError.
140
+
141
+ Returns
142
+ -------
143
+ np.random.RandomState
144
+ Random state object.
145
+
146
+ Raises
147
+ ------
148
+ ValueError
149
+ If seed is not valid, an error is thrown.
150
+ """
151
+ if seed is None or seed is np.random:
152
+ return np.random.mtrand._rand
153
+ elif isinstance(seed, (numbers.Integral, np.integer, int)):
154
+ return np.random.RandomState(seed)
155
+ elif isinstance(seed, np.random.RandomState):
156
+ return seed
157
+ else:
158
+ raise ValueError(
159
+ "%r cannot be used to seed a numpy.random.RandomState" " instance" % seed
160
+ )
161
+
162
+
163
+ def seed(random_seed, set_cudnn=False):
164
+ """
165
+ Seeds all random states with the same random seed
166
+ for reproducibility. Seeds ``numpy``, ``random`` and ``torch``
167
+ random generators.
168
+ For full reproducibility, two further options must be set
169
+ according to the torch documentation:
170
+ https://pytorch.org/docs/stable/notes/randomness.html
171
+ To do this, ``set_cudnn`` must be True. It defaults to
172
+ False, since setting it to True results in a performance
173
+ hit.
174
+
175
+ Args:
176
+ random_seed (int): integer corresponding to random seed to
177
+ use.
178
+ set_cudnn (bool): Whether or not to set cudnn into determinstic
179
+ mode and off of benchmark mode. Defaults to False.
180
+ """
181
+
182
+ torch.manual_seed(random_seed)
183
+ np.random.seed(random_seed)
184
+ random.seed(random_seed)
185
+
186
+ if set_cudnn:
187
+ torch.backends.cudnn.deterministic = True
188
+ torch.backends.cudnn.benchmark = False
189
+
190
+
191
+ @contextmanager
192
+ def _close_temp_files(tmpfiles: list):
193
+ """Utility function for creating a context and closing all temporary files
194
+ once the context is exited. For correct functionality, all temporary file
195
+ handles created inside the context must be appended to the ```tmpfiles```
196
+ list.
197
+
198
+ This function is taken wholesale from Scaper.
199
+
200
+ Parameters
201
+ ----------
202
+ tmpfiles : list
203
+ List of temporary file handles
204
+ """
205
+
206
+ def _close():
207
+ for t in tmpfiles:
208
+ try:
209
+ t.close()
210
+ os.unlink(t.name)
211
+ except:
212
+ pass
213
+
214
+ try:
215
+ yield
216
+ except: # pragma: no cover
217
+ _close()
218
+ raise
219
+ _close()
220
+
221
+
222
+ AUDIO_EXTENSIONS = [".wav", ".flac", ".mp3", ".mp4"]
223
+
224
+
225
+ def find_audio(folder: str, ext: List[str] = AUDIO_EXTENSIONS):
226
+ """Finds all audio files in a directory recursively.
227
+ Returns a list.
228
+
229
+ Parameters
230
+ ----------
231
+ folder : str
232
+ Folder to look for audio files in, recursively.
233
+ ext : List[str], optional
234
+ Extensions to look for without the ., by default
235
+ ``['.wav', '.flac', '.mp3', '.mp4']``.
236
+ """
237
+ folder = Path(folder)
238
+ # Take care of case where user has passed in an audio file directly
239
+ # into one of the calling functions.
240
+ if str(folder).endswith(tuple(ext)):
241
+ # if, however, there's a glob in the path, we need to
242
+ # return the glob, not the file.
243
+ if "*" in str(folder):
244
+ return glob.glob(str(folder), recursive=("**" in str(folder)))
245
+ else:
246
+ return [folder]
247
+
248
+ files = []
249
+ for x in ext:
250
+ files += folder.glob(f"**/*{x}")
251
+ return files
252
+
253
+
254
+ def read_sources(
255
+ sources: List[str],
256
+ remove_empty: bool = True,
257
+ relative_path: str = "",
258
+ ext: List[str] = AUDIO_EXTENSIONS,
259
+ ):
260
+ """Reads audio sources that can either be folders
261
+ full of audio files, or CSV files that contain paths
262
+ to audio files. CSV files that adhere to the expected
263
+ format can be generated by
264
+ :py:func:`audiotools.data.preprocess.create_csv`.
265
+
266
+ Parameters
267
+ ----------
268
+ sources : List[str]
269
+ List of audio sources to be converted into a
270
+ list of lists of audio files.
271
+ remove_empty : bool, optional
272
+ Whether or not to remove rows with an empty "path"
273
+ from each CSV file, by default True.
274
+
275
+ Returns
276
+ -------
277
+ list
278
+ List of lists of rows of CSV files.
279
+ """
280
+ files = []
281
+ relative_path = Path(relative_path)
282
+ for source in sources:
283
+ source = str(source)
284
+ _files = []
285
+ if source.endswith(".csv"):
286
+ with open(source, "r") as f:
287
+ reader = csv.DictReader(f)
288
+ for x in reader:
289
+ if remove_empty and x["path"] == "":
290
+ continue
291
+ if x["path"] != "":
292
+ x["path"] = str(relative_path / x["path"])
293
+ _files.append(x)
294
+ else:
295
+ for x in find_audio(source, ext=ext):
296
+ x = str(relative_path / x)
297
+ _files.append({"path": x})
298
+ files.append(sorted(_files, key=lambda x: x["path"]))
299
+ return files
300
+
301
+
302
+ def choose_from_list_of_lists(
303
+ state: np.random.RandomState, list_of_lists: list, p: float = None
304
+ ):
305
+ """Choose a single item from a list of lists.
306
+
307
+ Parameters
308
+ ----------
309
+ state : np.random.RandomState
310
+ Random state to use when choosing an item.
311
+ list_of_lists : list
312
+ A list of lists from which items will be drawn.
313
+ p : float, optional
314
+ Probabilities of each list, by default None
315
+
316
+ Returns
317
+ -------
318
+ typing.Any
319
+ An item from the list of lists.
320
+ """
321
+ source_idx = state.choice(list(range(len(list_of_lists))), p=p)
322
+ item_idx = state.randint(len(list_of_lists[source_idx]))
323
+ return list_of_lists[source_idx][item_idx], source_idx, item_idx
324
+
325
+
326
+ @contextmanager
327
+ def chdir(newdir: typing.Union[Path, str]):
328
+ """
329
+ Context manager for switching directories to run a
330
+ function. Useful for when you want to use relative
331
+ paths to different runs.
332
+
333
+ Parameters
334
+ ----------
335
+ newdir : typing.Union[Path, str]
336
+ Directory to switch to.
337
+ """
338
+ curdir = os.getcwd()
339
+ try:
340
+ os.chdir(newdir)
341
+ yield
342
+ finally:
343
+ os.chdir(curdir)
344
+
345
+
346
+ def prepare_batch(batch: typing.Union[dict, list, torch.Tensor], device: str = "cpu"):
347
+ """Moves items in a batch (typically generated by a DataLoader as a list
348
+ or a dict) to the specified device. This works even if dictionaries
349
+ are nested.
350
+
351
+ Parameters
352
+ ----------
353
+ batch : typing.Union[dict, list, torch.Tensor]
354
+ Batch, typically generated by a dataloader, that will be moved to
355
+ the device.
356
+ device : str, optional
357
+ Device to move batch to, by default "cpu"
358
+
359
+ Returns
360
+ -------
361
+ typing.Union[dict, list, torch.Tensor]
362
+ Batch with all values moved to the specified device.
363
+ """
364
+ if isinstance(batch, dict):
365
+ batch = flatten(batch)
366
+ for key, val in batch.items():
367
+ try:
368
+ batch[key] = val.to(device)
369
+ except:
370
+ pass
371
+ batch = unflatten(batch)
372
+ elif torch.is_tensor(batch):
373
+ batch = batch.to(device)
374
+ elif isinstance(batch, list):
375
+ for i in range(len(batch)):
376
+ try:
377
+ batch[i] = batch[i].to(device)
378
+ except:
379
+ pass
380
+ return batch
381
+
382
+
383
+ def sample_from_dist(dist_tuple: tuple, state: np.random.RandomState = None):
384
+ """Samples from a distribution defined by a tuple. The first
385
+ item in the tuple is the distribution type, and the rest of the
386
+ items are arguments to that distribution. The distribution function
387
+ is gotten from the ``np.random.RandomState`` object.
388
+
389
+ Parameters
390
+ ----------
391
+ dist_tuple : tuple
392
+ Distribution tuple
393
+ state : np.random.RandomState, optional
394
+ Random state, or seed to use, by default None
395
+
396
+ Returns
397
+ -------
398
+ typing.Union[float, int, str]
399
+ Draw from the distribution.
400
+
401
+ Examples
402
+ --------
403
+ Sample from a uniform distribution:
404
+
405
+ >>> dist_tuple = ("uniform", 0, 1)
406
+ >>> sample_from_dist(dist_tuple)
407
+
408
+ Sample from a constant distribution:
409
+
410
+ >>> dist_tuple = ("const", 0)
411
+ >>> sample_from_dist(dist_tuple)
412
+
413
+ Sample from a normal distribution:
414
+
415
+ >>> dist_tuple = ("normal", 0, 0.5)
416
+ >>> sample_from_dist(dist_tuple)
417
+
418
+ """
419
+ if dist_tuple[0] == "const":
420
+ return dist_tuple[1]
421
+ state = random_state(state)
422
+ dist_fn = getattr(state, dist_tuple[0])
423
+ return dist_fn(*dist_tuple[1:])
424
+
425
+
426
+ def collate(list_of_dicts: list, n_splits: int = None):
427
+ """Collates a list of dictionaries (e.g. as returned by a
428
+ dataloader) into a dictionary with batched values. This routine
429
+ uses the default torch collate function for everything
430
+ except AudioSignal objects, which are handled by the
431
+ :py:func:`audiotools.core.audio_signal.AudioSignal.batch`
432
+ function.
433
+
434
+ This function takes n_splits to enable splitting a batch
435
+ into multiple sub-batches for the purposes of gradient accumulation,
436
+ etc.
437
+
438
+ Parameters
439
+ ----------
440
+ list_of_dicts : list
441
+ List of dictionaries to be collated.
442
+ n_splits : int
443
+ Number of splits to make when creating the batches (split into
444
+ sub-batches). Useful for things like gradient accumulation.
445
+
446
+ Returns
447
+ -------
448
+ dict
449
+ Dictionary containing batched data.
450
+ """
451
+
452
+ from . import AudioSignal
453
+
454
+ batches = []
455
+ list_len = len(list_of_dicts)
456
+
457
+ return_list = False if n_splits is None else True
458
+ n_splits = 1 if n_splits is None else n_splits
459
+ n_items = int(math.ceil(list_len / n_splits))
460
+
461
+ for i in range(0, list_len, n_items):
462
+ # Flatten the dictionaries to avoid recursion.
463
+ list_of_dicts_ = [flatten(d) for d in list_of_dicts[i : i + n_items]]
464
+ dict_of_lists = {
465
+ k: [dic[k] for dic in list_of_dicts_] for k in list_of_dicts_[0]
466
+ }
467
+
468
+ batch = {}
469
+ for k, v in dict_of_lists.items():
470
+ if isinstance(v, list):
471
+ if all(isinstance(s, AudioSignal) for s in v):
472
+ batch[k] = AudioSignal.batch(v, pad_signals=True)
473
+ else:
474
+ # Borrow the default collate fn from torch.
475
+ batch[k] = torch.utils.data._utils.collate.default_collate(v)
476
+ batches.append(unflatten(batch))
477
+
478
+ batches = batches[0] if not return_list else batches
479
+ return batches
480
+
481
+
482
+ BASE_SIZE = 864
483
+ DEFAULT_FIG_SIZE = (9, 3)
484
+
485
+
486
+ def format_figure(
487
+ fig_size: tuple = None,
488
+ title: str = None,
489
+ fig=None,
490
+ format_axes: bool = True,
491
+ format: bool = True,
492
+ font_color: str = "white",
493
+ ):
494
+ """Prettifies the spectrogram and waveform plots. A title
495
+ can be inset into the top right corner, and the axes can be
496
+ inset into the figure, allowing the data to take up the entire
497
+ image. Used in
498
+
499
+ - :py:func:`audiotools.core.display.DisplayMixin.specshow`
500
+ - :py:func:`audiotools.core.display.DisplayMixin.waveplot`
501
+ - :py:func:`audiotools.core.display.DisplayMixin.wavespec`
502
+
503
+ Parameters
504
+ ----------
505
+ fig_size : tuple, optional
506
+ Size of figure, by default (9, 3)
507
+ title : str, optional
508
+ Title to inset in top right, by default None
509
+ fig : matplotlib.figure.Figure, optional
510
+ Figure object, if None ``plt.gcf()`` will be used, by default None
511
+ format_axes : bool, optional
512
+ Format the axes to be inside the figure, by default True
513
+ format : bool, optional
514
+ This formatting can be skipped entirely by passing ``format=False``
515
+ to any of the plotting functions that use this formater, by default True
516
+ font_color : str, optional
517
+ Color of font of axes, by default "white"
518
+ """
519
+ import matplotlib
520
+ import matplotlib.pyplot as plt
521
+
522
+ if fig_size is None:
523
+ fig_size = DEFAULT_FIG_SIZE
524
+ if not format:
525
+ return
526
+ if fig is None:
527
+ fig = plt.gcf()
528
+ fig.set_size_inches(*fig_size)
529
+ axs = fig.axes
530
+
531
+ pixels = (fig.get_size_inches() * fig.dpi)[0]
532
+ font_scale = pixels / BASE_SIZE
533
+
534
+ if format_axes:
535
+ axs = fig.axes
536
+
537
+ for ax in axs:
538
+ ymin, _ = ax.get_ylim()
539
+ xmin, _ = ax.get_xlim()
540
+
541
+ ticks = ax.get_yticks()
542
+ for t in ticks[2:-1]:
543
+ t = axs[0].annotate(
544
+ f"{(t / 1000):2.1f}k",
545
+ xy=(xmin, t),
546
+ xycoords="data",
547
+ xytext=(5, -5),
548
+ textcoords="offset points",
549
+ ha="left",
550
+ va="top",
551
+ color=font_color,
552
+ fontsize=12 * font_scale,
553
+ alpha=0.75,
554
+ )
555
+
556
+ ticks = ax.get_xticks()[2:]
557
+ for t in ticks[:-1]:
558
+ t = axs[0].annotate(
559
+ f"{t:2.1f}s",
560
+ xy=(t, ymin),
561
+ xycoords="data",
562
+ xytext=(5, 5),
563
+ textcoords="offset points",
564
+ ha="center",
565
+ va="bottom",
566
+ color=font_color,
567
+ fontsize=12 * font_scale,
568
+ alpha=0.75,
569
+ )
570
+
571
+ ax.margins(0, 0)
572
+ ax.set_axis_off()
573
+ ax.xaxis.set_major_locator(plt.NullLocator())
574
+ ax.yaxis.set_major_locator(plt.NullLocator())
575
+
576
+ plt.subplots_adjust(top=1, bottom=0, right=1, left=0, hspace=0, wspace=0)
577
+
578
+ if title is not None:
579
+ t = axs[0].annotate(
580
+ title,
581
+ xy=(1, 1),
582
+ xycoords="axes fraction",
583
+ fontsize=20 * font_scale,
584
+ xytext=(-5, -5),
585
+ textcoords="offset points",
586
+ ha="right",
587
+ va="top",
588
+ color="white",
589
+ )
590
+ t.set_bbox(dict(facecolor="black", alpha=0.5, edgecolor="black"))
591
+
592
+
593
+ def generate_chord_dataset(
594
+ max_voices: int = 8,
595
+ sample_rate: int = 44100,
596
+ num_items: int = 5,
597
+ duration: float = 1.0,
598
+ min_note: str = "C2",
599
+ max_note: str = "C6",
600
+ output_dir: Path = "chords",
601
+ ):
602
+ """
603
+ Generates a toy multitrack dataset of chords, synthesized from sine waves.
604
+
605
+
606
+ Parameters
607
+ ----------
608
+ max_voices : int, optional
609
+ Maximum number of voices in a chord, by default 8
610
+ sample_rate : int, optional
611
+ Sample rate of audio, by default 44100
612
+ num_items : int, optional
613
+ Number of items to generate, by default 5
614
+ duration : float, optional
615
+ Duration of each item, by default 1.0
616
+ min_note : str, optional
617
+ Minimum note in the dataset, by default "C2"
618
+ max_note : str, optional
619
+ Maximum note in the dataset, by default "C6"
620
+ output_dir : Path, optional
621
+ Directory to save the dataset, by default "chords"
622
+
623
+ """
624
+ import librosa
625
+ from . import AudioSignal
626
+ from ..data.preprocess import create_csv
627
+
628
+ min_midi = librosa.note_to_midi(min_note)
629
+ max_midi = librosa.note_to_midi(max_note)
630
+
631
+ tracks = []
632
+ for idx in range(num_items):
633
+ track = {}
634
+ # figure out how many voices to put in this track
635
+ num_voices = random.randint(1, max_voices)
636
+ for voice_idx in range(num_voices):
637
+ # choose some random params
638
+ midinote = random.randint(min_midi, max_midi)
639
+ dur = random.uniform(0.85 * duration, duration)
640
+
641
+ sig = AudioSignal.wave(
642
+ frequency=librosa.midi_to_hz(midinote),
643
+ duration=dur,
644
+ sample_rate=sample_rate,
645
+ shape="sine",
646
+ )
647
+ track[f"voice_{voice_idx}"] = sig
648
+ tracks.append(track)
649
+
650
+ # save the tracks to disk
651
+ output_dir = Path(output_dir)
652
+ output_dir.mkdir(exist_ok=True)
653
+ for idx, track in enumerate(tracks):
654
+ track_dir = output_dir / f"track_{idx}"
655
+ track_dir.mkdir(exist_ok=True)
656
+ for voice_name, sig in track.items():
657
+ sig.write(track_dir / f"{voice_name}.wav")
658
+
659
+ all_voices = list(set([k for track in tracks for k in track.keys()]))
660
+ voice_lists = {voice: [] for voice in all_voices}
661
+ for track in tracks:
662
+ for voice_name in all_voices:
663
+ if voice_name in track:
664
+ voice_lists[voice_name].append(track[voice_name].path_to_file)
665
+ else:
666
+ voice_lists[voice_name].append("")
667
+
668
+ for voice_name, paths in voice_lists.items():
669
+ create_csv(paths, output_dir / f"{voice_name}.csv", loudness=True)
670
+
671
+ return output_dir
audiotools/core/whisper.py ADDED
@@ -0,0 +1,97 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+
3
+
4
+ class WhisperMixin:
5
+ is_initialized = False
6
+
7
+ def setup_whisper(
8
+ self,
9
+ pretrained_model_name_or_path: str = "openai/whisper-base.en",
10
+ device: str = torch.device("cuda" if torch.cuda.is_available() else "cpu"),
11
+ ):
12
+ from transformers import WhisperForConditionalGeneration
13
+ from transformers import WhisperProcessor
14
+
15
+ self.whisper_device = device
16
+ self.whisper_processor = WhisperProcessor.from_pretrained(
17
+ pretrained_model_name_or_path
18
+ )
19
+ self.whisper_model = WhisperForConditionalGeneration.from_pretrained(
20
+ pretrained_model_name_or_path
21
+ ).to(self.whisper_device)
22
+ self.is_initialized = True
23
+
24
+ def get_whisper_features(self) -> torch.Tensor:
25
+ """Preprocess audio signal as per the whisper model's training config.
26
+
27
+ Returns
28
+ -------
29
+ torch.Tensor
30
+ The prepinput features of the audio signal. Shape: (1, channels, seq_len)
31
+ """
32
+ import torch
33
+
34
+ if not self.is_initialized:
35
+ self.setup_whisper()
36
+
37
+ signal = self.to(self.device)
38
+ raw_speech = list(
39
+ (
40
+ signal.clone()
41
+ .resample(self.whisper_processor.feature_extractor.sampling_rate)
42
+ .audio_data[:, 0, :]
43
+ .numpy()
44
+ )
45
+ )
46
+
47
+ with torch.inference_mode():
48
+ input_features = self.whisper_processor(
49
+ raw_speech,
50
+ sampling_rate=self.whisper_processor.feature_extractor.sampling_rate,
51
+ return_tensors="pt",
52
+ ).input_features
53
+
54
+ return input_features
55
+
56
+ def get_whisper_transcript(self) -> str:
57
+ """Get the transcript of the audio signal using the whisper model.
58
+
59
+ Returns
60
+ -------
61
+ str
62
+ The transcript of the audio signal, including special tokens such as <|startoftranscript|> and <|endoftext|>.
63
+ """
64
+
65
+ if not self.is_initialized:
66
+ self.setup_whisper()
67
+
68
+ input_features = self.get_whisper_features()
69
+
70
+ with torch.inference_mode():
71
+ input_features = input_features.to(self.whisper_device)
72
+ generated_ids = self.whisper_model.generate(inputs=input_features)
73
+
74
+ transcription = self.whisper_processor.batch_decode(generated_ids)
75
+ return transcription[0]
76
+
77
+ def get_whisper_embeddings(self) -> torch.Tensor:
78
+ """Get the last hidden state embeddings of the audio signal using the whisper model.
79
+
80
+ Returns
81
+ -------
82
+ torch.Tensor
83
+ The Whisper embeddings of the audio signal. Shape: (1, seq_len, hidden_size)
84
+ """
85
+ import torch
86
+
87
+ if not self.is_initialized:
88
+ self.setup_whisper()
89
+
90
+ input_features = self.get_whisper_features()
91
+ encoder = self.whisper_model.get_encoder()
92
+
93
+ with torch.inference_mode():
94
+ input_features = input_features.to(self.whisper_device)
95
+ embeddings = encoder(input_features)
96
+
97
+ return embeddings.last_hidden_state
audiotools/data/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from . import datasets
2
+ from . import preprocess
3
+ from . import transforms
audiotools/data/datasets.py ADDED
@@ -0,0 +1,517 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from typing import Callable
3
+ from typing import Dict
4
+ from typing import List
5
+ from typing import Union
6
+
7
+ import numpy as np
8
+ from torch.utils.data import SequentialSampler
9
+ from torch.utils.data.distributed import DistributedSampler
10
+
11
+ from ..core import AudioSignal
12
+ from ..core import util
13
+
14
+
15
+ class AudioLoader:
16
+ """Loads audio endlessly from a list of audio sources
17
+ containing paths to audio files. Audio sources can be
18
+ folders full of audio files (which are found via file
19
+ extension) or by providing a CSV file which contains paths
20
+ to audio files.
21
+
22
+ Parameters
23
+ ----------
24
+ sources : List[str], optional
25
+ Sources containing folders, or CSVs with
26
+ paths to audio files, by default None
27
+ weights : List[float], optional
28
+ Weights to sample audio files from each source, by default None
29
+ relative_path : str, optional
30
+ Path audio should be loaded relative to, by default ""
31
+ transform : Callable, optional
32
+ Transform to instantiate alongside audio sample,
33
+ by default None
34
+ ext : List[str]
35
+ List of extensions to find audio within each source by. Can
36
+ also be a file name (e.g. "vocals.wav"). by default
37
+ ``['.wav', '.flac', '.mp3', '.mp4']``.
38
+ shuffle: bool
39
+ Whether to shuffle the files within the dataloader. Defaults to True.
40
+ shuffle_state: int
41
+ State to use to seed the shuffle of the files.
42
+ """
43
+
44
+ def __init__(
45
+ self,
46
+ sources: List[str] = None,
47
+ weights: List[float] = None,
48
+ transform: Callable = None,
49
+ relative_path: str = "",
50
+ ext: List[str] = util.AUDIO_EXTENSIONS,
51
+ shuffle: bool = True,
52
+ shuffle_state: int = 0,
53
+ ):
54
+ self.audio_lists = util.read_sources(
55
+ sources, relative_path=relative_path, ext=ext
56
+ )
57
+
58
+ self.audio_indices = [
59
+ (src_idx, item_idx)
60
+ for src_idx, src in enumerate(self.audio_lists)
61
+ for item_idx in range(len(src))
62
+ ]
63
+ if shuffle:
64
+ state = util.random_state(shuffle_state)
65
+ state.shuffle(self.audio_indices)
66
+
67
+ self.sources = sources
68
+ self.weights = weights
69
+ self.transform = transform
70
+
71
+ def __call__(
72
+ self,
73
+ state,
74
+ sample_rate: int,
75
+ duration: float,
76
+ loudness_cutoff: float = -40,
77
+ num_channels: int = 1,
78
+ offset: float = None,
79
+ source_idx: int = None,
80
+ item_idx: int = None,
81
+ global_idx: int = None,
82
+ ):
83
+ if source_idx is not None and item_idx is not None:
84
+ try:
85
+ audio_info = self.audio_lists[source_idx][item_idx]
86
+ except:
87
+ audio_info = {"path": "none"}
88
+ elif global_idx is not None:
89
+ source_idx, item_idx = self.audio_indices[
90
+ global_idx % len(self.audio_indices)
91
+ ]
92
+ audio_info = self.audio_lists[source_idx][item_idx]
93
+ else:
94
+ audio_info, source_idx, item_idx = util.choose_from_list_of_lists(
95
+ state, self.audio_lists, p=self.weights
96
+ )
97
+
98
+ path = audio_info["path"]
99
+ signal = AudioSignal.zeros(duration, sample_rate, num_channels)
100
+
101
+ if path != "none":
102
+ if offset is None:
103
+ signal = AudioSignal.salient_excerpt(
104
+ path,
105
+ duration=duration,
106
+ state=state,
107
+ loudness_cutoff=loudness_cutoff,
108
+ )
109
+ else:
110
+ signal = AudioSignal(
111
+ path,
112
+ offset=offset,
113
+ duration=duration,
114
+ )
115
+
116
+ if num_channels == 1:
117
+ signal = signal.to_mono()
118
+ signal = signal.resample(sample_rate)
119
+
120
+ if signal.duration < duration:
121
+ signal = signal.zero_pad_to(int(duration * sample_rate))
122
+
123
+ for k, v in audio_info.items():
124
+ signal.metadata[k] = v
125
+
126
+ item = {
127
+ "signal": signal,
128
+ "source_idx": source_idx,
129
+ "item_idx": item_idx,
130
+ "source": str(self.sources[source_idx]),
131
+ "path": str(path),
132
+ }
133
+ if self.transform is not None:
134
+ item["transform_args"] = self.transform.instantiate(state, signal=signal)
135
+ return item
136
+
137
+
138
+ def default_matcher(x, y):
139
+ return Path(x).parent == Path(y).parent
140
+
141
+
142
+ def align_lists(lists, matcher: Callable = default_matcher):
143
+ longest_list = lists[np.argmax([len(l) for l in lists])]
144
+ for i, x in enumerate(longest_list):
145
+ for l in lists:
146
+ if i >= len(l):
147
+ l.append({"path": "none"})
148
+ elif not matcher(l[i]["path"], x["path"]):
149
+ l.insert(i, {"path": "none"})
150
+ return lists
151
+
152
+
153
+ class AudioDataset:
154
+ """Loads audio from multiple loaders (with associated transforms)
155
+ for a specified number of samples. Excerpts are drawn randomly
156
+ of the specified duration, above a specified loudness threshold
157
+ and are resampled on the fly to the desired sample rate
158
+ (if it is different from the audio source sample rate).
159
+
160
+ This takes either a single AudioLoader object,
161
+ a dictionary of AudioLoader objects, or a dictionary of AudioLoader
162
+ objects. Each AudioLoader is called by the dataset, and the
163
+ result is placed in the output dictionary. A transform can also be
164
+ specified for the entire dataset, rather than for each specific
165
+ loader. This transform can be applied to the output of all the
166
+ loaders if desired.
167
+
168
+ AudioLoader objects can be specified as aligned, which means the
169
+ loaders correspond to multitrack audio (e.g. a vocals, bass,
170
+ drums, and other loader for multitrack music mixtures).
171
+
172
+
173
+ Parameters
174
+ ----------
175
+ loaders : Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]]
176
+ AudioLoaders to sample audio from.
177
+ sample_rate : int
178
+ Desired sample rate.
179
+ n_examples : int, optional
180
+ Number of examples (length of dataset), by default 1000
181
+ duration : float, optional
182
+ Duration of audio samples, by default 0.5
183
+ loudness_cutoff : float, optional
184
+ Loudness cutoff threshold for audio samples, by default -40
185
+ num_channels : int, optional
186
+ Number of channels in output audio, by default 1
187
+ transform : Callable, optional
188
+ Transform to instantiate alongside each dataset item, by default None
189
+ aligned : bool, optional
190
+ Whether the loaders should be sampled in an aligned manner (e.g. same
191
+ offset, duration, and matched file name), by default False
192
+ shuffle_loaders : bool, optional
193
+ Whether to shuffle the loaders before sampling from them, by default False
194
+ matcher : Callable
195
+ How to match files from adjacent audio lists (e.g. for a multitrack audio loader),
196
+ by default uses the parent directory of each file.
197
+ without_replacement : bool
198
+ Whether to choose files with or without replacement, by default True.
199
+
200
+
201
+ Examples
202
+ --------
203
+ >>> from audiotools.data.datasets import AudioLoader
204
+ >>> from audiotools.data.datasets import AudioDataset
205
+ >>> from audiotools import transforms as tfm
206
+ >>> import numpy as np
207
+ >>>
208
+ >>> loaders = [
209
+ >>> AudioLoader(
210
+ >>> sources=[f"tests/audio/spk"],
211
+ >>> transform=tfm.Equalizer(),
212
+ >>> ext=["wav"],
213
+ >>> )
214
+ >>> for i in range(5)
215
+ >>> ]
216
+ >>>
217
+ >>> dataset = AudioDataset(
218
+ >>> loaders = loaders,
219
+ >>> sample_rate = 44100,
220
+ >>> duration = 1.0,
221
+ >>> transform = tfm.RescaleAudio(),
222
+ >>> )
223
+ >>>
224
+ >>> item = dataset[np.random.randint(len(dataset))]
225
+ >>>
226
+ >>> for i in range(len(loaders)):
227
+ >>> item[i]["signal"] = loaders[i].transform(
228
+ >>> item[i]["signal"], **item[i]["transform_args"]
229
+ >>> )
230
+ >>> item[i]["signal"].widget(i)
231
+ >>>
232
+ >>> mix = sum([item[i]["signal"] for i in range(len(loaders))])
233
+ >>> mix = dataset.transform(mix, **item["transform_args"])
234
+ >>> mix.widget("mix")
235
+
236
+ Below is an example of how one could load MUSDB multitrack data:
237
+
238
+ >>> import audiotools as at
239
+ >>> from pathlib import Path
240
+ >>> from audiotools import transforms as tfm
241
+ >>> import numpy as np
242
+ >>> import torch
243
+ >>>
244
+ >>> def build_dataset(
245
+ >>> sample_rate: int = 44100,
246
+ >>> duration: float = 5.0,
247
+ >>> musdb_path: str = "~/.data/musdb/",
248
+ >>> ):
249
+ >>> musdb_path = Path(musdb_path).expanduser()
250
+ >>> loaders = {
251
+ >>> src: at.datasets.AudioLoader(
252
+ >>> sources=[musdb_path],
253
+ >>> transform=tfm.Compose(
254
+ >>> tfm.VolumeNorm(("uniform", -20, -10)),
255
+ >>> tfm.Silence(prob=0.1),
256
+ >>> ),
257
+ >>> ext=[f"{src}.wav"],
258
+ >>> )
259
+ >>> for src in ["vocals", "bass", "drums", "other"]
260
+ >>> }
261
+ >>>
262
+ >>> dataset = at.datasets.AudioDataset(
263
+ >>> loaders=loaders,
264
+ >>> sample_rate=sample_rate,
265
+ >>> duration=duration,
266
+ >>> num_channels=1,
267
+ >>> aligned=True,
268
+ >>> transform=tfm.RescaleAudio(),
269
+ >>> shuffle_loaders=True,
270
+ >>> )
271
+ >>> return dataset, list(loaders.keys())
272
+ >>>
273
+ >>> train_data, sources = build_dataset()
274
+ >>> dataloader = torch.utils.data.DataLoader(
275
+ >>> train_data,
276
+ >>> batch_size=16,
277
+ >>> num_workers=0,
278
+ >>> collate_fn=train_data.collate,
279
+ >>> )
280
+ >>> batch = next(iter(dataloader))
281
+ >>>
282
+ >>> for k in sources:
283
+ >>> src = batch[k]
284
+ >>> src["transformed"] = train_data.loaders[k].transform(
285
+ >>> src["signal"].clone(), **src["transform_args"]
286
+ >>> )
287
+ >>>
288
+ >>> mixture = sum(batch[k]["transformed"] for k in sources)
289
+ >>> mixture = train_data.transform(mixture, **batch["transform_args"])
290
+ >>>
291
+ >>> # Say a model takes the mix and gives back (n_batch, n_src, n_time).
292
+ >>> # Construct the targets:
293
+ >>> targets = at.AudioSignal.batch([batch[k]["transformed"] for k in sources], dim=1)
294
+
295
+ Similarly, here's example code for loading Slakh data:
296
+
297
+ >>> import audiotools as at
298
+ >>> from pathlib import Path
299
+ >>> from audiotools import transforms as tfm
300
+ >>> import numpy as np
301
+ >>> import torch
302
+ >>> import glob
303
+ >>>
304
+ >>> def build_dataset(
305
+ >>> sample_rate: int = 16000,
306
+ >>> duration: float = 10.0,
307
+ >>> slakh_path: str = "~/.data/slakh/",
308
+ >>> ):
309
+ >>> slakh_path = Path(slakh_path).expanduser()
310
+ >>>
311
+ >>> # Find the max number of sources in Slakh
312
+ >>> src_names = [x.name for x in list(slakh_path.glob("**/*.wav")) if "S" in str(x.name)]
313
+ >>> n_sources = len(list(set(src_names)))
314
+ >>>
315
+ >>> loaders = {
316
+ >>> f"S{i:02d}": at.datasets.AudioLoader(
317
+ >>> sources=[slakh_path],
318
+ >>> transform=tfm.Compose(
319
+ >>> tfm.VolumeNorm(("uniform", -20, -10)),
320
+ >>> tfm.Silence(prob=0.1),
321
+ >>> ),
322
+ >>> ext=[f"S{i:02d}.wav"],
323
+ >>> )
324
+ >>> for i in range(n_sources)
325
+ >>> }
326
+ >>> dataset = at.datasets.AudioDataset(
327
+ >>> loaders=loaders,
328
+ >>> sample_rate=sample_rate,
329
+ >>> duration=duration,
330
+ >>> num_channels=1,
331
+ >>> aligned=True,
332
+ >>> transform=tfm.RescaleAudio(),
333
+ >>> shuffle_loaders=False,
334
+ >>> )
335
+ >>>
336
+ >>> return dataset, list(loaders.keys())
337
+ >>>
338
+ >>> train_data, sources = build_dataset()
339
+ >>> dataloader = torch.utils.data.DataLoader(
340
+ >>> train_data,
341
+ >>> batch_size=16,
342
+ >>> num_workers=0,
343
+ >>> collate_fn=train_data.collate,
344
+ >>> )
345
+ >>> batch = next(iter(dataloader))
346
+ >>>
347
+ >>> for k in sources:
348
+ >>> src = batch[k]
349
+ >>> src["transformed"] = train_data.loaders[k].transform(
350
+ >>> src["signal"].clone(), **src["transform_args"]
351
+ >>> )
352
+ >>>
353
+ >>> mixture = sum(batch[k]["transformed"] for k in sources)
354
+ >>> mixture = train_data.transform(mixture, **batch["transform_args"])
355
+
356
+ """
357
+
358
+ def __init__(
359
+ self,
360
+ loaders: Union[AudioLoader, List[AudioLoader], Dict[str, AudioLoader]],
361
+ sample_rate: int,
362
+ n_examples: int = 1000,
363
+ duration: float = 0.5,
364
+ offset: float = None,
365
+ loudness_cutoff: float = -40,
366
+ num_channels: int = 1,
367
+ transform: Callable = None,
368
+ aligned: bool = False,
369
+ shuffle_loaders: bool = False,
370
+ matcher: Callable = default_matcher,
371
+ without_replacement: bool = True,
372
+ ):
373
+ # Internally we convert loaders to a dictionary
374
+ if isinstance(loaders, list):
375
+ loaders = {i: l for i, l in enumerate(loaders)}
376
+ elif isinstance(loaders, AudioLoader):
377
+ loaders = {0: loaders}
378
+
379
+ self.loaders = loaders
380
+ self.loudness_cutoff = loudness_cutoff
381
+ self.num_channels = num_channels
382
+
383
+ self.length = n_examples
384
+ self.transform = transform
385
+ self.sample_rate = sample_rate
386
+ self.duration = duration
387
+ self.offset = offset
388
+ self.aligned = aligned
389
+ self.shuffle_loaders = shuffle_loaders
390
+ self.without_replacement = without_replacement
391
+
392
+ if aligned:
393
+ loaders_list = list(loaders.values())
394
+ for i in range(len(loaders_list[0].audio_lists)):
395
+ input_lists = [l.audio_lists[i] for l in loaders_list]
396
+ # Alignment happens in-place
397
+ align_lists(input_lists, matcher)
398
+
399
+ def __getitem__(self, idx):
400
+ state = util.random_state(idx)
401
+ offset = None if self.offset is None else self.offset
402
+ item = {}
403
+
404
+ keys = list(self.loaders.keys())
405
+ if self.shuffle_loaders:
406
+ state.shuffle(keys)
407
+
408
+ loader_kwargs = {
409
+ "state": state,
410
+ "sample_rate": self.sample_rate,
411
+ "duration": self.duration,
412
+ "loudness_cutoff": self.loudness_cutoff,
413
+ "num_channels": self.num_channels,
414
+ "global_idx": idx if self.without_replacement else None,
415
+ }
416
+
417
+ # Draw item from first loader
418
+ loader = self.loaders[keys[0]]
419
+ item[keys[0]] = loader(**loader_kwargs)
420
+
421
+ for key in keys[1:]:
422
+ loader = self.loaders[key]
423
+ if self.aligned:
424
+ # Path mapper takes the current loader + everything
425
+ # returned by the first loader.
426
+ offset = item[keys[0]]["signal"].metadata["offset"]
427
+ loader_kwargs.update(
428
+ {
429
+ "offset": offset,
430
+ "source_idx": item[keys[0]]["source_idx"],
431
+ "item_idx": item[keys[0]]["item_idx"],
432
+ }
433
+ )
434
+ item[key] = loader(**loader_kwargs)
435
+
436
+ # Sort dictionary back into original order
437
+ keys = list(self.loaders.keys())
438
+ item = {k: item[k] for k in keys}
439
+
440
+ item["idx"] = idx
441
+ if self.transform is not None:
442
+ item["transform_args"] = self.transform.instantiate(
443
+ state=state, signal=item[keys[0]]["signal"]
444
+ )
445
+
446
+ # If there's only one loader, pop it up
447
+ # to the main dictionary, instead of keeping it
448
+ # nested.
449
+ if len(keys) == 1:
450
+ item.update(item.pop(keys[0]))
451
+
452
+ return item
453
+
454
+ def __len__(self):
455
+ return self.length
456
+
457
+ @staticmethod
458
+ def collate(list_of_dicts: Union[list, dict], n_splits: int = None):
459
+ """Collates items drawn from this dataset. Uses
460
+ :py:func:`audiotools.core.util.collate`.
461
+
462
+ Parameters
463
+ ----------
464
+ list_of_dicts : typing.Union[list, dict]
465
+ Data drawn from each item.
466
+ n_splits : int
467
+ Number of splits to make when creating the batches (split into
468
+ sub-batches). Useful for things like gradient accumulation.
469
+
470
+ Returns
471
+ -------
472
+ dict
473
+ Dictionary of batched data.
474
+ """
475
+ return util.collate(list_of_dicts, n_splits=n_splits)
476
+
477
+
478
+ class ConcatDataset(AudioDataset):
479
+ def __init__(self, datasets: list):
480
+ self.datasets = datasets
481
+
482
+ def __len__(self):
483
+ return sum([len(d) for d in self.datasets])
484
+
485
+ def __getitem__(self, idx):
486
+ dataset = self.datasets[idx % len(self.datasets)]
487
+ return dataset[idx // len(self.datasets)]
488
+
489
+
490
+ class ResumableDistributedSampler(DistributedSampler): # pragma: no cover
491
+ """Distributed sampler that can be resumed from a given start index."""
492
+
493
+ def __init__(self, dataset, start_idx: int = None, **kwargs):
494
+ super().__init__(dataset, **kwargs)
495
+ # Start index, allows to resume an experiment at the index it was
496
+ self.start_idx = start_idx // self.num_replicas if start_idx is not None else 0
497
+
498
+ def __iter__(self):
499
+ for i, idx in enumerate(super().__iter__()):
500
+ if i >= self.start_idx:
501
+ yield idx
502
+ self.start_idx = 0 # set the index back to 0 so for the next epoch
503
+
504
+
505
+ class ResumableSequentialSampler(SequentialSampler): # pragma: no cover
506
+ """Sequential sampler that can be resumed from a given start index."""
507
+
508
+ def __init__(self, dataset, start_idx: int = None, **kwargs):
509
+ super().__init__(dataset, **kwargs)
510
+ # Start index, allows to resume an experiment at the index it was
511
+ self.start_idx = start_idx if start_idx is not None else 0
512
+
513
+ def __iter__(self):
514
+ for i, idx in enumerate(super().__iter__()):
515
+ if i >= self.start_idx:
516
+ yield idx
517
+ self.start_idx = 0 # set the index back to 0 so for the next epoch
audiotools/data/preprocess.py ADDED
@@ -0,0 +1,81 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import csv
2
+ import os
3
+ from pathlib import Path
4
+
5
+ from tqdm import tqdm
6
+
7
+ from ..core import AudioSignal
8
+
9
+
10
+ def create_csv(
11
+ audio_files: list, output_csv: Path, loudness: bool = False, data_path: str = None
12
+ ):
13
+ """Converts a folder of audio files to a CSV file. If ``loudness = True``,
14
+ the output of this function will create a CSV file that looks something
15
+ like:
16
+
17
+ .. csv-table::
18
+ :header: path,loudness
19
+
20
+ daps/produced/f1_script1_produced.wav,-16.299999237060547
21
+ daps/produced/f1_script2_produced.wav,-16.600000381469727
22
+ daps/produced/f1_script3_produced.wav,-17.299999237060547
23
+ daps/produced/f1_script4_produced.wav,-16.100000381469727
24
+ daps/produced/f1_script5_produced.wav,-16.700000762939453
25
+ daps/produced/f3_script1_produced.wav,-16.5
26
+
27
+ .. note::
28
+ The paths above are written relative to the ``data_path`` argument
29
+ which defaults to the environment variable ``PATH_TO_DATA`` if
30
+ it isn't passed to this function, and defaults to the empty string
31
+ if that environment variable is not set.
32
+
33
+ You can produce a CSV file from a directory of audio files via:
34
+
35
+ >>> import audiotools
36
+ >>> directory = ...
37
+ >>> audio_files = audiotools.util.find_audio(directory)
38
+ >>> output_path = "train.csv"
39
+ >>> audiotools.data.preprocess.create_csv(
40
+ >>> audio_files, output_csv, loudness=True
41
+ >>> )
42
+
43
+ Note that you can create empty rows in the CSV file by passing an empty
44
+ string or None in the ``audio_files`` list. This is useful if you want to
45
+ sync multiple CSV files in a multitrack setting. The loudness of these
46
+ empty rows will be set to -inf.
47
+
48
+ Parameters
49
+ ----------
50
+ audio_files : list
51
+ List of audio files.
52
+ output_csv : Path
53
+ Output CSV, with each row containing the relative path of every file
54
+ to ``data_path``, if specified (defaults to None).
55
+ loudness : bool
56
+ Compute loudness of entire file and store alongside path.
57
+ """
58
+
59
+ info = []
60
+ pbar = tqdm(audio_files)
61
+ for af in pbar:
62
+ af = Path(af)
63
+ pbar.set_description(f"Processing {af.name}")
64
+ _info = {}
65
+ if af.name == "":
66
+ _info["path"] = ""
67
+ if loudness:
68
+ _info["loudness"] = -float("inf")
69
+ else:
70
+ _info["path"] = af.relative_to(data_path) if data_path is not None else af
71
+ if loudness:
72
+ _info["loudness"] = AudioSignal(af).ffmpeg_loudness().item()
73
+
74
+ info.append(_info)
75
+
76
+ with open(output_csv, "w") as f:
77
+ writer = csv.DictWriter(f, fieldnames=list(info[0].keys()))
78
+ writer.writeheader()
79
+
80
+ for item in info:
81
+ writer.writerow(item)
audiotools/data/transforms.py ADDED
@@ -0,0 +1,1592 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import copy
2
+ from contextlib import contextmanager
3
+ from inspect import signature
4
+ from typing import List
5
+
6
+ import numpy as np
7
+ import torch
8
+ from flatten_dict import flatten
9
+ from flatten_dict import unflatten
10
+ from numpy.random import RandomState
11
+
12
+ from .. import ml
13
+ from ..core import AudioSignal
14
+ from ..core import util
15
+ from .datasets import AudioLoader
16
+
17
+ tt = torch.tensor
18
+ """Shorthand for converting things to torch.tensor."""
19
+
20
+
21
+ class BaseTransform:
22
+ """This is the base class for all transforms that are implemented
23
+ in this library. Transforms have two main operations: ``transform``
24
+ and ``instantiate``.
25
+
26
+ ``instantiate`` sets the parameters randomly
27
+ from distribution tuples for each parameter. For example, for the
28
+ ``BackgroundNoise`` transform, the signal-to-noise ratio (``snr``)
29
+ is chosen randomly by instantiate. By default, it chosen uniformly
30
+ between 10.0 and 30.0 (the tuple is set to ``("uniform", 10.0, 30.0)``).
31
+
32
+ ``transform`` applies the transform using the instantiated parameters.
33
+ A simple example is as follows:
34
+
35
+ >>> seed = 0
36
+ >>> signal = ...
37
+ >>> transform = transforms.NoiseFloor(db = ("uniform", -50.0, -30.0))
38
+ >>> kwargs = transform.instantiate()
39
+ >>> output = transform(signal.clone(), **kwargs)
40
+
41
+ By breaking apart the instantiation of parameters from the actual audio
42
+ processing of the transform, we can make things more reproducible, while
43
+ also applying the transform on batches of data efficiently on GPU,
44
+ rather than on individual audio samples.
45
+
46
+ .. note::
47
+ We call ``signal.clone()`` for the input to the ``transform`` function
48
+ because signals are modified in-place! If you don't clone the signal,
49
+ you will lose the original data.
50
+
51
+ Parameters
52
+ ----------
53
+ keys : list, optional
54
+ Keys that the transform looks for when
55
+ calling ``self.transform``, by default []. In general this is
56
+ set automatically, and you won't need to manipulate this argument.
57
+ name : str, optional
58
+ Name of this transform, used to identify it in the dictionary
59
+ produced by ``self.instantiate``, by default None
60
+ prob : float, optional
61
+ Probability of applying this transform, by default 1.0
62
+
63
+ Examples
64
+ --------
65
+
66
+ >>> seed = 0
67
+ >>>
68
+ >>> audio_path = "tests/audio/spk/f10_script4_produced.wav"
69
+ >>> signal = AudioSignal(audio_path, offset=10, duration=2)
70
+ >>> transform = tfm.Compose(
71
+ >>> [
72
+ >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
73
+ >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
74
+ >>> ],
75
+ >>> )
76
+ >>>
77
+ >>> kwargs = transform.instantiate(seed, signal)
78
+ >>> output = transform(signal, **kwargs)
79
+
80
+ """
81
+
82
+ def __init__(self, keys: list = [], name: str = None, prob: float = 1.0):
83
+ # Get keys from the _transform signature.
84
+ tfm_keys = list(signature(self._transform).parameters.keys())
85
+
86
+ # Filter out signal and kwargs keys.
87
+ ignore_keys = ["signal", "kwargs"]
88
+ tfm_keys = [k for k in tfm_keys if k not in ignore_keys]
89
+
90
+ # Combine keys specified by the child class, the keys found in
91
+ # _transform signature, and the mask key.
92
+ self.keys = keys + tfm_keys + ["mask"]
93
+
94
+ self.prob = prob
95
+
96
+ if name is None:
97
+ name = self.__class__.__name__
98
+ self.name = name
99
+
100
+ def _prepare(self, batch: dict):
101
+ sub_batch = batch[self.name]
102
+
103
+ for k in self.keys:
104
+ assert k in sub_batch.keys(), f"{k} not in batch"
105
+
106
+ return sub_batch
107
+
108
+ def _transform(self, signal):
109
+ return signal
110
+
111
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
112
+ return {}
113
+
114
+ @staticmethod
115
+ def apply_mask(batch: dict, mask: torch.Tensor):
116
+ """Applies a mask to the batch.
117
+
118
+ Parameters
119
+ ----------
120
+ batch : dict
121
+ Batch whose values will be masked in the ``transform`` pass.
122
+ mask : torch.Tensor
123
+ Mask to apply to batch.
124
+
125
+ Returns
126
+ -------
127
+ dict
128
+ A dictionary that contains values only where ``mask = True``.
129
+ """
130
+ masked_batch = {k: v[mask] for k, v in flatten(batch).items()}
131
+ return unflatten(masked_batch)
132
+
133
+ def transform(self, signal: AudioSignal, **kwargs):
134
+ """Apply the transform to the audio signal,
135
+ with given keyword arguments.
136
+
137
+ Parameters
138
+ ----------
139
+ signal : AudioSignal
140
+ Signal that will be modified by the transforms in-place.
141
+ kwargs: dict
142
+ Keyword arguments to the specific transforms ``self._transform``
143
+ function.
144
+
145
+ Returns
146
+ -------
147
+ AudioSignal
148
+ Transformed AudioSignal.
149
+
150
+ Examples
151
+ --------
152
+
153
+ >>> for seed in range(10):
154
+ >>> kwargs = transform.instantiate(seed, signal)
155
+ >>> output = transform(signal.clone(), **kwargs)
156
+
157
+ """
158
+ tfm_kwargs = self._prepare(kwargs)
159
+ mask = tfm_kwargs["mask"]
160
+
161
+ if torch.any(mask):
162
+ tfm_kwargs = self.apply_mask(tfm_kwargs, mask)
163
+ tfm_kwargs = {k: v for k, v in tfm_kwargs.items() if k != "mask"}
164
+ signal[mask] = self._transform(signal[mask], **tfm_kwargs)
165
+
166
+ return signal
167
+
168
+ def __call__(self, *args, **kwargs):
169
+ return self.transform(*args, **kwargs)
170
+
171
+ def instantiate(
172
+ self,
173
+ state: RandomState = None,
174
+ signal: AudioSignal = None,
175
+ ):
176
+ """Instantiates parameters for the transform.
177
+
178
+ Parameters
179
+ ----------
180
+ state : RandomState, optional
181
+ _description_, by default None
182
+ signal : AudioSignal, optional
183
+ _description_, by default None
184
+
185
+ Returns
186
+ -------
187
+ dict
188
+ Dictionary containing instantiated arguments for every keyword
189
+ argument to ``self._transform``.
190
+
191
+ Examples
192
+ --------
193
+
194
+ >>> for seed in range(10):
195
+ >>> kwargs = transform.instantiate(seed, signal)
196
+ >>> output = transform(signal.clone(), **kwargs)
197
+
198
+ """
199
+ state = util.random_state(state)
200
+
201
+ # Not all instantiates need the signal. Check if signal
202
+ # is needed before passing it in, so that the end-user
203
+ # doesn't need to have variables they're not using flowing
204
+ # into their function.
205
+ needs_signal = "signal" in set(signature(self._instantiate).parameters.keys())
206
+ kwargs = {}
207
+ if needs_signal:
208
+ kwargs = {"signal": signal}
209
+
210
+ # Instantiate the parameters for the transform.
211
+ params = self._instantiate(state, **kwargs)
212
+ for k in list(params.keys()):
213
+ v = params[k]
214
+ if isinstance(v, (AudioSignal, torch.Tensor, dict)):
215
+ params[k] = v
216
+ else:
217
+ params[k] = tt(v)
218
+ mask = state.rand() <= self.prob
219
+ params[f"mask"] = tt(mask)
220
+
221
+ # Put the params into a nested dictionary that will be
222
+ # used later when calling the transform. This is to avoid
223
+ # collisions in the dictionary.
224
+ params = {self.name: params}
225
+
226
+ return params
227
+
228
+ def batch_instantiate(
229
+ self,
230
+ states: list = None,
231
+ signal: AudioSignal = None,
232
+ ):
233
+ """Instantiates arguments for every item in a batch,
234
+ given a list of states. Each state in the list
235
+ corresponds to one item in the batch.
236
+
237
+ Parameters
238
+ ----------
239
+ states : list, optional
240
+ List of states, by default None
241
+ signal : AudioSignal, optional
242
+ AudioSignal to pass to the ``self.instantiate`` section
243
+ if it is needed for this transform, by default None
244
+
245
+ Returns
246
+ -------
247
+ dict
248
+ Collated dictionary of arguments.
249
+
250
+ Examples
251
+ --------
252
+
253
+ >>> batch_size = 4
254
+ >>> signal = AudioSignal(audio_path, offset=10, duration=2)
255
+ >>> signal_batch = AudioSignal.batch([signal.clone() for _ in range(batch_size)])
256
+ >>>
257
+ >>> states = [seed + idx for idx in list(range(batch_size))]
258
+ >>> kwargs = transform.batch_instantiate(states, signal_batch)
259
+ >>> batch_output = transform(signal_batch, **kwargs)
260
+ """
261
+ kwargs = []
262
+ for state in states:
263
+ kwargs.append(self.instantiate(state, signal))
264
+ kwargs = util.collate(kwargs)
265
+ return kwargs
266
+
267
+
268
+ class Identity(BaseTransform):
269
+ """This transform just returns the original signal."""
270
+
271
+ pass
272
+
273
+
274
+ class SpectralTransform(BaseTransform):
275
+ """Spectral transforms require STFT data to exist, since manipulations
276
+ of the STFT require the spectrogram. This just calls ``stft`` before
277
+ the transform is called, and calls ``istft`` after the transform is
278
+ called so that the audio data is written to after the spectral
279
+ manipulation.
280
+ """
281
+
282
+ def transform(self, signal, **kwargs):
283
+ signal.stft()
284
+ super().transform(signal, **kwargs)
285
+ signal.istft()
286
+ return signal
287
+
288
+
289
+ class Compose(BaseTransform):
290
+ """Compose applies transforms in sequence, one after the other. The
291
+ transforms are passed in as positional arguments or as a list like so:
292
+
293
+ >>> transform = tfm.Compose(
294
+ >>> [
295
+ >>> tfm.RoomImpulseResponse(sources=["tests/audio/irs.csv"]),
296
+ >>> tfm.BackgroundNoise(sources=["tests/audio/noises.csv"]),
297
+ >>> ],
298
+ >>> )
299
+
300
+ This will convolve the signal with a room impulse response, and then
301
+ add background noise to the signal. Instantiate instantiates
302
+ all the parameters for every transform in the transform list so the
303
+ interface for using the Compose transform is the same as everything
304
+ else:
305
+
306
+ >>> kwargs = transform.instantiate()
307
+ >>> output = transform(signal.clone(), **kwargs)
308
+
309
+ Under the hood, the transform maps each transform to a unique name
310
+ under the hood of the form ``{position}.{name}``, where ``position``
311
+ is the index of the transform in the list. ``Compose`` can nest
312
+ within other ``Compose`` transforms, like so:
313
+
314
+ >>> preprocess = transforms.Compose(
315
+ >>> tfm.GlobalVolumeNorm(),
316
+ >>> tfm.CrossTalk(),
317
+ >>> name="preprocess",
318
+ >>> )
319
+ >>> augment = transforms.Compose(
320
+ >>> tfm.RoomImpulseResponse(),
321
+ >>> tfm.BackgroundNoise(),
322
+ >>> name="augment",
323
+ >>> )
324
+ >>> postprocess = transforms.Compose(
325
+ >>> tfm.VolumeChange(),
326
+ >>> tfm.RescaleAudio(),
327
+ >>> tfm.ShiftPhase(),
328
+ >>> name="postprocess",
329
+ >>> )
330
+ >>> transform = transforms.Compose(preprocess, augment, postprocess),
331
+
332
+ This defines 3 composed transforms, and then composes them in sequence
333
+ with one another.
334
+
335
+ Parameters
336
+ ----------
337
+ *transforms : list
338
+ List of transforms to apply
339
+ name : str, optional
340
+ Name of this transform, used to identify it in the dictionary
341
+ produced by ``self.instantiate``, by default None
342
+ prob : float, optional
343
+ Probability of applying this transform, by default 1.0
344
+ """
345
+
346
+ def __init__(self, *transforms: list, name: str = None, prob: float = 1.0):
347
+ if isinstance(transforms[0], list):
348
+ transforms = transforms[0]
349
+
350
+ for i, tfm in enumerate(transforms):
351
+ tfm.name = f"{i}.{tfm.name}"
352
+
353
+ keys = [tfm.name for tfm in transforms]
354
+ super().__init__(keys=keys, name=name, prob=prob)
355
+
356
+ self.transforms = transforms
357
+ self.transforms_to_apply = keys
358
+
359
+ @contextmanager
360
+ def filter(self, *names: list):
361
+ """This can be used to skip transforms entirely when applying
362
+ the sequence of transforms to a signal. For example, take
363
+ the following transforms with the names ``preprocess, augment, postprocess``.
364
+
365
+ >>> preprocess = transforms.Compose(
366
+ >>> tfm.GlobalVolumeNorm(),
367
+ >>> tfm.CrossTalk(),
368
+ >>> name="preprocess",
369
+ >>> )
370
+ >>> augment = transforms.Compose(
371
+ >>> tfm.RoomImpulseResponse(),
372
+ >>> tfm.BackgroundNoise(),
373
+ >>> name="augment",
374
+ >>> )
375
+ >>> postprocess = transforms.Compose(
376
+ >>> tfm.VolumeChange(),
377
+ >>> tfm.RescaleAudio(),
378
+ >>> tfm.ShiftPhase(),
379
+ >>> name="postprocess",
380
+ >>> )
381
+ >>> transform = transforms.Compose(preprocess, augment, postprocess)
382
+
383
+ If we wanted to apply all 3 to a signal, we do:
384
+
385
+ >>> kwargs = transform.instantiate()
386
+ >>> output = transform(signal.clone(), **kwargs)
387
+
388
+ But if we only wanted to apply the ``preprocess`` and ``postprocess``
389
+ transforms to the signal, we do:
390
+
391
+ >>> with transform_fn.filter("preprocess", "postprocess"):
392
+ >>> output = transform(signal.clone(), **kwargs)
393
+
394
+ Parameters
395
+ ----------
396
+ *names : list
397
+ List of transforms, identified by name, to apply to signal.
398
+ """
399
+ old_transforms = self.transforms_to_apply
400
+ self.transforms_to_apply = names
401
+ yield
402
+ self.transforms_to_apply = old_transforms
403
+
404
+ def _transform(self, signal, **kwargs):
405
+ for transform in self.transforms:
406
+ if any([x in transform.name for x in self.transforms_to_apply]):
407
+ signal = transform(signal, **kwargs)
408
+ return signal
409
+
410
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
411
+ parameters = {}
412
+ for transform in self.transforms:
413
+ parameters.update(transform.instantiate(state, signal=signal))
414
+ return parameters
415
+
416
+ def __getitem__(self, idx):
417
+ return self.transforms[idx]
418
+
419
+ def __len__(self):
420
+ return len(self.transforms)
421
+
422
+ def __iter__(self):
423
+ for transform in self.transforms:
424
+ yield transform
425
+
426
+
427
+ class Choose(Compose):
428
+ """Choose logic is the same as :py:func:`audiotools.data.transforms.Compose`,
429
+ but instead of applying all the transforms in sequence, it applies just a single transform,
430
+ which is chosen for each item in the batch.
431
+
432
+ Parameters
433
+ ----------
434
+ *transforms : list
435
+ List of transforms to apply
436
+ weights : list
437
+ Probability of choosing any specific transform.
438
+ name : str, optional
439
+ Name of this transform, used to identify it in the dictionary
440
+ produced by ``self.instantiate``, by default None
441
+ prob : float, optional
442
+ Probability of applying this transform, by default 1.0
443
+
444
+ Examples
445
+ --------
446
+
447
+ >>> transforms.Choose(tfm.LowPass(), tfm.HighPass())
448
+ """
449
+
450
+ def __init__(
451
+ self,
452
+ *transforms: list,
453
+ weights: list = None,
454
+ name: str = None,
455
+ prob: float = 1.0,
456
+ ):
457
+ super().__init__(*transforms, name=name, prob=prob)
458
+
459
+ if weights is None:
460
+ _len = len(self.transforms)
461
+ weights = [1 / _len for _ in range(_len)]
462
+ self.weights = np.array(weights)
463
+
464
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
465
+ kwargs = super()._instantiate(state, signal)
466
+ tfm_idx = list(range(len(self.transforms)))
467
+ tfm_idx = state.choice(tfm_idx, p=self.weights)
468
+ one_hot = []
469
+ for i, t in enumerate(self.transforms):
470
+ mask = kwargs[t.name]["mask"]
471
+ if mask.item():
472
+ kwargs[t.name]["mask"] = tt(i == tfm_idx)
473
+ one_hot.append(kwargs[t.name]["mask"])
474
+ kwargs["one_hot"] = one_hot
475
+ return kwargs
476
+
477
+
478
+ class Repeat(Compose):
479
+ """Repeatedly applies a given transform ``n_repeat`` times."
480
+
481
+ Parameters
482
+ ----------
483
+ transform : BaseTransform
484
+ Transform to repeat.
485
+ n_repeat : int, optional
486
+ Number of times to repeat transform, by default 1
487
+ """
488
+
489
+ def __init__(
490
+ self,
491
+ transform,
492
+ n_repeat: int = 1,
493
+ name: str = None,
494
+ prob: float = 1.0,
495
+ ):
496
+ transforms = [copy.copy(transform) for _ in range(n_repeat)]
497
+ super().__init__(transforms, name=name, prob=prob)
498
+
499
+ self.n_repeat = n_repeat
500
+
501
+
502
+ class RepeatUpTo(Choose):
503
+ """Repeatedly applies a given transform up to ``max_repeat`` times."
504
+
505
+ Parameters
506
+ ----------
507
+ transform : BaseTransform
508
+ Transform to repeat.
509
+ max_repeat : int, optional
510
+ Max number of times to repeat transform, by default 1
511
+ weights : list
512
+ Probability of choosing any specific number up to ``max_repeat``.
513
+ """
514
+
515
+ def __init__(
516
+ self,
517
+ transform,
518
+ max_repeat: int = 5,
519
+ weights: list = None,
520
+ name: str = None,
521
+ prob: float = 1.0,
522
+ ):
523
+ transforms = []
524
+ for n in range(1, max_repeat):
525
+ transforms.append(Repeat(transform, n_repeat=n))
526
+ super().__init__(transforms, name=name, prob=prob, weights=weights)
527
+
528
+ self.max_repeat = max_repeat
529
+
530
+
531
+ class ClippingDistortion(BaseTransform):
532
+ """Adds clipping distortion to signal. Corresponds
533
+ to :py:func:`audiotools.core.effects.EffectMixin.clip_distortion`.
534
+
535
+ Parameters
536
+ ----------
537
+ perc : tuple, optional
538
+ Clipping percentile. Values are between 0.0 to 1.0.
539
+ Typical values are 0.1 or below, by default ("uniform", 0.0, 0.1)
540
+ name : str, optional
541
+ Name of this transform, used to identify it in the dictionary
542
+ produced by ``self.instantiate``, by default None
543
+ prob : float, optional
544
+ Probability of applying this transform, by default 1.0
545
+ """
546
+
547
+ def __init__(
548
+ self,
549
+ perc: tuple = ("uniform", 0.0, 0.1),
550
+ name: str = None,
551
+ prob: float = 1.0,
552
+ ):
553
+ super().__init__(name=name, prob=prob)
554
+
555
+ self.perc = perc
556
+
557
+ def _instantiate(self, state: RandomState):
558
+ return {"perc": util.sample_from_dist(self.perc, state)}
559
+
560
+ def _transform(self, signal, perc):
561
+ return signal.clip_distortion(perc)
562
+
563
+
564
+ class Equalizer(BaseTransform):
565
+ """Applies an equalization curve to the audio signal. Corresponds
566
+ to :py:func:`audiotools.core.effects.EffectMixin.equalizer`.
567
+
568
+ Parameters
569
+ ----------
570
+ eq_amount : tuple, optional
571
+ The maximum dB cut to apply to the audio in any band,
572
+ by default ("const", 1.0 dB)
573
+ n_bands : int, optional
574
+ Number of bands in EQ, by default 6
575
+ name : str, optional
576
+ Name of this transform, used to identify it in the dictionary
577
+ produced by ``self.instantiate``, by default None
578
+ prob : float, optional
579
+ Probability of applying this transform, by default 1.0
580
+ """
581
+
582
+ def __init__(
583
+ self,
584
+ eq_amount: tuple = ("const", 1.0),
585
+ n_bands: int = 6,
586
+ name: str = None,
587
+ prob: float = 1.0,
588
+ ):
589
+ super().__init__(name=name, prob=prob)
590
+
591
+ self.eq_amount = eq_amount
592
+ self.n_bands = n_bands
593
+
594
+ def _instantiate(self, state: RandomState):
595
+ eq_amount = util.sample_from_dist(self.eq_amount, state)
596
+ eq = -eq_amount * state.rand(self.n_bands)
597
+ return {"eq": eq}
598
+
599
+ def _transform(self, signal, eq):
600
+ return signal.equalizer(eq)
601
+
602
+
603
+ class Quantization(BaseTransform):
604
+ """Applies quantization to the input waveform. Corresponds
605
+ to :py:func:`audiotools.core.effects.EffectMixin.quantization`.
606
+
607
+ Parameters
608
+ ----------
609
+ channels : tuple, optional
610
+ Number of evenly spaced quantization channels to quantize
611
+ to, by default ("choice", [8, 32, 128, 256, 1024])
612
+ name : str, optional
613
+ Name of this transform, used to identify it in the dictionary
614
+ produced by ``self.instantiate``, by default None
615
+ prob : float, optional
616
+ Probability of applying this transform, by default 1.0
617
+ """
618
+
619
+ def __init__(
620
+ self,
621
+ channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
622
+ name: str = None,
623
+ prob: float = 1.0,
624
+ ):
625
+ super().__init__(name=name, prob=prob)
626
+
627
+ self.channels = channels
628
+
629
+ def _instantiate(self, state: RandomState):
630
+ return {"channels": util.sample_from_dist(self.channels, state)}
631
+
632
+ def _transform(self, signal, channels):
633
+ return signal.quantization(channels)
634
+
635
+
636
+ class MuLawQuantization(BaseTransform):
637
+ """Applies mu-law quantization to the input waveform. Corresponds
638
+ to :py:func:`audiotools.core.effects.EffectMixin.mulaw_quantization`.
639
+
640
+ Parameters
641
+ ----------
642
+ channels : tuple, optional
643
+ Number of mu-law spaced quantization channels to quantize
644
+ to, by default ("choice", [8, 32, 128, 256, 1024])
645
+ name : str, optional
646
+ Name of this transform, used to identify it in the dictionary
647
+ produced by ``self.instantiate``, by default None
648
+ prob : float, optional
649
+ Probability of applying this transform, by default 1.0
650
+ """
651
+
652
+ def __init__(
653
+ self,
654
+ channels: tuple = ("choice", [8, 32, 128, 256, 1024]),
655
+ name: str = None,
656
+ prob: float = 1.0,
657
+ ):
658
+ super().__init__(name=name, prob=prob)
659
+
660
+ self.channels = channels
661
+
662
+ def _instantiate(self, state: RandomState):
663
+ return {"channels": util.sample_from_dist(self.channels, state)}
664
+
665
+ def _transform(self, signal, channels):
666
+ return signal.mulaw_quantization(channels)
667
+
668
+
669
+ class NoiseFloor(BaseTransform):
670
+ """Adds a noise floor of Gaussian noise to the signal at a specified
671
+ dB.
672
+
673
+ Parameters
674
+ ----------
675
+ db : tuple, optional
676
+ Level of noise to add to signal, by default ("const", -50.0)
677
+ name : str, optional
678
+ Name of this transform, used to identify it in the dictionary
679
+ produced by ``self.instantiate``, by default None
680
+ prob : float, optional
681
+ Probability of applying this transform, by default 1.0
682
+ """
683
+
684
+ def __init__(
685
+ self,
686
+ db: tuple = ("const", -50.0),
687
+ name: str = None,
688
+ prob: float = 1.0,
689
+ ):
690
+ super().__init__(name=name, prob=prob)
691
+
692
+ self.db = db
693
+
694
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
695
+ db = util.sample_from_dist(self.db, state)
696
+ audio_data = state.randn(signal.num_channels, signal.signal_length)
697
+ nz_signal = AudioSignal(audio_data, signal.sample_rate)
698
+ nz_signal.normalize(db)
699
+ return {"nz_signal": nz_signal}
700
+
701
+ def _transform(self, signal, nz_signal):
702
+ # Clone bg_signal so that transform can be repeatedly applied
703
+ # to different signals with the same effect.
704
+ return signal + nz_signal
705
+
706
+
707
+ class BackgroundNoise(BaseTransform):
708
+ """Adds background noise from audio specified by a set of CSV files.
709
+ A valid CSV file looks like, and is typically generated by
710
+ :py:func:`audiotools.data.preprocess.create_csv`:
711
+
712
+ .. csv-table::
713
+ :header: path
714
+
715
+ room_tone/m6_script2_clean.wav
716
+ room_tone/m6_script2_cleanraw.wav
717
+ room_tone/m6_script2_ipad_balcony1.wav
718
+ room_tone/m6_script2_ipad_bedroom1.wav
719
+ room_tone/m6_script2_ipad_confroom1.wav
720
+ room_tone/m6_script2_ipad_confroom2.wav
721
+ room_tone/m6_script2_ipad_livingroom1.wav
722
+ room_tone/m6_script2_ipad_office1.wav
723
+
724
+ .. note::
725
+ All paths are relative to an environment variable called ``PATH_TO_DATA``,
726
+ so that CSV files are portable across machines where data may be
727
+ located in different places.
728
+
729
+ This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
730
+ and :py:func:`audiotools.core.effects.EffectMixin.equalizer` under the
731
+ hood.
732
+
733
+ Parameters
734
+ ----------
735
+ snr : tuple, optional
736
+ Signal-to-noise ratio, by default ("uniform", 10.0, 30.0)
737
+ sources : List[str], optional
738
+ Sources containing folders, or CSVs with paths to audio files,
739
+ by default None
740
+ weights : List[float], optional
741
+ Weights to sample audio files from each source, by default None
742
+ eq_amount : tuple, optional
743
+ Amount of equalization to apply, by default ("const", 1.0)
744
+ n_bands : int, optional
745
+ Number of bands in equalizer, by default 3
746
+ name : str, optional
747
+ Name of this transform, used to identify it in the dictionary
748
+ produced by ``self.instantiate``, by default None
749
+ prob : float, optional
750
+ Probability of applying this transform, by default 1.0
751
+ loudness_cutoff : float, optional
752
+ Loudness cutoff when loading from audio files, by default None
753
+ """
754
+
755
+ def __init__(
756
+ self,
757
+ snr: tuple = ("uniform", 10.0, 30.0),
758
+ sources: List[str] = None,
759
+ weights: List[float] = None,
760
+ eq_amount: tuple = ("const", 1.0),
761
+ n_bands: int = 3,
762
+ name: str = None,
763
+ prob: float = 1.0,
764
+ loudness_cutoff: float = None,
765
+ ):
766
+ super().__init__(name=name, prob=prob)
767
+
768
+ self.snr = snr
769
+ self.eq_amount = eq_amount
770
+ self.n_bands = n_bands
771
+ self.loader = AudioLoader(sources, weights)
772
+ self.loudness_cutoff = loudness_cutoff
773
+
774
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
775
+ eq_amount = util.sample_from_dist(self.eq_amount, state)
776
+ eq = -eq_amount * state.rand(self.n_bands)
777
+ snr = util.sample_from_dist(self.snr, state)
778
+
779
+ bg_signal = self.loader(
780
+ state,
781
+ signal.sample_rate,
782
+ duration=signal.signal_duration,
783
+ loudness_cutoff=self.loudness_cutoff,
784
+ num_channels=signal.num_channels,
785
+ )["signal"]
786
+
787
+ return {"eq": eq, "bg_signal": bg_signal, "snr": snr}
788
+
789
+ def _transform(self, signal, bg_signal, snr, eq):
790
+ # Clone bg_signal so that transform can be repeatedly applied
791
+ # to different signals with the same effect.
792
+ return signal.mix(bg_signal.clone(), snr, eq)
793
+
794
+
795
+ class CrossTalk(BaseTransform):
796
+ """Adds crosstalk between speakers, whose audio is drawn from a CSV file
797
+ that was produced via :py:func:`audiotools.data.preprocess.create_csv`.
798
+
799
+ This transform calls :py:func:`audiotools.core.effects.EffectMixin.mix`
800
+ under the hood.
801
+
802
+ Parameters
803
+ ----------
804
+ snr : tuple, optional
805
+ How loud cross-talk speaker is relative to original signal in dB,
806
+ by default ("uniform", 0.0, 10.0)
807
+ sources : List[str], optional
808
+ Sources containing folders, or CSVs with paths to audio files,
809
+ by default None
810
+ weights : List[float], optional
811
+ Weights to sample audio files from each source, by default None
812
+ name : str, optional
813
+ Name of this transform, used to identify it in the dictionary
814
+ produced by ``self.instantiate``, by default None
815
+ prob : float, optional
816
+ Probability of applying this transform, by default 1.0
817
+ loudness_cutoff : float, optional
818
+ Loudness cutoff when loading from audio files, by default -40
819
+ """
820
+
821
+ def __init__(
822
+ self,
823
+ snr: tuple = ("uniform", 0.0, 10.0),
824
+ sources: List[str] = None,
825
+ weights: List[float] = None,
826
+ name: str = None,
827
+ prob: float = 1.0,
828
+ loudness_cutoff: float = -40,
829
+ ):
830
+ super().__init__(name=name, prob=prob)
831
+
832
+ self.snr = snr
833
+ self.loader = AudioLoader(sources, weights)
834
+ self.loudness_cutoff = loudness_cutoff
835
+
836
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
837
+ snr = util.sample_from_dist(self.snr, state)
838
+ crosstalk_signal = self.loader(
839
+ state,
840
+ signal.sample_rate,
841
+ duration=signal.signal_duration,
842
+ loudness_cutoff=self.loudness_cutoff,
843
+ num_channels=signal.num_channels,
844
+ )["signal"]
845
+
846
+ return {"crosstalk_signal": crosstalk_signal, "snr": snr}
847
+
848
+ def _transform(self, signal, crosstalk_signal, snr):
849
+ # Clone bg_signal so that transform can be repeatedly applied
850
+ # to different signals with the same effect.
851
+ loudness = signal.loudness()
852
+ mix = signal.mix(crosstalk_signal.clone(), snr)
853
+ mix.normalize(loudness)
854
+ return mix
855
+
856
+
857
+ class RoomImpulseResponse(BaseTransform):
858
+ """Convolves signal with a room impulse response, at a specified
859
+ direct-to-reverberant ratio, with equalization applied. Room impulse
860
+ response data is drawn from a CSV file that was produced via
861
+ :py:func:`audiotools.data.preprocess.create_csv`.
862
+
863
+ This transform calls :py:func:`audiotools.core.effects.EffectMixin.apply_ir`
864
+ under the hood.
865
+
866
+ Parameters
867
+ ----------
868
+ drr : tuple, optional
869
+ _description_, by default ("uniform", 0.0, 30.0)
870
+ sources : List[str], optional
871
+ Sources containing folders, or CSVs with paths to audio files,
872
+ by default None
873
+ weights : List[float], optional
874
+ Weights to sample audio files from each source, by default None
875
+ eq_amount : tuple, optional
876
+ Amount of equalization to apply, by default ("const", 1.0)
877
+ n_bands : int, optional
878
+ Number of bands in equalizer, by default 6
879
+ name : str, optional
880
+ Name of this transform, used to identify it in the dictionary
881
+ produced by ``self.instantiate``, by default None
882
+ prob : float, optional
883
+ Probability of applying this transform, by default 1.0
884
+ use_original_phase : bool, optional
885
+ Whether or not to use the original phase, by default False
886
+ offset : float, optional
887
+ Offset from each impulse response file to use, by default 0.0
888
+ duration : float, optional
889
+ Duration of each impulse response, by default 1.0
890
+ """
891
+
892
+ def __init__(
893
+ self,
894
+ drr: tuple = ("uniform", 0.0, 30.0),
895
+ sources: List[str] = None,
896
+ weights: List[float] = None,
897
+ eq_amount: tuple = ("const", 1.0),
898
+ n_bands: int = 6,
899
+ name: str = None,
900
+ prob: float = 1.0,
901
+ use_original_phase: bool = False,
902
+ offset: float = 0.0,
903
+ duration: float = 1.0,
904
+ ):
905
+ super().__init__(name=name, prob=prob)
906
+
907
+ self.drr = drr
908
+ self.eq_amount = eq_amount
909
+ self.n_bands = n_bands
910
+ self.use_original_phase = use_original_phase
911
+
912
+ self.loader = AudioLoader(sources, weights)
913
+ self.offset = offset
914
+ self.duration = duration
915
+
916
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
917
+ eq_amount = util.sample_from_dist(self.eq_amount, state)
918
+ eq = -eq_amount * state.rand(self.n_bands)
919
+ drr = util.sample_from_dist(self.drr, state)
920
+
921
+ ir_signal = self.loader(
922
+ state,
923
+ signal.sample_rate,
924
+ offset=self.offset,
925
+ duration=self.duration,
926
+ loudness_cutoff=None,
927
+ num_channels=signal.num_channels,
928
+ )["signal"]
929
+ ir_signal.zero_pad_to(signal.sample_rate)
930
+
931
+ return {"eq": eq, "ir_signal": ir_signal, "drr": drr}
932
+
933
+ def _transform(self, signal, ir_signal, drr, eq):
934
+ # Clone ir_signal so that transform can be repeatedly applied
935
+ # to different signals with the same effect.
936
+ return signal.apply_ir(
937
+ ir_signal.clone(), drr, eq, use_original_phase=self.use_original_phase
938
+ )
939
+
940
+
941
+ class VolumeChange(BaseTransform):
942
+ """Changes the volume of the input signal.
943
+
944
+ Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
945
+
946
+ Parameters
947
+ ----------
948
+ db : tuple, optional
949
+ Change in volume in decibels, by default ("uniform", -12.0, 0.0)
950
+ name : str, optional
951
+ Name of this transform, used to identify it in the dictionary
952
+ produced by ``self.instantiate``, by default None
953
+ prob : float, optional
954
+ Probability of applying this transform, by default 1.0
955
+ """
956
+
957
+ def __init__(
958
+ self,
959
+ db: tuple = ("uniform", -12.0, 0.0),
960
+ name: str = None,
961
+ prob: float = 1.0,
962
+ ):
963
+ super().__init__(name=name, prob=prob)
964
+ self.db = db
965
+
966
+ def _instantiate(self, state: RandomState):
967
+ return {"db": util.sample_from_dist(self.db, state)}
968
+
969
+ def _transform(self, signal, db):
970
+ return signal.volume_change(db)
971
+
972
+
973
+ class VolumeNorm(BaseTransform):
974
+ """Normalizes the volume of the excerpt to a specified decibel.
975
+
976
+ Uses :py:func:`audiotools.core.effects.EffectMixin.normalize`.
977
+
978
+ Parameters
979
+ ----------
980
+ db : tuple, optional
981
+ dB to normalize signal to, by default ("const", -24)
982
+ name : str, optional
983
+ Name of this transform, used to identify it in the dictionary
984
+ produced by ``self.instantiate``, by default None
985
+ prob : float, optional
986
+ Probability of applying this transform, by default 1.0
987
+ """
988
+
989
+ def __init__(
990
+ self,
991
+ db: tuple = ("const", -24),
992
+ name: str = None,
993
+ prob: float = 1.0,
994
+ ):
995
+ super().__init__(name=name, prob=prob)
996
+
997
+ self.db = db
998
+
999
+ def _instantiate(self, state: RandomState):
1000
+ return {"db": util.sample_from_dist(self.db, state)}
1001
+
1002
+ def _transform(self, signal, db):
1003
+ return signal.normalize(db)
1004
+
1005
+
1006
+ class GlobalVolumeNorm(BaseTransform):
1007
+ """Similar to :py:func:`audiotools.data.transforms.VolumeNorm`, this
1008
+ transform also normalizes the volume of a signal, but it uses
1009
+ the volume of the entire audio file the loaded excerpt comes from,
1010
+ rather than the volume of just the excerpt. The volume of the
1011
+ entire audio file is expected in ``signal.metadata["loudness"]``.
1012
+ If loading audio from a CSV generated by :py:func:`audiotools.data.preprocess.create_csv`
1013
+ with ``loudness = True``, like the following:
1014
+
1015
+ .. csv-table::
1016
+ :header: path,loudness
1017
+
1018
+ daps/produced/f1_script1_produced.wav,-16.299999237060547
1019
+ daps/produced/f1_script2_produced.wav,-16.600000381469727
1020
+ daps/produced/f1_script3_produced.wav,-17.299999237060547
1021
+ daps/produced/f1_script4_produced.wav,-16.100000381469727
1022
+ daps/produced/f1_script5_produced.wav,-16.700000762939453
1023
+ daps/produced/f3_script1_produced.wav,-16.5
1024
+
1025
+ The ``AudioLoader`` will automatically load the loudness column into
1026
+ the metadata of the signal.
1027
+
1028
+ Uses :py:func:`audiotools.core.effects.EffectMixin.volume_change`.
1029
+
1030
+ Parameters
1031
+ ----------
1032
+ db : tuple, optional
1033
+ dB to normalize signal to, by default ("const", -24)
1034
+ name : str, optional
1035
+ Name of this transform, used to identify it in the dictionary
1036
+ produced by ``self.instantiate``, by default None
1037
+ prob : float, optional
1038
+ Probability of applying this transform, by default 1.0
1039
+ """
1040
+
1041
+ def __init__(
1042
+ self,
1043
+ db: tuple = ("const", -24),
1044
+ name: str = None,
1045
+ prob: float = 1.0,
1046
+ ):
1047
+ super().__init__(name=name, prob=prob)
1048
+
1049
+ self.db = db
1050
+
1051
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
1052
+ if "loudness" not in signal.metadata:
1053
+ db_change = 0.0
1054
+ elif float(signal.metadata["loudness"]) == float("-inf"):
1055
+ db_change = 0.0
1056
+ else:
1057
+ db = util.sample_from_dist(self.db, state)
1058
+ db_change = db - float(signal.metadata["loudness"])
1059
+
1060
+ return {"db": db_change}
1061
+
1062
+ def _transform(self, signal, db):
1063
+ return signal.volume_change(db)
1064
+
1065
+
1066
+ class Silence(BaseTransform):
1067
+ """Zeros out the signal with some probability.
1068
+
1069
+ Parameters
1070
+ ----------
1071
+ name : str, optional
1072
+ Name of this transform, used to identify it in the dictionary
1073
+ produced by ``self.instantiate``, by default None
1074
+ prob : float, optional
1075
+ Probability of applying this transform, by default 0.1
1076
+ """
1077
+
1078
+ def __init__(self, name: str = None, prob: float = 0.1):
1079
+ super().__init__(name=name, prob=prob)
1080
+
1081
+ def _transform(self, signal):
1082
+ _loudness = signal._loudness
1083
+ signal = AudioSignal(
1084
+ torch.zeros_like(signal.audio_data),
1085
+ sample_rate=signal.sample_rate,
1086
+ stft_params=signal.stft_params,
1087
+ )
1088
+ # So that the amound of noise added is as if it wasn't silenced.
1089
+ # TODO: improve this hack
1090
+ signal._loudness = _loudness
1091
+
1092
+ return signal
1093
+
1094
+
1095
+ class LowPass(BaseTransform):
1096
+ """Applies a LowPass filter.
1097
+
1098
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.low_pass`.
1099
+
1100
+ Parameters
1101
+ ----------
1102
+ cutoff : tuple, optional
1103
+ Cutoff frequency distribution,
1104
+ by default ``("choice", [4000, 8000, 16000])``
1105
+ zeros : int, optional
1106
+ Number of zero-crossings in filter, argument to
1107
+ ``julius.LowPassFilters``, by default 51
1108
+ name : str, optional
1109
+ Name of this transform, used to identify it in the dictionary
1110
+ produced by ``self.instantiate``, by default None
1111
+ prob : float, optional
1112
+ Probability of applying this transform, by default 1.0
1113
+ """
1114
+
1115
+ def __init__(
1116
+ self,
1117
+ cutoff: tuple = ("choice", [4000, 8000, 16000]),
1118
+ zeros: int = 51,
1119
+ name: str = None,
1120
+ prob: float = 1,
1121
+ ):
1122
+ super().__init__(name=name, prob=prob)
1123
+
1124
+ self.cutoff = cutoff
1125
+ self.zeros = zeros
1126
+
1127
+ def _instantiate(self, state: RandomState):
1128
+ return {"cutoff": util.sample_from_dist(self.cutoff, state)}
1129
+
1130
+ def _transform(self, signal, cutoff):
1131
+ return signal.low_pass(cutoff, zeros=self.zeros)
1132
+
1133
+
1134
+ class HighPass(BaseTransform):
1135
+ """Applies a HighPass filter.
1136
+
1137
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.high_pass`.
1138
+
1139
+ Parameters
1140
+ ----------
1141
+ cutoff : tuple, optional
1142
+ Cutoff frequency distribution,
1143
+ by default ``("choice", [50, 100, 250, 500, 1000])``
1144
+ zeros : int, optional
1145
+ Number of zero-crossings in filter, argument to
1146
+ ``julius.LowPassFilters``, by default 51
1147
+ name : str, optional
1148
+ Name of this transform, used to identify it in the dictionary
1149
+ produced by ``self.instantiate``, by default None
1150
+ prob : float, optional
1151
+ Probability of applying this transform, by default 1.0
1152
+ """
1153
+
1154
+ def __init__(
1155
+ self,
1156
+ cutoff: tuple = ("choice", [50, 100, 250, 500, 1000]),
1157
+ zeros: int = 51,
1158
+ name: str = None,
1159
+ prob: float = 1,
1160
+ ):
1161
+ super().__init__(name=name, prob=prob)
1162
+
1163
+ self.cutoff = cutoff
1164
+ self.zeros = zeros
1165
+
1166
+ def _instantiate(self, state: RandomState):
1167
+ return {"cutoff": util.sample_from_dist(self.cutoff, state)}
1168
+
1169
+ def _transform(self, signal, cutoff):
1170
+ return signal.high_pass(cutoff, zeros=self.zeros)
1171
+
1172
+
1173
+ class RescaleAudio(BaseTransform):
1174
+ """Rescales the audio so it is in between ``-val`` and ``val``
1175
+ only if the original audio exceeds those bounds. Useful if
1176
+ transforms have caused the audio to clip.
1177
+
1178
+ Uses :py:func:`audiotools.core.effects.EffectMixin.ensure_max_of_audio`.
1179
+
1180
+ Parameters
1181
+ ----------
1182
+ val : float, optional
1183
+ Max absolute value of signal, by default 1.0
1184
+ name : str, optional
1185
+ Name of this transform, used to identify it in the dictionary
1186
+ produced by ``self.instantiate``, by default None
1187
+ prob : float, optional
1188
+ Probability of applying this transform, by default 1.0
1189
+ """
1190
+
1191
+ def __init__(self, val: float = 1.0, name: str = None, prob: float = 1):
1192
+ super().__init__(name=name, prob=prob)
1193
+
1194
+ self.val = val
1195
+
1196
+ def _transform(self, signal):
1197
+ return signal.ensure_max_of_audio(self.val)
1198
+
1199
+
1200
+ class ShiftPhase(SpectralTransform):
1201
+ """Shifts the phase of the audio.
1202
+
1203
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.shift)phase`.
1204
+
1205
+ Parameters
1206
+ ----------
1207
+ shift : tuple, optional
1208
+ How much to shift phase by, by default ("uniform", -np.pi, np.pi)
1209
+ name : str, optional
1210
+ Name of this transform, used to identify it in the dictionary
1211
+ produced by ``self.instantiate``, by default None
1212
+ prob : float, optional
1213
+ Probability of applying this transform, by default 1.0
1214
+ """
1215
+
1216
+ def __init__(
1217
+ self,
1218
+ shift: tuple = ("uniform", -np.pi, np.pi),
1219
+ name: str = None,
1220
+ prob: float = 1,
1221
+ ):
1222
+ super().__init__(name=name, prob=prob)
1223
+ self.shift = shift
1224
+
1225
+ def _instantiate(self, state: RandomState):
1226
+ return {"shift": util.sample_from_dist(self.shift, state)}
1227
+
1228
+ def _transform(self, signal, shift):
1229
+ return signal.shift_phase(shift)
1230
+
1231
+
1232
+ class InvertPhase(ShiftPhase):
1233
+ """Inverts the phase of the audio.
1234
+
1235
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.shift_phase`.
1236
+
1237
+ Parameters
1238
+ ----------
1239
+ name : str, optional
1240
+ Name of this transform, used to identify it in the dictionary
1241
+ produced by ``self.instantiate``, by default None
1242
+ prob : float, optional
1243
+ Probability of applying this transform, by default 1.0
1244
+ """
1245
+
1246
+ def __init__(self, name: str = None, prob: float = 1):
1247
+ super().__init__(shift=("const", np.pi), name=name, prob=prob)
1248
+
1249
+
1250
+ class CorruptPhase(SpectralTransform):
1251
+ """Corrupts the phase of the audio.
1252
+
1253
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.corrupt_phase`.
1254
+
1255
+ Parameters
1256
+ ----------
1257
+ scale : tuple, optional
1258
+ How much to corrupt phase by, by default ("uniform", 0, np.pi)
1259
+ name : str, optional
1260
+ Name of this transform, used to identify it in the dictionary
1261
+ produced by ``self.instantiate``, by default None
1262
+ prob : float, optional
1263
+ Probability of applying this transform, by default 1.0
1264
+ """
1265
+
1266
+ def __init__(
1267
+ self, scale: tuple = ("uniform", 0, np.pi), name: str = None, prob: float = 1
1268
+ ):
1269
+ super().__init__(name=name, prob=prob)
1270
+ self.scale = scale
1271
+
1272
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
1273
+ scale = util.sample_from_dist(self.scale, state)
1274
+ corruption = state.normal(scale=scale, size=signal.phase.shape[1:])
1275
+ return {"corruption": corruption.astype("float32")}
1276
+
1277
+ def _transform(self, signal, corruption):
1278
+ return signal.shift_phase(shift=corruption)
1279
+
1280
+
1281
+ class FrequencyMask(SpectralTransform):
1282
+ """Masks a band of frequencies at a center frequency
1283
+ from the audio.
1284
+
1285
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_frequencies`.
1286
+
1287
+ Parameters
1288
+ ----------
1289
+ f_center : tuple, optional
1290
+ Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
1291
+ f_width : tuple, optional
1292
+ Width of zero'd out band, by default ("const", 0.1)
1293
+ name : str, optional
1294
+ Name of this transform, used to identify it in the dictionary
1295
+ produced by ``self.instantiate``, by default None
1296
+ prob : float, optional
1297
+ Probability of applying this transform, by default 1.0
1298
+ """
1299
+
1300
+ def __init__(
1301
+ self,
1302
+ f_center: tuple = ("uniform", 0.0, 1.0),
1303
+ f_width: tuple = ("const", 0.1),
1304
+ name: str = None,
1305
+ prob: float = 1,
1306
+ ):
1307
+ super().__init__(name=name, prob=prob)
1308
+ self.f_center = f_center
1309
+ self.f_width = f_width
1310
+
1311
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
1312
+ f_center = util.sample_from_dist(self.f_center, state)
1313
+ f_width = util.sample_from_dist(self.f_width, state)
1314
+
1315
+ fmin = max(f_center - (f_width / 2), 0.0)
1316
+ fmax = min(f_center + (f_width / 2), 1.0)
1317
+
1318
+ fmin_hz = (signal.sample_rate / 2) * fmin
1319
+ fmax_hz = (signal.sample_rate / 2) * fmax
1320
+
1321
+ return {"fmin_hz": fmin_hz, "fmax_hz": fmax_hz}
1322
+
1323
+ def _transform(self, signal, fmin_hz: float, fmax_hz: float):
1324
+ return signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
1325
+
1326
+
1327
+ class TimeMask(SpectralTransform):
1328
+ """Masks out contiguous time-steps from signal.
1329
+
1330
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_timesteps`.
1331
+
1332
+ Parameters
1333
+ ----------
1334
+ t_center : tuple, optional
1335
+ Center time in terms of 0.0 and 1.0 (duration of signal),
1336
+ by default ("uniform", 0.0, 1.0)
1337
+ t_width : tuple, optional
1338
+ Width of dropped out portion, by default ("const", 0.025)
1339
+ name : str, optional
1340
+ Name of this transform, used to identify it in the dictionary
1341
+ produced by ``self.instantiate``, by default None
1342
+ prob : float, optional
1343
+ Probability of applying this transform, by default 1.0
1344
+ """
1345
+
1346
+ def __init__(
1347
+ self,
1348
+ t_center: tuple = ("uniform", 0.0, 1.0),
1349
+ t_width: tuple = ("const", 0.025),
1350
+ name: str = None,
1351
+ prob: float = 1,
1352
+ ):
1353
+ super().__init__(name=name, prob=prob)
1354
+ self.t_center = t_center
1355
+ self.t_width = t_width
1356
+
1357
+ def _instantiate(self, state: RandomState, signal: AudioSignal):
1358
+ t_center = util.sample_from_dist(self.t_center, state)
1359
+ t_width = util.sample_from_dist(self.t_width, state)
1360
+
1361
+ tmin = max(t_center - (t_width / 2), 0.0)
1362
+ tmax = min(t_center + (t_width / 2), 1.0)
1363
+
1364
+ tmin_s = signal.signal_duration * tmin
1365
+ tmax_s = signal.signal_duration * tmax
1366
+ return {"tmin_s": tmin_s, "tmax_s": tmax_s}
1367
+
1368
+ def _transform(self, signal, tmin_s: float, tmax_s: float):
1369
+ return signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s)
1370
+
1371
+
1372
+ class MaskLowMagnitudes(SpectralTransform):
1373
+ """Masks low magnitude regions out of signal.
1374
+
1375
+ Uses :py:func:`audiotools.core.dsp.DSPMixin.mask_low_magnitudes`.
1376
+
1377
+ Parameters
1378
+ ----------
1379
+ db_cutoff : tuple, optional
1380
+ Decibel value for which things below it will be masked away,
1381
+ by default ("uniform", -10, 10)
1382
+ name : str, optional
1383
+ Name of this transform, used to identify it in the dictionary
1384
+ produced by ``self.instantiate``, by default None
1385
+ prob : float, optional
1386
+ Probability of applying this transform, by default 1.0
1387
+ """
1388
+
1389
+ def __init__(
1390
+ self,
1391
+ db_cutoff: tuple = ("uniform", -10, 10),
1392
+ name: str = None,
1393
+ prob: float = 1,
1394
+ ):
1395
+ super().__init__(name=name, prob=prob)
1396
+ self.db_cutoff = db_cutoff
1397
+
1398
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
1399
+ return {"db_cutoff": util.sample_from_dist(self.db_cutoff, state)}
1400
+
1401
+ def _transform(self, signal, db_cutoff: float):
1402
+ return signal.mask_low_magnitudes(db_cutoff)
1403
+
1404
+
1405
+ class Smoothing(BaseTransform):
1406
+ """Convolves the signal with a smoothing window.
1407
+
1408
+ Uses :py:func:`audiotools.core.effects.EffectMixin.convolve`.
1409
+
1410
+ Parameters
1411
+ ----------
1412
+ window_type : tuple, optional
1413
+ Type of window to use, by default ("const", "average")
1414
+ window_length : tuple, optional
1415
+ Length of smoothing window, by
1416
+ default ("choice", [8, 16, 32, 64, 128, 256, 512])
1417
+ name : str, optional
1418
+ Name of this transform, used to identify it in the dictionary
1419
+ produced by ``self.instantiate``, by default None
1420
+ prob : float, optional
1421
+ Probability of applying this transform, by default 1.0
1422
+ """
1423
+
1424
+ def __init__(
1425
+ self,
1426
+ window_type: tuple = ("const", "average"),
1427
+ window_length: tuple = ("choice", [8, 16, 32, 64, 128, 256, 512]),
1428
+ name: str = None,
1429
+ prob: float = 1,
1430
+ ):
1431
+ super().__init__(name=name, prob=prob)
1432
+ self.window_type = window_type
1433
+ self.window_length = window_length
1434
+
1435
+ def _instantiate(self, state: RandomState, signal: AudioSignal = None):
1436
+ window_type = util.sample_from_dist(self.window_type, state)
1437
+ window_length = util.sample_from_dist(self.window_length, state)
1438
+ window = signal.get_window(
1439
+ window_type=window_type, window_length=window_length, device="cpu"
1440
+ )
1441
+ return {"window": AudioSignal(window, signal.sample_rate)}
1442
+
1443
+ def _transform(self, signal, window):
1444
+ sscale = signal.audio_data.abs().max(dim=-1, keepdim=True).values
1445
+ sscale[sscale == 0.0] = 1.0
1446
+
1447
+ out = signal.convolve(window)
1448
+
1449
+ oscale = out.audio_data.abs().max(dim=-1, keepdim=True).values
1450
+ oscale[oscale == 0.0] = 1.0
1451
+
1452
+ out = out * (sscale / oscale)
1453
+ return out
1454
+
1455
+
1456
+ class TimeNoise(TimeMask):
1457
+ """Similar to :py:func:`audiotools.data.transforms.TimeMask`, but
1458
+ replaces with noise instead of zeros.
1459
+
1460
+ Parameters
1461
+ ----------
1462
+ t_center : tuple, optional
1463
+ Center time in terms of 0.0 and 1.0 (duration of signal),
1464
+ by default ("uniform", 0.0, 1.0)
1465
+ t_width : tuple, optional
1466
+ Width of dropped out portion, by default ("const", 0.025)
1467
+ name : str, optional
1468
+ Name of this transform, used to identify it in the dictionary
1469
+ produced by ``self.instantiate``, by default None
1470
+ prob : float, optional
1471
+ Probability of applying this transform, by default 1.0
1472
+ """
1473
+
1474
+ def __init__(
1475
+ self,
1476
+ t_center: tuple = ("uniform", 0.0, 1.0),
1477
+ t_width: tuple = ("const", 0.025),
1478
+ name: str = None,
1479
+ prob: float = 1,
1480
+ ):
1481
+ super().__init__(t_center=t_center, t_width=t_width, name=name, prob=prob)
1482
+
1483
+ def _transform(self, signal, tmin_s: float, tmax_s: float):
1484
+ signal = signal.mask_timesteps(tmin_s=tmin_s, tmax_s=tmax_s, val=0.0)
1485
+ mag, phase = signal.magnitude, signal.phase
1486
+
1487
+ mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
1488
+ mask = (mag == 0.0) * (phase == 0.0)
1489
+
1490
+ mag[mask] = mag_r[mask]
1491
+ phase[mask] = phase_r[mask]
1492
+
1493
+ signal.magnitude = mag
1494
+ signal.phase = phase
1495
+ return signal
1496
+
1497
+
1498
+ class FrequencyNoise(FrequencyMask):
1499
+ """Similar to :py:func:`audiotools.data.transforms.FrequencyMask`, but
1500
+ replaces with noise instead of zeros.
1501
+
1502
+ Parameters
1503
+ ----------
1504
+ f_center : tuple, optional
1505
+ Center frequency between 0.0 and 1.0 (Nyquist), by default ("uniform", 0.0, 1.0)
1506
+ f_width : tuple, optional
1507
+ Width of zero'd out band, by default ("const", 0.1)
1508
+ name : str, optional
1509
+ Name of this transform, used to identify it in the dictionary
1510
+ produced by ``self.instantiate``, by default None
1511
+ prob : float, optional
1512
+ Probability of applying this transform, by default 1.0
1513
+ """
1514
+
1515
+ def __init__(
1516
+ self,
1517
+ f_center: tuple = ("uniform", 0.0, 1.0),
1518
+ f_width: tuple = ("const", 0.1),
1519
+ name: str = None,
1520
+ prob: float = 1,
1521
+ ):
1522
+ super().__init__(f_center=f_center, f_width=f_width, name=name, prob=prob)
1523
+
1524
+ def _transform(self, signal, fmin_hz: float, fmax_hz: float):
1525
+ signal = signal.mask_frequencies(fmin_hz=fmin_hz, fmax_hz=fmax_hz)
1526
+ mag, phase = signal.magnitude, signal.phase
1527
+
1528
+ mag_r, phase_r = torch.randn_like(mag), torch.randn_like(phase)
1529
+ mask = (mag == 0.0) * (phase == 0.0)
1530
+
1531
+ mag[mask] = mag_r[mask]
1532
+ phase[mask] = phase_r[mask]
1533
+
1534
+ signal.magnitude = mag
1535
+ signal.phase = phase
1536
+ return signal
1537
+
1538
+
1539
+ class SpectralDenoising(Equalizer):
1540
+ """Applies denoising algorithm detailed in
1541
+ :py:func:`audiotools.ml.layers.spectral_gate.SpectralGate`,
1542
+ using a randomly generated noise signal for denoising.
1543
+
1544
+ Parameters
1545
+ ----------
1546
+ eq_amount : tuple, optional
1547
+ Amount of eq to apply to noise signal, by default ("const", 1.0)
1548
+ denoise_amount : tuple, optional
1549
+ Amount to denoise by, by default ("uniform", 0.8, 1.0)
1550
+ nz_volume : float, optional
1551
+ Volume of noise to denoise with, by default -40
1552
+ n_bands : int, optional
1553
+ Number of bands in equalizer, by default 6
1554
+ n_freq : int, optional
1555
+ Number of frequency bins to smooth by, by default 3
1556
+ n_time : int, optional
1557
+ Number of time bins to smooth by, by default 5
1558
+ name : str, optional
1559
+ Name of this transform, used to identify it in the dictionary
1560
+ produced by ``self.instantiate``, by default None
1561
+ prob : float, optional
1562
+ Probability of applying this transform, by default 1.0
1563
+ """
1564
+
1565
+ def __init__(
1566
+ self,
1567
+ eq_amount: tuple = ("const", 1.0),
1568
+ denoise_amount: tuple = ("uniform", 0.8, 1.0),
1569
+ nz_volume: float = -40,
1570
+ n_bands: int = 6,
1571
+ n_freq: int = 3,
1572
+ n_time: int = 5,
1573
+ name: str = None,
1574
+ prob: float = 1,
1575
+ ):
1576
+ super().__init__(eq_amount=eq_amount, n_bands=n_bands, name=name, prob=prob)
1577
+
1578
+ self.nz_volume = nz_volume
1579
+ self.denoise_amount = denoise_amount
1580
+ self.spectral_gate = ml.layers.SpectralGate(n_freq, n_time)
1581
+
1582
+ def _transform(self, signal, nz, eq, denoise_amount):
1583
+ nz = nz.normalize(self.nz_volume).equalizer(eq)
1584
+ self.spectral_gate = self.spectral_gate.to(signal.device)
1585
+ signal = self.spectral_gate(signal, nz, denoise_amount)
1586
+ return signal
1587
+
1588
+ def _instantiate(self, state: RandomState):
1589
+ kwargs = super()._instantiate(state)
1590
+ kwargs["denoise_amount"] = util.sample_from_dist(self.denoise_amount, state)
1591
+ kwargs["nz"] = AudioSignal(state.randn(22050), 44100)
1592
+ return kwargs
audiotools/metrics/__init__.py ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ """
2
+ Functions for comparing AudioSignal objects to one another.
3
+ """ # fmt: skip
4
+ from . import distance
5
+ from . import quality
6
+ from . import spectral
audiotools/metrics/distance.py ADDED
@@ -0,0 +1,131 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ from torch import nn
3
+
4
+ from .. import AudioSignal
5
+
6
+
7
+ class L1Loss(nn.L1Loss):
8
+ """L1 Loss between AudioSignals. Defaults
9
+ to comparing ``audio_data``, but any
10
+ attribute of an AudioSignal can be used.
11
+
12
+ Parameters
13
+ ----------
14
+ attribute : str, optional
15
+ Attribute of signal to compare, defaults to ``audio_data``.
16
+ weight : float, optional
17
+ Weight of this loss, defaults to 1.0.
18
+ """
19
+
20
+ def __init__(self, attribute: str = "audio_data", weight: float = 1.0, **kwargs):
21
+ self.attribute = attribute
22
+ self.weight = weight
23
+ super().__init__(**kwargs)
24
+
25
+ def forward(self, x: AudioSignal, y: AudioSignal):
26
+ """
27
+ Parameters
28
+ ----------
29
+ x : AudioSignal
30
+ Estimate AudioSignal
31
+ y : AudioSignal
32
+ Reference AudioSignal
33
+
34
+ Returns
35
+ -------
36
+ torch.Tensor
37
+ L1 loss between AudioSignal attributes.
38
+ """
39
+ if isinstance(x, AudioSignal):
40
+ x = getattr(x, self.attribute)
41
+ y = getattr(y, self.attribute)
42
+ return super().forward(x, y)
43
+
44
+
45
+ class SISDRLoss(nn.Module):
46
+ """
47
+ Computes the Scale-Invariant Source-to-Distortion Ratio between a batch
48
+ of estimated and reference audio signals or aligned features.
49
+
50
+ Parameters
51
+ ----------
52
+ scaling : int, optional
53
+ Whether to use scale-invariant (True) or
54
+ signal-to-noise ratio (False), by default True
55
+ reduction : str, optional
56
+ How to reduce across the batch (either 'mean',
57
+ 'sum', or none).], by default ' mean'
58
+ zero_mean : int, optional
59
+ Zero mean the references and estimates before
60
+ computing the loss, by default True
61
+ clip_min : int, optional
62
+ The minimum possible loss value. Helps network
63
+ to not focus on making already good examples better, by default None
64
+ weight : float, optional
65
+ Weight of this loss, defaults to 1.0.
66
+ """
67
+
68
+ def __init__(
69
+ self,
70
+ scaling: int = True,
71
+ reduction: str = "mean",
72
+ zero_mean: int = True,
73
+ clip_min: int = None,
74
+ weight: float = 1.0,
75
+ ):
76
+ self.scaling = scaling
77
+ self.reduction = reduction
78
+ self.zero_mean = zero_mean
79
+ self.clip_min = clip_min
80
+ self.weight = weight
81
+ super().__init__()
82
+
83
+ def forward(self, x: AudioSignal, y: AudioSignal):
84
+ eps = 1e-8
85
+ # nb, nc, nt
86
+ if isinstance(x, AudioSignal):
87
+ references = x.audio_data
88
+ estimates = y.audio_data
89
+ else:
90
+ references = x
91
+ estimates = y
92
+
93
+ nb = references.shape[0]
94
+ references = references.reshape(nb, 1, -1).permute(0, 2, 1)
95
+ estimates = estimates.reshape(nb, 1, -1).permute(0, 2, 1)
96
+
97
+ # samples now on axis 1
98
+ if self.zero_mean:
99
+ mean_reference = references.mean(dim=1, keepdim=True)
100
+ mean_estimate = estimates.mean(dim=1, keepdim=True)
101
+ else:
102
+ mean_reference = 0
103
+ mean_estimate = 0
104
+
105
+ _references = references - mean_reference
106
+ _estimates = estimates - mean_estimate
107
+
108
+ references_projection = (_references**2).sum(dim=-2) + eps
109
+ references_on_estimates = (_estimates * _references).sum(dim=-2) + eps
110
+
111
+ scale = (
112
+ (references_on_estimates / references_projection).unsqueeze(1)
113
+ if self.scaling
114
+ else 1
115
+ )
116
+
117
+ e_true = scale * _references
118
+ e_res = _estimates - e_true
119
+
120
+ signal = (e_true**2).sum(dim=1)
121
+ noise = (e_res**2).sum(dim=1)
122
+ sdr = -10 * torch.log10(signal / noise + eps)
123
+
124
+ if self.clip_min is not None:
125
+ sdr = torch.clamp(sdr, min=self.clip_min)
126
+
127
+ if self.reduction == "mean":
128
+ sdr = sdr.mean()
129
+ elif self.reduction == "sum":
130
+ sdr = sdr.sum()
131
+ return sdr
audiotools/metrics/quality.py ADDED
@@ -0,0 +1,159 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+
3
+ import numpy as np
4
+ import torch
5
+
6
+ from .. import AudioSignal
7
+
8
+
9
+ def stoi(
10
+ estimates: AudioSignal,
11
+ references: AudioSignal,
12
+ extended: int = False,
13
+ ):
14
+ """Short term objective intelligibility
15
+ Computes the STOI (See [1][2]) of a denoised signal compared to a clean
16
+ signal, The output is expected to have a monotonic relation with the
17
+ subjective speech-intelligibility, where a higher score denotes better
18
+ speech intelligibility. Uses pystoi under the hood.
19
+
20
+ Parameters
21
+ ----------
22
+ estimates : AudioSignal
23
+ Denoised speech
24
+ references : AudioSignal
25
+ Clean original speech
26
+ extended : int, optional
27
+ Boolean, whether to use the extended STOI described in [3], by default False
28
+
29
+ Returns
30
+ -------
31
+ Tensor[float]
32
+ Short time objective intelligibility measure between clean and
33
+ denoised speech
34
+
35
+ References
36
+ ----------
37
+ 1. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'A Short-Time
38
+ Objective Intelligibility Measure for Time-Frequency Weighted Noisy
39
+ Speech', ICASSP 2010, Texas, Dallas.
40
+ 2. C.H.Taal, R.C.Hendriks, R.Heusdens, J.Jensen 'An Algorithm for
41
+ Intelligibility Prediction of Time-Frequency Weighted Noisy Speech',
42
+ IEEE Transactions on Audio, Speech, and Language Processing, 2011.
43
+ 3. Jesper Jensen and Cees H. Taal, 'An Algorithm for Predicting the
44
+ Intelligibility of Speech Masked by Modulated Noise Maskers',
45
+ IEEE Transactions on Audio, Speech and Language Processing, 2016.
46
+ """
47
+ import pystoi
48
+
49
+ estimates = estimates.clone().to_mono()
50
+ references = references.clone().to_mono()
51
+
52
+ stois = []
53
+ for i in range(estimates.batch_size):
54
+ _stoi = pystoi.stoi(
55
+ references.audio_data[i, 0].detach().cpu().numpy(),
56
+ estimates.audio_data[i, 0].detach().cpu().numpy(),
57
+ references.sample_rate,
58
+ extended=extended,
59
+ )
60
+ stois.append(_stoi)
61
+ return torch.from_numpy(np.array(stois))
62
+
63
+
64
+ def pesq(
65
+ estimates: AudioSignal,
66
+ references: AudioSignal,
67
+ mode: str = "wb",
68
+ target_sr: float = 16000,
69
+ ):
70
+ """_summary_
71
+
72
+ Parameters
73
+ ----------
74
+ estimates : AudioSignal
75
+ Degraded AudioSignal
76
+ references : AudioSignal
77
+ Reference AudioSignal
78
+ mode : str, optional
79
+ 'wb' (wide-band) or 'nb' (narrow-band), by default "wb"
80
+ target_sr : float, optional
81
+ Target sample rate, by default 16000
82
+
83
+ Returns
84
+ -------
85
+ Tensor[float]
86
+ PESQ score: P.862.2 Prediction (MOS-LQO)
87
+ """
88
+ from pesq import pesq as pesq_fn
89
+
90
+ estimates = estimates.clone().to_mono().resample(target_sr)
91
+ references = references.clone().to_mono().resample(target_sr)
92
+
93
+ pesqs = []
94
+ for i in range(estimates.batch_size):
95
+ _pesq = pesq_fn(
96
+ estimates.sample_rate,
97
+ references.audio_data[i, 0].detach().cpu().numpy(),
98
+ estimates.audio_data[i, 0].detach().cpu().numpy(),
99
+ mode,
100
+ )
101
+ pesqs.append(_pesq)
102
+ return torch.from_numpy(np.array(pesqs))
103
+
104
+
105
+ def visqol(
106
+ estimates: AudioSignal,
107
+ references: AudioSignal,
108
+ mode: str = "audio",
109
+ ): # pragma: no cover
110
+ """ViSQOL score.
111
+
112
+ Parameters
113
+ ----------
114
+ estimates : AudioSignal
115
+ Degraded AudioSignal
116
+ references : AudioSignal
117
+ Reference AudioSignal
118
+ mode : str, optional
119
+ 'audio' or 'speech', by default 'audio'
120
+
121
+ Returns
122
+ -------
123
+ Tensor[float]
124
+ ViSQOL score (MOS-LQO)
125
+ """
126
+ from visqol import visqol_lib_py
127
+ from visqol.pb2 import visqol_config_pb2
128
+ from visqol.pb2 import similarity_result_pb2
129
+
130
+ config = visqol_config_pb2.VisqolConfig()
131
+ if mode == "audio":
132
+ target_sr = 48000
133
+ config.options.use_speech_scoring = False
134
+ svr_model_path = "libsvm_nu_svr_model.txt"
135
+ elif mode == "speech":
136
+ target_sr = 16000
137
+ config.options.use_speech_scoring = True
138
+ svr_model_path = "lattice_tcditugenmeetpackhref_ls2_nl60_lr12_bs2048_learn.005_ep2400_train1_7_raw.tflite"
139
+ else:
140
+ raise ValueError(f"Unrecognized mode: {mode}")
141
+ config.audio.sample_rate = target_sr
142
+ config.options.svr_model_path = os.path.join(
143
+ os.path.dirname(visqol_lib_py.__file__), "model", svr_model_path
144
+ )
145
+
146
+ api = visqol_lib_py.VisqolApi()
147
+ api.Create(config)
148
+
149
+ estimates = estimates.clone().to_mono().resample(target_sr)
150
+ references = references.clone().to_mono().resample(target_sr)
151
+
152
+ visqols = []
153
+ for i in range(estimates.batch_size):
154
+ _visqol = api.Measure(
155
+ references.audio_data[i, 0].detach().cpu().numpy().astype(float),
156
+ estimates.audio_data[i, 0].detach().cpu().numpy().astype(float),
157
+ )
158
+ visqols.append(_visqol.moslqo)
159
+ return torch.from_numpy(np.array(visqols))
audiotools/metrics/spectral.py ADDED
@@ -0,0 +1,247 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import typing
2
+ from typing import List
3
+
4
+ import numpy as np
5
+ from torch import nn
6
+
7
+ from .. import AudioSignal
8
+ from .. import STFTParams
9
+
10
+
11
+ class MultiScaleSTFTLoss(nn.Module):
12
+ """Computes the multi-scale STFT loss from [1].
13
+
14
+ Parameters
15
+ ----------
16
+ window_lengths : List[int], optional
17
+ Length of each window of each STFT, by default [2048, 512]
18
+ loss_fn : typing.Callable, optional
19
+ How to compare each loss, by default nn.L1Loss()
20
+ clamp_eps : float, optional
21
+ Clamp on the log magnitude, below, by default 1e-5
22
+ mag_weight : float, optional
23
+ Weight of raw magnitude portion of loss, by default 1.0
24
+ log_weight : float, optional
25
+ Weight of log magnitude portion of loss, by default 1.0
26
+ pow : float, optional
27
+ Power to raise magnitude to before taking log, by default 2.0
28
+ weight : float, optional
29
+ Weight of this loss, by default 1.0
30
+ match_stride : bool, optional
31
+ Whether to match the stride of convolutional layers, by default False
32
+
33
+ References
34
+ ----------
35
+
36
+ 1. Engel, Jesse, Chenjie Gu, and Adam Roberts.
37
+ "DDSP: Differentiable Digital Signal Processing."
38
+ International Conference on Learning Representations. 2019.
39
+ """
40
+
41
+ def __init__(
42
+ self,
43
+ window_lengths: List[int] = [2048, 512],
44
+ loss_fn: typing.Callable = nn.L1Loss(),
45
+ clamp_eps: float = 1e-5,
46
+ mag_weight: float = 1.0,
47
+ log_weight: float = 1.0,
48
+ pow: float = 2.0,
49
+ weight: float = 1.0,
50
+ match_stride: bool = False,
51
+ window_type: str = None,
52
+ ):
53
+ super().__init__()
54
+ self.stft_params = [
55
+ STFTParams(
56
+ window_length=w,
57
+ hop_length=w // 4,
58
+ match_stride=match_stride,
59
+ window_type=window_type,
60
+ )
61
+ for w in window_lengths
62
+ ]
63
+ self.loss_fn = loss_fn
64
+ self.log_weight = log_weight
65
+ self.mag_weight = mag_weight
66
+ self.clamp_eps = clamp_eps
67
+ self.weight = weight
68
+ self.pow = pow
69
+
70
+ def forward(self, x: AudioSignal, y: AudioSignal):
71
+ """Computes multi-scale STFT between an estimate and a reference
72
+ signal.
73
+
74
+ Parameters
75
+ ----------
76
+ x : AudioSignal
77
+ Estimate signal
78
+ y : AudioSignal
79
+ Reference signal
80
+
81
+ Returns
82
+ -------
83
+ torch.Tensor
84
+ Multi-scale STFT loss.
85
+ """
86
+ loss = 0.0
87
+ for s in self.stft_params:
88
+ x.stft(s.window_length, s.hop_length, s.window_type)
89
+ y.stft(s.window_length, s.hop_length, s.window_type)
90
+ loss += self.log_weight * self.loss_fn(
91
+ x.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
92
+ y.magnitude.clamp(self.clamp_eps).pow(self.pow).log10(),
93
+ )
94
+ loss += self.mag_weight * self.loss_fn(x.magnitude, y.magnitude)
95
+ return loss
96
+
97
+
98
+ class MelSpectrogramLoss(nn.Module):
99
+ """Compute distance between mel spectrograms. Can be used
100
+ in a multi-scale way.
101
+
102
+ Parameters
103
+ ----------
104
+ n_mels : List[int]
105
+ Number of mels per STFT, by default [150, 80],
106
+ window_lengths : List[int], optional
107
+ Length of each window of each STFT, by default [2048, 512]
108
+ loss_fn : typing.Callable, optional
109
+ How to compare each loss, by default nn.L1Loss()
110
+ clamp_eps : float, optional
111
+ Clamp on the log magnitude, below, by default 1e-5
112
+ mag_weight : float, optional
113
+ Weight of raw magnitude portion of loss, by default 1.0
114
+ log_weight : float, optional
115
+ Weight of log magnitude portion of loss, by default 1.0
116
+ pow : float, optional
117
+ Power to raise magnitude to before taking log, by default 2.0
118
+ weight : float, optional
119
+ Weight of this loss, by default 1.0
120
+ match_stride : bool, optional
121
+ Whether to match the stride of convolutional layers, by default False
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ n_mels: List[int] = [150, 80],
127
+ window_lengths: List[int] = [2048, 512],
128
+ loss_fn: typing.Callable = nn.L1Loss(),
129
+ clamp_eps: float = 1e-5,
130
+ mag_weight: float = 1.0,
131
+ log_weight: float = 1.0,
132
+ pow: float = 2.0,
133
+ weight: float = 1.0,
134
+ match_stride: bool = False,
135
+ mel_fmin: List[float] = [0.0, 0.0],
136
+ mel_fmax: List[float] = [None, None],
137
+ window_type: str = None,
138
+ ):
139
+ super().__init__()
140
+ self.stft_params = [
141
+ STFTParams(
142
+ window_length=w,
143
+ hop_length=w // 4,
144
+ match_stride=match_stride,
145
+ window_type=window_type,
146
+ )
147
+ for w in window_lengths
148
+ ]
149
+ self.n_mels = n_mels
150
+ self.loss_fn = loss_fn
151
+ self.clamp_eps = clamp_eps
152
+ self.log_weight = log_weight
153
+ self.mag_weight = mag_weight
154
+ self.weight = weight
155
+ self.mel_fmin = mel_fmin
156
+ self.mel_fmax = mel_fmax
157
+ self.pow = pow
158
+
159
+ def forward(self, x: AudioSignal, y: AudioSignal):
160
+ """Computes mel loss between an estimate and a reference
161
+ signal.
162
+
163
+ Parameters
164
+ ----------
165
+ x : AudioSignal
166
+ Estimate signal
167
+ y : AudioSignal
168
+ Reference signal
169
+
170
+ Returns
171
+ -------
172
+ torch.Tensor
173
+ Mel loss.
174
+ """
175
+ loss = 0.0
176
+ for n_mels, fmin, fmax, s in zip(
177
+ self.n_mels, self.mel_fmin, self.mel_fmax, self.stft_params
178
+ ):
179
+ kwargs = {
180
+ "window_length": s.window_length,
181
+ "hop_length": s.hop_length,
182
+ "window_type": s.window_type,
183
+ }
184
+ x_mels = x.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
185
+ y_mels = y.mel_spectrogram(n_mels, mel_fmin=fmin, mel_fmax=fmax, **kwargs)
186
+
187
+ loss += self.log_weight * self.loss_fn(
188
+ x_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
189
+ y_mels.clamp(self.clamp_eps).pow(self.pow).log10(),
190
+ )
191
+ loss += self.mag_weight * self.loss_fn(x_mels, y_mels)
192
+ return loss
193
+
194
+
195
+ class PhaseLoss(nn.Module):
196
+ """Difference between phase spectrograms.
197
+
198
+ Parameters
199
+ ----------
200
+ window_length : int, optional
201
+ Length of STFT window, by default 2048
202
+ hop_length : int, optional
203
+ Hop length of STFT window, by default 512
204
+ weight : float, optional
205
+ Weight of loss, by default 1.0
206
+ """
207
+
208
+ def __init__(
209
+ self, window_length: int = 2048, hop_length: int = 512, weight: float = 1.0
210
+ ):
211
+ super().__init__()
212
+
213
+ self.weight = weight
214
+ self.stft_params = STFTParams(window_length, hop_length)
215
+
216
+ def forward(self, x: AudioSignal, y: AudioSignal):
217
+ """Computes phase loss between an estimate and a reference
218
+ signal.
219
+
220
+ Parameters
221
+ ----------
222
+ x : AudioSignal
223
+ Estimate signal
224
+ y : AudioSignal
225
+ Reference signal
226
+
227
+ Returns
228
+ -------
229
+ torch.Tensor
230
+ Phase loss.
231
+ """
232
+ s = self.stft_params
233
+ x.stft(s.window_length, s.hop_length, s.window_type)
234
+ y.stft(s.window_length, s.hop_length, s.window_type)
235
+
236
+ # Take circular difference
237
+ diff = x.phase - y.phase
238
+ diff[diff < -np.pi] += 2 * np.pi
239
+ diff[diff > np.pi] -= -2 * np.pi
240
+
241
+ # Scale true magnitude to weights in [0, 1]
242
+ x_min, x_max = x.magnitude.min(), x.magnitude.max()
243
+ weights = (x.magnitude - x_min) / (x_max - x_min)
244
+
245
+ # Take weighted mean of all phase errors
246
+ loss = ((weights * diff) ** 2).mean()
247
+ return loss
audiotools/ml/__init__.py ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ from . import decorators
2
+ from . import layers
3
+ from .accelerator import Accelerator
4
+ from .experiment import Experiment
5
+ from .layers import BaseModel
audiotools/ml/accelerator.py ADDED
@@ -0,0 +1,184 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import typing
3
+
4
+ import torch
5
+ import torch.distributed as dist
6
+ from torch.nn.parallel import DataParallel
7
+ from torch.nn.parallel import DistributedDataParallel
8
+
9
+ from ..data.datasets import ResumableDistributedSampler as DistributedSampler
10
+ from ..data.datasets import ResumableSequentialSampler as SequentialSampler
11
+
12
+
13
+ class Accelerator: # pragma: no cover
14
+ """This class is used to prepare models and dataloaders for
15
+ usage with DDP or DP. Use the functions prepare_model, prepare_dataloader to
16
+ prepare the respective objects. In the case of models, they are moved to
17
+ the appropriate GPU and SyncBatchNorm is applied to them. In the case of
18
+ dataloaders, a sampler is created and the dataloader is initialized with
19
+ that sampler.
20
+
21
+ If the world size is 1, prepare_model and prepare_dataloader are
22
+ no-ops. If the environment variable ``LOCAL_RANK`` is not set, then the
23
+ script was launched without ``torchrun``, and ``DataParallel``
24
+ will be used instead of ``DistributedDataParallel`` (not recommended), if
25
+ the world size (number of GPUs) is greater than 1.
26
+
27
+ Parameters
28
+ ----------
29
+ amp : bool, optional
30
+ Whether or not to enable automatic mixed precision, by default False
31
+ """
32
+
33
+ def __init__(self, amp: bool = False):
34
+ local_rank = os.getenv("LOCAL_RANK", None)
35
+ self.world_size = torch.cuda.device_count()
36
+
37
+ self.use_ddp = self.world_size > 1 and local_rank is not None
38
+ self.use_dp = self.world_size > 1 and local_rank is None
39
+ self.device = "cpu" if self.world_size == 0 else "cuda"
40
+
41
+ if self.use_ddp:
42
+ local_rank = int(local_rank)
43
+ dist.init_process_group(
44
+ "nccl",
45
+ init_method="env://",
46
+ world_size=self.world_size,
47
+ rank=local_rank,
48
+ )
49
+
50
+ self.local_rank = 0 if local_rank is None else local_rank
51
+ self.amp = amp
52
+
53
+ class DummyScaler:
54
+ def __init__(self):
55
+ pass
56
+
57
+ def step(self, optimizer):
58
+ optimizer.step()
59
+
60
+ def scale(self, loss):
61
+ return loss
62
+
63
+ def unscale_(self, optimizer):
64
+ return optimizer
65
+
66
+ def update(self):
67
+ pass
68
+
69
+ self.scaler = torch.cuda.amp.GradScaler() if amp else DummyScaler()
70
+ self.device_ctx = (
71
+ torch.cuda.device(self.local_rank) if torch.cuda.is_available() else None
72
+ )
73
+
74
+ def __enter__(self):
75
+ if self.device_ctx is not None:
76
+ self.device_ctx.__enter__()
77
+ return self
78
+
79
+ def __exit__(self, exc_type, exc_value, traceback):
80
+ if self.device_ctx is not None:
81
+ self.device_ctx.__exit__(exc_type, exc_value, traceback)
82
+
83
+ def prepare_model(self, model: torch.nn.Module, **kwargs):
84
+ """Prepares model for DDP or DP. The model is moved to
85
+ the device of the correct rank.
86
+
87
+ Parameters
88
+ ----------
89
+ model : torch.nn.Module
90
+ Model that is converted for DDP or DP.
91
+
92
+ Returns
93
+ -------
94
+ torch.nn.Module
95
+ Wrapped model, or original model if DDP and DP are turned off.
96
+ """
97
+ model = model.to(self.device)
98
+ if self.use_ddp:
99
+ model = torch.nn.SyncBatchNorm.convert_sync_batchnorm(model)
100
+ model = DistributedDataParallel(
101
+ model, device_ids=[self.local_rank], **kwargs
102
+ )
103
+ elif self.use_dp:
104
+ model = DataParallel(model, **kwargs)
105
+ return model
106
+
107
+ # Automatic mixed-precision utilities
108
+ def autocast(self, *args, **kwargs):
109
+ """Context manager for autocasting. Arguments
110
+ go to ``torch.cuda.amp.autocast``.
111
+ """
112
+ return torch.cuda.amp.autocast(self.amp, *args, **kwargs)
113
+
114
+ def backward(self, loss: torch.Tensor):
115
+ """Backwards pass, after scaling the loss if ``amp`` is
116
+ enabled.
117
+
118
+ Parameters
119
+ ----------
120
+ loss : torch.Tensor
121
+ Loss value.
122
+ """
123
+ self.scaler.scale(loss).backward()
124
+
125
+ def step(self, optimizer: torch.optim.Optimizer):
126
+ """Steps the optimizer, using a ``scaler`` if ``amp`` is
127
+ enabled.
128
+
129
+ Parameters
130
+ ----------
131
+ optimizer : torch.optim.Optimizer
132
+ Optimizer to step forward.
133
+ """
134
+ self.scaler.step(optimizer)
135
+
136
+ def update(self):
137
+ """Updates the scale factor."""
138
+ self.scaler.update()
139
+
140
+ def prepare_dataloader(
141
+ self, dataset: typing.Iterable, start_idx: int = None, **kwargs
142
+ ):
143
+ """Wraps a dataset with a DataLoader, using the correct sampler if DDP is
144
+ enabled.
145
+
146
+ Parameters
147
+ ----------
148
+ dataset : typing.Iterable
149
+ Dataset to build Dataloader around.
150
+ start_idx : int, optional
151
+ Start index of sampler, useful if resuming from some epoch,
152
+ by default None
153
+
154
+ Returns
155
+ -------
156
+ _type_
157
+ _description_
158
+ """
159
+
160
+ if self.use_ddp:
161
+ sampler = DistributedSampler(
162
+ dataset,
163
+ start_idx,
164
+ num_replicas=self.world_size,
165
+ rank=self.local_rank,
166
+ )
167
+ if "num_workers" in kwargs:
168
+ kwargs["num_workers"] = max(kwargs["num_workers"] // self.world_size, 1)
169
+ kwargs["batch_size"] = max(kwargs["batch_size"] // self.world_size, 1)
170
+ else:
171
+ sampler = SequentialSampler(dataset, start_idx)
172
+
173
+ dataloader = torch.utils.data.DataLoader(dataset, sampler=sampler, **kwargs)
174
+ return dataloader
175
+
176
+ @staticmethod
177
+ def unwrap(model):
178
+ """Unwraps the model if it was wrapped in DDP or DP, otherwise
179
+ just returns the model. Use this to unwrap the model returned by
180
+ :py:func:`audiotools.ml.accelerator.Accelerator.prepare_model`.
181
+ """
182
+ if hasattr(model, "module"):
183
+ return model.module
184
+ return model
audiotools/ml/decorators.py ADDED
@@ -0,0 +1,440 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import os
3
+ import time
4
+ from collections import defaultdict
5
+ from functools import wraps
6
+
7
+ import torch
8
+ import torch.distributed as dist
9
+ from rich import box
10
+ from rich.console import Console
11
+ from rich.console import Group
12
+ from rich.live import Live
13
+ from rich.markdown import Markdown
14
+ from rich.padding import Padding
15
+ from rich.panel import Panel
16
+ from rich.progress import BarColumn
17
+ from rich.progress import Progress
18
+ from rich.progress import SpinnerColumn
19
+ from rich.progress import TimeElapsedColumn
20
+ from rich.progress import TimeRemainingColumn
21
+ from rich.rule import Rule
22
+ from rich.table import Table
23
+ from torch.utils.tensorboard import SummaryWriter
24
+
25
+
26
+ # This is here so that the history can be pickled.
27
+ def default_list():
28
+ return []
29
+
30
+
31
+ class Mean:
32
+ """Keeps track of the running mean, along with the latest
33
+ value.
34
+ """
35
+
36
+ def __init__(self):
37
+ self.reset()
38
+
39
+ def __call__(self):
40
+ mean = self.total / max(self.count, 1)
41
+ return mean
42
+
43
+ def reset(self):
44
+ self.count = 0
45
+ self.total = 0
46
+
47
+ def update(self, val):
48
+ if math.isfinite(val):
49
+ self.count += 1
50
+ self.total += val
51
+
52
+
53
+ def when(condition):
54
+ """Runs a function only when the condition is met. The condition is
55
+ a function that is run.
56
+
57
+ Parameters
58
+ ----------
59
+ condition : Callable
60
+ Function to run to check whether or not to run the decorated
61
+ function.
62
+
63
+ Example
64
+ -------
65
+ Checkpoint only runs every 100 iterations, and only if the
66
+ local rank is 0.
67
+
68
+ >>> i = 0
69
+ >>> rank = 0
70
+ >>>
71
+ >>> @when(lambda: i % 100 == 0 and rank == 0)
72
+ >>> def checkpoint():
73
+ >>> print("Saving to /runs/exp1")
74
+ >>>
75
+ >>> for i in range(1000):
76
+ >>> checkpoint()
77
+
78
+ """
79
+
80
+ def decorator(fn):
81
+ @wraps(fn)
82
+ def decorated(*args, **kwargs):
83
+ if condition():
84
+ return fn(*args, **kwargs)
85
+
86
+ return decorated
87
+
88
+ return decorator
89
+
90
+
91
+ def timer(prefix: str = "time"):
92
+ """Adds execution time to the output dictionary of the decorated
93
+ function. The function decorated by this must output a dictionary.
94
+ The key added will follow the form "[prefix]/[name_of_function]"
95
+
96
+ Parameters
97
+ ----------
98
+ prefix : str, optional
99
+ The key added will follow the form "[prefix]/[name_of_function]",
100
+ by default "time".
101
+ """
102
+
103
+ def decorator(fn):
104
+ @wraps(fn)
105
+ def decorated(*args, **kwargs):
106
+ s = time.perf_counter()
107
+ output = fn(*args, **kwargs)
108
+ assert isinstance(output, dict)
109
+ e = time.perf_counter()
110
+ output[f"{prefix}/{fn.__name__}"] = e - s
111
+ return output
112
+
113
+ return decorated
114
+
115
+ return decorator
116
+
117
+
118
+ class Tracker:
119
+ """
120
+ A tracker class that helps to monitor the progress of training and logging the metrics.
121
+
122
+ Attributes
123
+ ----------
124
+ metrics : dict
125
+ A dictionary containing the metrics for each label.
126
+ history : dict
127
+ A dictionary containing the history of metrics for each label.
128
+ writer : SummaryWriter
129
+ A SummaryWriter object for logging the metrics.
130
+ rank : int
131
+ The rank of the current process.
132
+ step : int
133
+ The current step of the training.
134
+ tasks : dict
135
+ A dictionary containing the progress bars and tables for each label.
136
+ pbar : Progress
137
+ A progress bar object for displaying the progress.
138
+ consoles : list
139
+ A list of console objects for logging.
140
+ live : Live
141
+ A Live object for updating the display live.
142
+
143
+ Methods
144
+ -------
145
+ print(msg: str)
146
+ Prints the given message to all consoles.
147
+ update(label: str, fn_name: str)
148
+ Updates the progress bar and table for the given label.
149
+ done(label: str, title: str)
150
+ Resets the progress bar and table for the given label and prints the final result.
151
+ track(label: str, length: int, completed: int = 0, op: dist.ReduceOp = dist.ReduceOp.AVG, ddp_active: bool = "LOCAL_RANK" in os.environ)
152
+ A decorator for tracking the progress and metrics of a function.
153
+ log(label: str, value_type: str = "value", history: bool = True)
154
+ A decorator for logging the metrics of a function.
155
+ is_best(label: str, key: str) -> bool
156
+ Checks if the latest value of the given key in the label is the best so far.
157
+ state_dict() -> dict
158
+ Returns a dictionary containing the state of the tracker.
159
+ load_state_dict(state_dict: dict) -> Tracker
160
+ Loads the state of the tracker from the given state dictionary.
161
+ """
162
+
163
+ def __init__(
164
+ self,
165
+ writer: SummaryWriter = None,
166
+ log_file: str = None,
167
+ rank: int = 0,
168
+ console_width: int = 100,
169
+ step: int = 0,
170
+ ):
171
+ """
172
+ Initializes the Tracker object.
173
+
174
+ Parameters
175
+ ----------
176
+ writer : SummaryWriter, optional
177
+ A SummaryWriter object for logging the metrics, by default None.
178
+ log_file : str, optional
179
+ The path to the log file, by default None.
180
+ rank : int, optional
181
+ The rank of the current process, by default 0.
182
+ console_width : int, optional
183
+ The width of the console, by default 100.
184
+ step : int, optional
185
+ The current step of the training, by default 0.
186
+ """
187
+ self.metrics = {}
188
+ self.history = {}
189
+ self.writer = writer
190
+ self.rank = rank
191
+ self.step = step
192
+
193
+ # Create progress bars etc.
194
+ self.tasks = {}
195
+ self.pbar = Progress(
196
+ SpinnerColumn(),
197
+ "[progress.description]{task.description}",
198
+ "{task.completed}/{task.total}",
199
+ BarColumn(),
200
+ TimeElapsedColumn(),
201
+ "/",
202
+ TimeRemainingColumn(),
203
+ )
204
+ self.consoles = [Console(width=console_width)]
205
+ self.live = Live(console=self.consoles[0], refresh_per_second=10)
206
+ if log_file is not None:
207
+ self.consoles.append(Console(width=console_width, file=open(log_file, "a")))
208
+
209
+ def print(self, msg):
210
+ """
211
+ Prints the given message to all consoles.
212
+
213
+ Parameters
214
+ ----------
215
+ msg : str
216
+ The message to be printed.
217
+ """
218
+ if self.rank == 0:
219
+ for c in self.consoles:
220
+ c.log(msg)
221
+
222
+ def update(self, label, fn_name):
223
+ """
224
+ Updates the progress bar and table for the given label.
225
+
226
+ Parameters
227
+ ----------
228
+ label : str
229
+ The label of the progress bar and table to be updated.
230
+ fn_name : str
231
+ The name of the function associated with the label.
232
+ """
233
+ if self.rank == 0:
234
+ self.pbar.advance(self.tasks[label]["pbar"])
235
+
236
+ # Create table
237
+ table = Table(title=label, expand=True, box=box.MINIMAL)
238
+ table.add_column("key", style="cyan")
239
+ table.add_column("value", style="bright_blue")
240
+ table.add_column("mean", style="bright_green")
241
+
242
+ keys = self.metrics[label]["value"].keys()
243
+ for k in keys:
244
+ value = self.metrics[label]["value"][k]
245
+ mean = self.metrics[label]["mean"][k]()
246
+ table.add_row(k, f"{value:10.6f}", f"{mean:10.6f}")
247
+
248
+ self.tasks[label]["table"] = table
249
+ tables = [t["table"] for t in self.tasks.values()]
250
+ group = Group(*tables, self.pbar)
251
+ self.live.update(
252
+ Group(
253
+ Padding("", (0, 0)),
254
+ Rule(f"[italic]{fn_name}()", style="white"),
255
+ Padding("", (0, 0)),
256
+ Panel.fit(
257
+ group, padding=(0, 5), title="[b]Progress", border_style="blue"
258
+ ),
259
+ )
260
+ )
261
+
262
+ def done(self, label: str, title: str):
263
+ """
264
+ Resets the progress bar and table for the given label and prints the final result.
265
+
266
+ Parameters
267
+ ----------
268
+ label : str
269
+ The label of the progress bar and table to be reset.
270
+ title : str
271
+ The title to be displayed when printing the final result.
272
+ """
273
+ for label in self.metrics:
274
+ for v in self.metrics[label]["mean"].values():
275
+ v.reset()
276
+
277
+ if self.rank == 0:
278
+ self.pbar.reset(self.tasks[label]["pbar"])
279
+ tables = [t["table"] for t in self.tasks.values()]
280
+ group = Group(Markdown(f"# {title}"), *tables, self.pbar)
281
+ self.print(group)
282
+
283
+ def track(
284
+ self,
285
+ label: str,
286
+ length: int,
287
+ completed: int = 0,
288
+ op: dist.ReduceOp = dist.ReduceOp.AVG,
289
+ ddp_active: bool = "LOCAL_RANK" in os.environ,
290
+ ):
291
+ """
292
+ A decorator for tracking the progress and metrics of a function.
293
+
294
+ Parameters
295
+ ----------
296
+ label : str
297
+ The label to be associated with the progress and metrics.
298
+ length : int
299
+ The total number of iterations to be completed.
300
+ completed : int, optional
301
+ The number of iterations already completed, by default 0.
302
+ op : dist.ReduceOp, optional
303
+ The reduce operation to be used, by default dist.ReduceOp.AVG.
304
+ ddp_active : bool, optional
305
+ Whether the DistributedDataParallel is active, by default "LOCAL_RANK" in os.environ.
306
+ """
307
+ self.tasks[label] = {
308
+ "pbar": self.pbar.add_task(
309
+ f"[white]Iteration ({label})", total=length, completed=completed
310
+ ),
311
+ "table": Table(),
312
+ }
313
+ self.metrics[label] = {
314
+ "value": defaultdict(),
315
+ "mean": defaultdict(lambda: Mean()),
316
+ }
317
+
318
+ def decorator(fn):
319
+ @wraps(fn)
320
+ def decorated(*args, **kwargs):
321
+ output = fn(*args, **kwargs)
322
+ if not isinstance(output, dict):
323
+ self.update(label, fn.__name__)
324
+ return output
325
+ # Collect across all DDP processes
326
+ scalar_keys = []
327
+ for k, v in output.items():
328
+ if isinstance(v, (int, float)):
329
+ v = torch.tensor([v])
330
+ if not torch.is_tensor(v):
331
+ continue
332
+ if ddp_active and v.is_cuda: # pragma: no cover
333
+ dist.all_reduce(v, op=op)
334
+ output[k] = v.detach()
335
+ if torch.numel(v) == 1:
336
+ scalar_keys.append(k)
337
+ output[k] = v.item()
338
+
339
+ # Save the outputs to tracker
340
+ for k, v in output.items():
341
+ if k not in scalar_keys:
342
+ continue
343
+ self.metrics[label]["value"][k] = v
344
+ # Update the running mean
345
+ self.metrics[label]["mean"][k].update(v)
346
+
347
+ self.update(label, fn.__name__)
348
+ return output
349
+
350
+ return decorated
351
+
352
+ return decorator
353
+
354
+ def log(self, label: str, value_type: str = "value", history: bool = True):
355
+ """
356
+ A decorator for logging the metrics of a function.
357
+
358
+ Parameters
359
+ ----------
360
+ label : str
361
+ The label to be associated with the logging.
362
+ value_type : str, optional
363
+ The type of value to be logged, by default "value".
364
+ history : bool, optional
365
+ Whether to save the history of the metrics, by default True.
366
+ """
367
+ assert value_type in ["mean", "value"]
368
+ if history:
369
+ if label not in self.history:
370
+ self.history[label] = defaultdict(default_list)
371
+
372
+ def decorator(fn):
373
+ @wraps(fn)
374
+ def decorated(*args, **kwargs):
375
+ output = fn(*args, **kwargs)
376
+ if self.rank == 0:
377
+ nonlocal value_type, label
378
+ metrics = self.metrics[label][value_type]
379
+ for k, v in metrics.items():
380
+ v = v() if isinstance(v, Mean) else v
381
+ if self.writer is not None:
382
+ self.writer.add_scalar(f"{k}/{label}", v, self.step)
383
+ if label in self.history:
384
+ self.history[label][k].append(v)
385
+
386
+ if label in self.history:
387
+ self.history[label]["step"].append(self.step)
388
+
389
+ return output
390
+
391
+ return decorated
392
+
393
+ return decorator
394
+
395
+ def is_best(self, label, key):
396
+ """
397
+ Checks if the latest value of the given key in the label is the best so far.
398
+
399
+ Parameters
400
+ ----------
401
+ label : str
402
+ The label of the metrics to be checked.
403
+ key : str
404
+ The key of the metric to be checked.
405
+
406
+ Returns
407
+ -------
408
+ bool
409
+ True if the latest value is the best so far, otherwise False.
410
+ """
411
+ return self.history[label][key][-1] == min(self.history[label][key])
412
+
413
+ def state_dict(self):
414
+ """
415
+ Returns a dictionary containing the state of the tracker.
416
+
417
+ Returns
418
+ -------
419
+ dict
420
+ A dictionary containing the history and step of the tracker.
421
+ """
422
+ return {"history": self.history, "step": self.step}
423
+
424
+ def load_state_dict(self, state_dict):
425
+ """
426
+ Loads the state of the tracker from the given state dictionary.
427
+
428
+ Parameters
429
+ ----------
430
+ state_dict : dict
431
+ A dictionary containing the history and step of the tracker.
432
+
433
+ Returns
434
+ -------
435
+ Tracker
436
+ The tracker object with the loaded state.
437
+ """
438
+ self.history = state_dict["history"]
439
+ self.step = state_dict["step"]
440
+ return self
audiotools/ml/experiment.py ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Useful class for Experiment tracking, and ensuring code is
3
+ saved alongside files.
4
+ """ # fmt: skip
5
+ import datetime
6
+ import os
7
+ import shlex
8
+ import shutil
9
+ import subprocess
10
+ import typing
11
+ from pathlib import Path
12
+
13
+ import randomname
14
+
15
+
16
+ class Experiment:
17
+ """This class contains utilities for managing experiments.
18
+ It is a context manager, that when you enter it, changes
19
+ your directory to a specified experiment folder (which
20
+ optionally can have an automatically generated experiment
21
+ name, or a specified one), and changes the CUDA device used
22
+ to the specified device (or devices).
23
+
24
+ Parameters
25
+ ----------
26
+ exp_directory : str
27
+ Folder where all experiments are saved, by default "runs/".
28
+ exp_name : str, optional
29
+ Name of the experiment, by default uses the current time, date, and
30
+ hostname to save.
31
+ """
32
+
33
+ def __init__(
34
+ self,
35
+ exp_directory: str = "runs/",
36
+ exp_name: str = None,
37
+ ):
38
+ if exp_name is None:
39
+ exp_name = self.generate_exp_name()
40
+ exp_dir = Path(exp_directory) / exp_name
41
+ exp_dir.mkdir(parents=True, exist_ok=True)
42
+
43
+ self.exp_dir = exp_dir
44
+ self.exp_name = exp_name
45
+ self.git_tracked_files = (
46
+ subprocess.check_output(
47
+ shlex.split("git ls-tree --full-tree --name-only -r HEAD")
48
+ )
49
+ .decode("utf-8")
50
+ .splitlines()
51
+ )
52
+ self.parent_directory = Path(".").absolute()
53
+
54
+ def __enter__(self):
55
+ self.prev_dir = os.getcwd()
56
+ os.chdir(self.exp_dir)
57
+ return self
58
+
59
+ def __exit__(self, exc_type, exc_value, traceback):
60
+ os.chdir(self.prev_dir)
61
+
62
+ @staticmethod
63
+ def generate_exp_name():
64
+ """Generates a random experiment name based on the date
65
+ and a randomly generated adjective-noun tuple.
66
+
67
+ Returns
68
+ -------
69
+ str
70
+ Randomly generated experiment name.
71
+ """
72
+ date = datetime.datetime.now().strftime("%y%m%d")
73
+ name = f"{date}-{randomname.get_name()}"
74
+ return name
75
+
76
+ def snapshot(self, filter_fn: typing.Callable = lambda f: True):
77
+ """Captures a full snapshot of all the files tracked by git at the time
78
+ the experiment is run. It also captures the diff against the committed
79
+ code as a separate file.
80
+
81
+ Parameters
82
+ ----------
83
+ filter_fn : typing.Callable, optional
84
+ Function that can be used to exclude some files
85
+ from the snapshot, by default accepts all files
86
+ """
87
+ for f in self.git_tracked_files:
88
+ if filter_fn(f):
89
+ Path(f).parent.mkdir(parents=True, exist_ok=True)
90
+ shutil.copyfile(self.parent_directory / f, f)
audiotools/ml/layers/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ from .base import BaseModel
2
+ from .spectral_gate import SpectralGate
audiotools/ml/layers/base.py ADDED
@@ -0,0 +1,328 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import inspect
2
+ import shutil
3
+ import tempfile
4
+ import typing
5
+ from pathlib import Path
6
+
7
+ import torch
8
+ from torch import nn
9
+
10
+
11
+ class BaseModel(nn.Module):
12
+ """This is a class that adds useful save/load functionality to a
13
+ ``torch.nn.Module`` object. ``BaseModel`` objects can be saved
14
+ as ``torch.package`` easily, making them super easy to port between
15
+ machines without requiring a ton of dependencies. Files can also be
16
+ saved as just weights, in the standard way.
17
+
18
+ >>> class Model(ml.BaseModel):
19
+ >>> def __init__(self, arg1: float = 1.0):
20
+ >>> super().__init__()
21
+ >>> self.arg1 = arg1
22
+ >>> self.linear = nn.Linear(1, 1)
23
+ >>>
24
+ >>> def forward(self, x):
25
+ >>> return self.linear(x)
26
+ >>>
27
+ >>> model1 = Model()
28
+ >>>
29
+ >>> with tempfile.NamedTemporaryFile(suffix=".pth") as f:
30
+ >>> model1.save(
31
+ >>> f.name,
32
+ >>> )
33
+ >>> model2 = Model.load(f.name)
34
+ >>> out2 = seed_and_run(model2, x)
35
+ >>> assert torch.allclose(out1, out2)
36
+ >>>
37
+ >>> model1.save(f.name, package=True)
38
+ >>> model2 = Model.load(f.name)
39
+ >>> model2.save(f.name, package=False)
40
+ >>> model3 = Model.load(f.name)
41
+ >>> out3 = seed_and_run(model3, x)
42
+ >>>
43
+ >>> with tempfile.TemporaryDirectory() as d:
44
+ >>> model1.save_to_folder(d, {"data": 1.0})
45
+ >>> Model.load_from_folder(d)
46
+
47
+ """
48
+
49
+ EXTERN = [
50
+ "audiotools.**",
51
+ "tqdm",
52
+ "__main__",
53
+ "numpy.**",
54
+ "julius.**",
55
+ "torchaudio.**",
56
+ "scipy.**",
57
+ "einops",
58
+ ]
59
+ """Names of libraries that are external to the torch.package saving mechanism.
60
+ Source code from these libraries will not be packaged into the model. This can
61
+ be edited by the user of this class by editing ``model.EXTERN``."""
62
+ INTERN = []
63
+ """Names of libraries that are internal to the torch.package saving mechanism.
64
+ Source code from these libraries will be saved alongside the model."""
65
+
66
+ def save(
67
+ self,
68
+ path: str,
69
+ metadata: dict = None,
70
+ package: bool = True,
71
+ intern: list = [],
72
+ extern: list = [],
73
+ mock: list = [],
74
+ ):
75
+ """Saves the model, either as a torch package, or just as
76
+ weights, alongside some specified metadata.
77
+
78
+ Parameters
79
+ ----------
80
+ path : str
81
+ Path to save model to.
82
+ metadata : dict, optional
83
+ Any metadata to save alongside the model,
84
+ by default None
85
+ package : bool, optional
86
+ Whether to use ``torch.package`` to save the model in
87
+ a format that is portable, by default True
88
+ intern : list, optional
89
+ List of additional libraries that are internal
90
+ to the model, used with torch.package, by default []
91
+ extern : list, optional
92
+ List of additional libraries that are external to
93
+ the model, used with torch.package, by default []
94
+ mock : list, optional
95
+ List of libraries to mock, used with torch.package,
96
+ by default []
97
+
98
+ Returns
99
+ -------
100
+ str
101
+ Path to saved model.
102
+ """
103
+ sig = inspect.signature(self.__class__)
104
+ args = {}
105
+
106
+ for key, val in sig.parameters.items():
107
+ arg_val = val.default
108
+ if arg_val is not inspect.Parameter.empty:
109
+ args[key] = arg_val
110
+
111
+ # Look up attibutes in self, and if any of them are in args,
112
+ # overwrite them in args.
113
+ for attribute in dir(self):
114
+ if attribute in args:
115
+ args[attribute] = getattr(self, attribute)
116
+
117
+ metadata = {} if metadata is None else metadata
118
+ metadata["kwargs"] = args
119
+ if not hasattr(self, "metadata"):
120
+ self.metadata = {}
121
+ self.metadata.update(metadata)
122
+
123
+ if not package:
124
+ state_dict = {"state_dict": self.state_dict(), "metadata": metadata}
125
+ torch.save(state_dict, path)
126
+ else:
127
+ self._save_package(path, intern=intern, extern=extern, mock=mock)
128
+
129
+ return path
130
+
131
+ @property
132
+ def device(self):
133
+ """Gets the device the model is on by looking at the device of
134
+ the first parameter. May not be valid if model is split across
135
+ multiple devices.
136
+ """
137
+ return list(self.parameters())[0].device
138
+
139
+ @classmethod
140
+ def load(
141
+ cls,
142
+ location: str,
143
+ *args,
144
+ package_name: str = None,
145
+ strict: bool = False,
146
+ **kwargs,
147
+ ):
148
+ """Load model from a path. Tries first to load as a package, and if
149
+ that fails, tries to load as weights. The arguments to the class are
150
+ specified inside the model weights file.
151
+
152
+ Parameters
153
+ ----------
154
+ location : str
155
+ Path to file.
156
+ package_name : str, optional
157
+ Name of package, by default ``cls.__name__``.
158
+ strict : bool, optional
159
+ Ignore unmatched keys, by default False
160
+ kwargs : dict
161
+ Additional keyword arguments to the model instantiation, if
162
+ not loading from package.
163
+
164
+ Returns
165
+ -------
166
+ BaseModel
167
+ A model that inherits from BaseModel.
168
+ """
169
+ try:
170
+ model = cls._load_package(location, package_name=package_name)
171
+ except:
172
+ model_dict = torch.load(location, "cpu")
173
+ metadata = model_dict["metadata"]
174
+ metadata["kwargs"].update(kwargs)
175
+
176
+ sig = inspect.signature(cls)
177
+ class_keys = list(sig.parameters.keys())
178
+ for k in list(metadata["kwargs"].keys()):
179
+ if k not in class_keys:
180
+ metadata["kwargs"].pop(k)
181
+
182
+ model = cls(*args, **metadata["kwargs"])
183
+ model.load_state_dict(model_dict["state_dict"], strict=strict)
184
+ model.metadata = metadata
185
+
186
+ return model
187
+
188
+ def _save_package(self, path, intern=[], extern=[], mock=[], **kwargs):
189
+ package_name = type(self).__name__
190
+ resource_name = f"{type(self).__name__}.pth"
191
+
192
+ # Below is for loading and re-saving a package.
193
+ if hasattr(self, "importer"):
194
+ kwargs["importer"] = (self.importer, torch.package.sys_importer)
195
+ del self.importer
196
+
197
+ # Why do we use a tempfile, you ask?
198
+ # It's so we can load a packaged model and then re-save
199
+ # it to the same location. torch.package throws an
200
+ # error if it's loading and writing to the same
201
+ # file (this is undocumented).
202
+ with tempfile.NamedTemporaryFile(suffix=".pth") as f:
203
+ with torch.package.PackageExporter(f.name, **kwargs) as exp:
204
+ exp.intern(self.INTERN + intern)
205
+ exp.mock(mock)
206
+ exp.extern(self.EXTERN + extern)
207
+ exp.save_pickle(package_name, resource_name, self)
208
+
209
+ if hasattr(self, "metadata"):
210
+ exp.save_pickle(
211
+ package_name, f"{package_name}.metadata", self.metadata
212
+ )
213
+
214
+ shutil.copyfile(f.name, path)
215
+
216
+ # Must reset the importer back to `self` if it existed
217
+ # so that you can save the model again!
218
+ if "importer" in kwargs:
219
+ self.importer = kwargs["importer"][0]
220
+ return path
221
+
222
+ @classmethod
223
+ def _load_package(cls, path, package_name=None):
224
+ package_name = cls.__name__ if package_name is None else package_name
225
+ resource_name = f"{package_name}.pth"
226
+
227
+ imp = torch.package.PackageImporter(path)
228
+ model = imp.load_pickle(package_name, resource_name, "cpu")
229
+ try:
230
+ model.metadata = imp.load_pickle(package_name, f"{package_name}.metadata")
231
+ except: # pragma: no cover
232
+ pass
233
+ model.importer = imp
234
+
235
+ return model
236
+
237
+ def save_to_folder(
238
+ self,
239
+ folder: typing.Union[str, Path],
240
+ extra_data: dict = None,
241
+ package: bool = True,
242
+ ):
243
+ """Dumps a model into a folder, as both a package
244
+ and as weights, as well as anything specified in
245
+ ``extra_data``. ``extra_data`` is a dictionary of other
246
+ pickleable files, with the keys being the paths
247
+ to save them in. The model is saved under a subfolder
248
+ specified by the name of the class (e.g. ``folder/generator/[package, weights].pth``
249
+ if the model name was ``Generator``).
250
+
251
+ >>> with tempfile.TemporaryDirectory() as d:
252
+ >>> extra_data = {
253
+ >>> "optimizer.pth": optimizer.state_dict()
254
+ >>> }
255
+ >>> model.save_to_folder(d, extra_data)
256
+ >>> Model.load_from_folder(d)
257
+
258
+ Parameters
259
+ ----------
260
+ folder : typing.Union[str, Path]
261
+ _description_
262
+ extra_data : dict, optional
263
+ _description_, by default None
264
+
265
+ Returns
266
+ -------
267
+ str
268
+ Path to folder
269
+ """
270
+ extra_data = {} if extra_data is None else extra_data
271
+ model_name = type(self).__name__.lower()
272
+ target_base = Path(f"{folder}/{model_name}/")
273
+ target_base.mkdir(exist_ok=True, parents=True)
274
+
275
+ if package:
276
+ package_path = target_base / f"package.pth"
277
+ self.save(package_path)
278
+
279
+ weights_path = target_base / f"weights.pth"
280
+ self.save(weights_path, package=False)
281
+
282
+ for path, obj in extra_data.items():
283
+ torch.save(obj, target_base / path)
284
+
285
+ return target_base
286
+
287
+ @classmethod
288
+ def load_from_folder(
289
+ cls,
290
+ folder: typing.Union[str, Path],
291
+ package: bool = True,
292
+ strict: bool = False,
293
+ **kwargs,
294
+ ):
295
+ """Loads the model from a folder generated by
296
+ :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
297
+ Like that function, this one looks for a subfolder that has
298
+ the name of the class (e.g. ``folder/generator/[package, weights].pth`` if the
299
+ model name was ``Generator``).
300
+
301
+ Parameters
302
+ ----------
303
+ folder : typing.Union[str, Path]
304
+ _description_
305
+ package : bool, optional
306
+ Whether to use ``torch.package`` to load the model,
307
+ loading the model from ``package.pth``.
308
+ strict : bool, optional
309
+ Ignore unmatched keys, by default False
310
+
311
+ Returns
312
+ -------
313
+ tuple
314
+ tuple of model and extra data as saved by
315
+ :py:func:`audiotools.ml.layers.base.BaseModel.save_to_folder`.
316
+ """
317
+ folder = Path(folder) / cls.__name__.lower()
318
+ model_pth = "package.pth" if package else "weights.pth"
319
+ model_pth = folder / model_pth
320
+
321
+ model = cls.load(model_pth, strict=strict)
322
+ extra_data = {}
323
+ excluded = ["package.pth", "weights.pth"]
324
+ files = [x for x in folder.glob("*") if x.is_file() and x.name not in excluded]
325
+ for f in files:
326
+ extra_data[f.name] = torch.load(f, **kwargs)
327
+
328
+ return model, extra_data
audiotools/ml/layers/spectral_gate.py ADDED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn.functional as F
3
+ from torch import nn
4
+
5
+ from ...core import AudioSignal
6
+ from ...core import STFTParams
7
+ from ...core import util
8
+
9
+
10
+ class SpectralGate(nn.Module):
11
+ """Spectral gating algorithm for noise reduction,
12
+ as in Audacity/Ocenaudio. The steps are as follows:
13
+
14
+ 1. An FFT is calculated over the noise audio clip
15
+ 2. Statistics are calculated over FFT of the the noise
16
+ (in frequency)
17
+ 3. A threshold is calculated based upon the statistics
18
+ of the noise (and the desired sensitivity of the algorithm)
19
+ 4. An FFT is calculated over the signal
20
+ 5. A mask is determined by comparing the signal FFT to the
21
+ threshold
22
+ 6. The mask is smoothed with a filter over frequency and time
23
+ 7. The mask is appled to the FFT of the signal, and is inverted
24
+
25
+ Implementation inspired by Tim Sainburg's noisereduce:
26
+
27
+ https://timsainburg.com/noise-reduction-python.html
28
+
29
+ Parameters
30
+ ----------
31
+ n_freq : int, optional
32
+ Number of frequency bins to smooth by, by default 3
33
+ n_time : int, optional
34
+ Number of time bins to smooth by, by default 5
35
+ """
36
+
37
+ def __init__(self, n_freq: int = 3, n_time: int = 5):
38
+ super().__init__()
39
+
40
+ smoothing_filter = torch.outer(
41
+ torch.cat(
42
+ [
43
+ torch.linspace(0, 1, n_freq + 2)[:-1],
44
+ torch.linspace(1, 0, n_freq + 2),
45
+ ]
46
+ )[..., 1:-1],
47
+ torch.cat(
48
+ [
49
+ torch.linspace(0, 1, n_time + 2)[:-1],
50
+ torch.linspace(1, 0, n_time + 2),
51
+ ]
52
+ )[..., 1:-1],
53
+ )
54
+ smoothing_filter = smoothing_filter / smoothing_filter.sum()
55
+ smoothing_filter = smoothing_filter.unsqueeze(0).unsqueeze(0)
56
+ self.register_buffer("smoothing_filter", smoothing_filter)
57
+
58
+ def forward(
59
+ self,
60
+ audio_signal: AudioSignal,
61
+ nz_signal: AudioSignal,
62
+ denoise_amount: float = 1.0,
63
+ n_std: float = 3.0,
64
+ win_length: int = 2048,
65
+ hop_length: int = 512,
66
+ ):
67
+ """Perform noise reduction.
68
+
69
+ Parameters
70
+ ----------
71
+ audio_signal : AudioSignal
72
+ Audio signal that noise will be removed from.
73
+ nz_signal : AudioSignal, optional
74
+ Noise signal to compute noise statistics from.
75
+ denoise_amount : float, optional
76
+ Amount to denoise by, by default 1.0
77
+ n_std : float, optional
78
+ Number of standard deviations above which to consider
79
+ noise, by default 3.0
80
+ win_length : int, optional
81
+ Length of window for STFT, by default 2048
82
+ hop_length : int, optional
83
+ Hop length for STFT, by default 512
84
+
85
+ Returns
86
+ -------
87
+ AudioSignal
88
+ Denoised audio signal.
89
+ """
90
+ stft_params = STFTParams(win_length, hop_length, "sqrt_hann")
91
+
92
+ audio_signal = audio_signal.clone()
93
+ audio_signal.stft_data = None
94
+ audio_signal.stft_params = stft_params
95
+
96
+ nz_signal = nz_signal.clone()
97
+ nz_signal.stft_params = stft_params
98
+
99
+ nz_stft_db = 20 * nz_signal.magnitude.clamp(1e-4).log10()
100
+ nz_freq_mean = nz_stft_db.mean(keepdim=True, dim=-1)
101
+ nz_freq_std = nz_stft_db.std(keepdim=True, dim=-1)
102
+
103
+ nz_thresh = nz_freq_mean + nz_freq_std * n_std
104
+
105
+ stft_db = 20 * audio_signal.magnitude.clamp(1e-4).log10()
106
+ nb, nac, nf, nt = stft_db.shape
107
+ db_thresh = nz_thresh.expand(nb, nac, -1, nt)
108
+
109
+ stft_mask = (stft_db < db_thresh).float()
110
+ shape = stft_mask.shape
111
+
112
+ stft_mask = stft_mask.reshape(nb * nac, 1, nf, nt)
113
+ pad_tuple = (
114
+ self.smoothing_filter.shape[-2] // 2,
115
+ self.smoothing_filter.shape[-1] // 2,
116
+ )
117
+ stft_mask = F.conv2d(stft_mask, self.smoothing_filter, padding=pad_tuple)
118
+ stft_mask = stft_mask.reshape(*shape)
119
+ stft_mask *= util.ensure_tensor(denoise_amount, ndim=stft_mask.ndim).to(
120
+ audio_signal.device
121
+ )
122
+ stft_mask = 1 - stft_mask
123
+
124
+ audio_signal.stft_data *= stft_mask
125
+ audio_signal.istft()
126
+
127
+ return audio_signal
audiotools/post.py ADDED
@@ -0,0 +1,140 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import tempfile
2
+ import typing
3
+ import zipfile
4
+ from pathlib import Path
5
+
6
+ import markdown2 as md
7
+ import matplotlib.pyplot as plt
8
+ import torch
9
+ from IPython.display import HTML
10
+
11
+
12
+ def audio_table(
13
+ audio_dict: dict,
14
+ first_column: str = None,
15
+ format_fn: typing.Callable = None,
16
+ **kwargs,
17
+ ): # pragma: no cover
18
+ """Embeds an audio table into HTML, or as the output cell
19
+ in a notebook.
20
+
21
+ Parameters
22
+ ----------
23
+ audio_dict : dict
24
+ Dictionary of data to embed.
25
+ first_column : str, optional
26
+ The label for the first column of the table, by default None
27
+ format_fn : typing.Callable, optional
28
+ How to format the data, by default None
29
+
30
+ Returns
31
+ -------
32
+ str
33
+ Table as a string
34
+
35
+ Examples
36
+ --------
37
+
38
+ >>> audio_dict = {}
39
+ >>> for i in range(signal_batch.batch_size):
40
+ >>> audio_dict[i] = {
41
+ >>> "input": signal_batch[i],
42
+ >>> "output": output_batch[i]
43
+ >>> }
44
+ >>> audiotools.post.audio_zip(audio_dict)
45
+
46
+ """
47
+ from audiotools import AudioSignal
48
+
49
+ output = []
50
+ columns = None
51
+
52
+ def _default_format_fn(label, x, **kwargs):
53
+ if torch.is_tensor(x):
54
+ x = x.tolist()
55
+
56
+ if x is None:
57
+ return "."
58
+ elif isinstance(x, AudioSignal):
59
+ return x.embed(display=False, return_html=True, **kwargs)
60
+ else:
61
+ return str(x)
62
+
63
+ if format_fn is None:
64
+ format_fn = _default_format_fn
65
+
66
+ if first_column is None:
67
+ first_column = "."
68
+
69
+ for k, v in audio_dict.items():
70
+ if not isinstance(v, dict):
71
+ v = {"Audio": v}
72
+
73
+ v_keys = list(v.keys())
74
+ if columns is None:
75
+ columns = [first_column] + v_keys
76
+ output.append(" | ".join(columns))
77
+
78
+ layout = "|---" + len(v_keys) * "|:-:"
79
+ output.append(layout)
80
+
81
+ formatted_audio = []
82
+ for col in columns[1:]:
83
+ formatted_audio.append(format_fn(col, v[col], **kwargs))
84
+
85
+ row = f"| {k} | "
86
+ row += " | ".join(formatted_audio)
87
+ output.append(row)
88
+
89
+ output = "\n" + "\n".join(output)
90
+ return output
91
+
92
+
93
+ def in_notebook(): # pragma: no cover
94
+ """Determines if code is running in a notebook.
95
+
96
+ Returns
97
+ -------
98
+ bool
99
+ Whether or not this is running in a notebook.
100
+ """
101
+ try:
102
+ from IPython import get_ipython
103
+
104
+ if "IPKernelApp" not in get_ipython().config: # pragma: no cover
105
+ return False
106
+ except ImportError:
107
+ return False
108
+ except AttributeError:
109
+ return False
110
+ return True
111
+
112
+
113
+ def disp(obj, **kwargs): # pragma: no cover
114
+ """Displays an object, depending on if its in a notebook
115
+ or not.
116
+
117
+ Parameters
118
+ ----------
119
+ obj : typing.Any
120
+ Any object to display.
121
+
122
+ """
123
+ from audiotools import AudioSignal
124
+
125
+ IN_NOTEBOOK = in_notebook()
126
+
127
+ if isinstance(obj, AudioSignal):
128
+ audio_elem = obj.embed(display=False, return_html=True)
129
+ if IN_NOTEBOOK:
130
+ return HTML(audio_elem)
131
+ else:
132
+ print(audio_elem)
133
+ if isinstance(obj, dict):
134
+ table = audio_table(obj, **kwargs)
135
+ if IN_NOTEBOOK:
136
+ return HTML(md.markdown(table, extras=["tables"]))
137
+ else:
138
+ print(table)
139
+ if isinstance(obj, plt.Figure):
140
+ plt.show()
audiotools/preference.py ADDED
@@ -0,0 +1,600 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ##############################################################
2
+ ### Tools for creating preference tests (MUSHRA, ABX, etc) ###
3
+ ##############################################################
4
+ import copy
5
+ import csv
6
+ import random
7
+ import sys
8
+ import traceback
9
+ from collections import defaultdict
10
+ from pathlib import Path
11
+ from typing import List
12
+
13
+ import gradio as gr
14
+
15
+ from audiotools.core.util import find_audio
16
+
17
+ ################################################################
18
+ ### Logic for audio player, and adding audio / play buttons. ###
19
+ ################################################################
20
+
21
+ WAVESURFER = """<div id="waveform"></div><div id="wave-timeline"></div>"""
22
+
23
+ CUSTOM_CSS = """
24
+ .gradio-container {
25
+ max-width: 840px !important;
26
+ }
27
+ region.wavesurfer-region:before {
28
+ content: attr(data-region-label);
29
+ }
30
+
31
+ block {
32
+ min-width: 0 !important;
33
+ }
34
+
35
+ #wave-timeline {
36
+ background-color: rgba(0, 0, 0, 0.8);
37
+ }
38
+
39
+ .head.svelte-1cl284s {
40
+ display: none;
41
+ }
42
+ """
43
+
44
+ load_wavesurfer_js = """
45
+ function load_wavesurfer() {
46
+ function load_script(url) {
47
+ const script = document.createElement('script');
48
+ script.src = url;
49
+ document.body.appendChild(script);
50
+
51
+ return new Promise((res, rej) => {
52
+ script.onload = function() {
53
+ res();
54
+ }
55
+ script.onerror = function () {
56
+ rej();
57
+ }
58
+ });
59
+ }
60
+
61
+ function create_wavesurfer() {
62
+ var options = {
63
+ container: '#waveform',
64
+ waveColor: '#F2F2F2', // Set a darker wave color
65
+ progressColor: 'white', // Set a slightly lighter progress color
66
+ loaderColor: 'white', // Set a slightly lighter loader color
67
+ cursorColor: 'black', // Set a slightly lighter cursor color
68
+ backgroundColor: '#00AAFF', // Set a black background color
69
+ barWidth: 4,
70
+ barRadius: 3,
71
+ barHeight: 1, // the height of the wave
72
+ plugins: [
73
+ WaveSurfer.regions.create({
74
+ regionsMinLength: 0.0,
75
+ dragSelection: {
76
+ slop: 5
77
+ },
78
+ color: 'hsla(200, 50%, 70%, 0.4)',
79
+ }),
80
+ WaveSurfer.timeline.create({
81
+ container: "#wave-timeline",
82
+ primaryLabelInterval: 5.0,
83
+ secondaryLabelInterval: 1.0,
84
+ primaryFontColor: '#F2F2F2',
85
+ secondaryFontColor: '#F2F2F2',
86
+ }),
87
+ ]
88
+ };
89
+ wavesurfer = WaveSurfer.create(options);
90
+ wavesurfer.on('region-created', region => {
91
+ wavesurfer.regions.clear();
92
+ });
93
+ wavesurfer.on('finish', function () {
94
+ var loop = document.getElementById("loop-button").textContent.includes("ON");
95
+ if (loop) {
96
+ wavesurfer.play();
97
+ }
98
+ else {
99
+ var button_elements = document.getElementsByClassName('playpause')
100
+ var buttons = Array.from(button_elements);
101
+
102
+ for (let j = 0; j < buttons.length; j++) {
103
+ buttons[j].classList.remove("primary");
104
+ buttons[j].classList.add("secondary");
105
+ buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
106
+ }
107
+ }
108
+ });
109
+
110
+ wavesurfer.on('region-out', function () {
111
+ var loop = document.getElementById("loop-button").textContent.includes("ON");
112
+ if (!loop) {
113
+ var button_elements = document.getElementsByClassName('playpause')
114
+ var buttons = Array.from(button_elements);
115
+
116
+ for (let j = 0; j < buttons.length; j++) {
117
+ buttons[j].classList.remove("primary");
118
+ buttons[j].classList.add("secondary");
119
+ buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
120
+ }
121
+ wavesurfer.pause();
122
+ }
123
+ });
124
+
125
+ console.log("Created WaveSurfer object.")
126
+ }
127
+
128
+ load_script('https://unpkg.com/[email protected]')
129
+ .then(() => {
130
+ load_script("https://unpkg.com/[email protected]/dist/plugin/wavesurfer.timeline.min.js")
131
+ .then(() => {
132
+ load_script('https://unpkg.com/[email protected]/dist/plugin/wavesurfer.regions.min.js')
133
+ .then(() => {
134
+ console.log("Loaded regions");
135
+ create_wavesurfer();
136
+ document.getElementById("start-survey").click();
137
+ })
138
+ })
139
+ });
140
+ }
141
+ """
142
+
143
+ play = lambda i: """
144
+ function play() {
145
+ var audio_elements = document.getElementsByTagName('audio');
146
+ var button_elements = document.getElementsByClassName('playpause')
147
+
148
+ var audio_array = Array.from(audio_elements);
149
+ var buttons = Array.from(button_elements);
150
+
151
+ var src_link = audio_array[{i}].getAttribute("src");
152
+ console.log(src_link);
153
+
154
+ var loop = document.getElementById("loop-button").textContent.includes("ON");
155
+ var playing = buttons[{i}].textContent.includes("Stop");
156
+
157
+ for (let j = 0; j < buttons.length; j++) {
158
+ if (j != {i} || playing) {
159
+ buttons[j].classList.remove("primary");
160
+ buttons[j].classList.add("secondary");
161
+ buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
162
+ }
163
+ else {
164
+ buttons[j].classList.remove("secondary");
165
+ buttons[j].classList.add("primary");
166
+ buttons[j].textContent = buttons[j].textContent.replace("Play", "Stop")
167
+ }
168
+ }
169
+
170
+ if (playing) {
171
+ wavesurfer.pause();
172
+ wavesurfer.seekTo(0.0);
173
+ }
174
+ else {
175
+ wavesurfer.load(src_link);
176
+ wavesurfer.on('ready', function () {
177
+ var region = Object.values(wavesurfer.regions.list)[0];
178
+
179
+ if (region != null) {
180
+ region.loop = loop;
181
+ region.play();
182
+ } else {
183
+ wavesurfer.play();
184
+ }
185
+ });
186
+ }
187
+ }
188
+ """.replace(
189
+ "{i}", str(i)
190
+ )
191
+
192
+ clear_regions = """
193
+ function clear_regions() {
194
+ wavesurfer.clearRegions();
195
+ }
196
+ """
197
+
198
+ reset_player = """
199
+ function reset_player() {
200
+ wavesurfer.clearRegions();
201
+ wavesurfer.pause();
202
+ wavesurfer.seekTo(0.0);
203
+
204
+ var button_elements = document.getElementsByClassName('playpause')
205
+ var buttons = Array.from(button_elements);
206
+
207
+ for (let j = 0; j < buttons.length; j++) {
208
+ buttons[j].classList.remove("primary");
209
+ buttons[j].classList.add("secondary");
210
+ buttons[j].textContent = buttons[j].textContent.replace("Stop", "Play")
211
+ }
212
+ }
213
+ """
214
+
215
+ loop_region = """
216
+ function loop_region() {
217
+ var element = document.getElementById("loop-button");
218
+ var loop = element.textContent.includes("OFF");
219
+ console.log(loop);
220
+
221
+ try {
222
+ var region = Object.values(wavesurfer.regions.list)[0];
223
+ region.loop = loop;
224
+ } catch {}
225
+
226
+ if (loop) {
227
+ element.classList.remove("secondary");
228
+ element.classList.add("primary");
229
+ element.textContent = "Looping ON";
230
+ } else {
231
+ element.classList.remove("primary");
232
+ element.classList.add("secondary");
233
+ element.textContent = "Looping OFF";
234
+ }
235
+ }
236
+ """
237
+
238
+
239
+ class Player:
240
+ def __init__(self, app):
241
+ self.app = app
242
+
243
+ self.app.load(_js=load_wavesurfer_js)
244
+ self.app.css = CUSTOM_CSS
245
+
246
+ self.wavs = []
247
+ self.position = 0
248
+
249
+ def create(self):
250
+ gr.HTML(WAVESURFER)
251
+ gr.Markdown(
252
+ "Click and drag on the waveform above to select a region for playback. "
253
+ "Once created, the region can be moved around and resized. "
254
+ "Clear the regions using the button below. Hit play on one of the buttons below to start!"
255
+ )
256
+
257
+ with gr.Row():
258
+ clear = gr.Button("Clear region")
259
+ loop = gr.Button("Looping OFF", elem_id="loop-button")
260
+
261
+ loop.click(None, _js=loop_region)
262
+ clear.click(None, _js=clear_regions)
263
+
264
+ gr.HTML("<hr>")
265
+
266
+ def add(self, name: str = "Play"):
267
+ i = self.position
268
+ self.wavs.append(
269
+ {
270
+ "audio": gr.Audio(visible=False),
271
+ "button": gr.Button(name, elem_classes=["playpause"]),
272
+ "position": i,
273
+ }
274
+ )
275
+ self.wavs[-1]["button"].click(None, _js=play(i))
276
+ self.position += 1
277
+ return self.wavs[-1]
278
+
279
+ def to_list(self):
280
+ return [x["audio"] for x in self.wavs]
281
+
282
+
283
+ ############################################################
284
+ ### Keeping track of users, and CSS for the progress bar ###
285
+ ############################################################
286
+
287
+ load_tracker = lambda name: """
288
+ function load_name() {
289
+ function setCookie(name, value, exp_days) {
290
+ var d = new Date();
291
+ d.setTime(d.getTime() + (exp_days*24*60*60*1000));
292
+ var expires = "expires=" + d.toGMTString();
293
+ document.cookie = name + "=" + value + ";" + expires + ";path=/";
294
+ }
295
+
296
+ function getCookie(name) {
297
+ var cname = name + "=";
298
+ var decodedCookie = decodeURIComponent(document.cookie);
299
+ var ca = decodedCookie.split(';');
300
+ for(var i = 0; i < ca.length; i++){
301
+ var c = ca[i];
302
+ while(c.charAt(0) == ' '){
303
+ c = c.substring(1);
304
+ }
305
+ if(c.indexOf(cname) == 0){
306
+ return c.substring(cname.length, c.length);
307
+ }
308
+ }
309
+ return "";
310
+ }
311
+
312
+ name = getCookie("{name}");
313
+ if (name == "") {
314
+ name = Math.random().toString(36).slice(2);
315
+ console.log(name);
316
+ setCookie("name", name, 30);
317
+ }
318
+ name = getCookie("{name}");
319
+ return name;
320
+ }
321
+ """.replace(
322
+ "{name}", name
323
+ )
324
+
325
+ # Progress bar
326
+
327
+ progress_template = """
328
+ <!DOCTYPE html>
329
+ <html>
330
+ <head>
331
+ <title>Progress Bar</title>
332
+ <style>
333
+ .progress-bar {
334
+ background-color: #ddd;
335
+ border-radius: 4px;
336
+ height: 30px;
337
+ width: 100%;
338
+ position: relative;
339
+ }
340
+
341
+ .progress {
342
+ background-color: #00AAFF;
343
+ border-radius: 4px;
344
+ height: 100%;
345
+ width: {PROGRESS}%; /* Change this value to control the progress */
346
+ }
347
+
348
+ .progress-text {
349
+ position: absolute;
350
+ top: 50%;
351
+ left: 50%;
352
+ transform: translate(-50%, -50%);
353
+ font-size: 18px;
354
+ font-family: Arial, sans-serif;
355
+ font-weight: bold;
356
+ color: #333 !important;
357
+ text-shadow: 1px 1px #fff;
358
+ }
359
+ </style>
360
+ </head>
361
+ <body>
362
+ <div class="progress-bar">
363
+ <div class="progress"></div>
364
+ <div class="progress-text">{TEXT}</div>
365
+ </div>
366
+ </body>
367
+ </html>
368
+ """
369
+
370
+
371
+ def create_tracker(app, cookie_name="name"):
372
+ user = gr.Text(label="user", interactive=True, visible=False, elem_id="user")
373
+ app.load(_js=load_tracker(cookie_name), outputs=user)
374
+ return user
375
+
376
+
377
+ #################################################################
378
+ ### CSS and HTML for labeling sliders for both ABX and MUSHRA ###
379
+ #################################################################
380
+
381
+ slider_abx = """
382
+ <!DOCTYPE html>
383
+ <html>
384
+ <head>
385
+ <meta charset="UTF-8">
386
+ <title>Labels Example</title>
387
+ <style>
388
+ body {
389
+ margin: 0;
390
+ padding: 0;
391
+ }
392
+
393
+ .labels-container {
394
+ display: flex;
395
+ justify-content: space-between;
396
+ align-items: center;
397
+ width: 100%;
398
+ height: 40px;
399
+ padding: 0px 12px 0px;
400
+ }
401
+
402
+ .label {
403
+ display: flex;
404
+ justify-content: center;
405
+ align-items: center;
406
+ width: 33%;
407
+ height: 100%;
408
+ font-weight: bold;
409
+ text-transform: uppercase;
410
+ padding: 10px;
411
+ font-family: Arial, sans-serif;
412
+ font-size: 16px;
413
+ font-weight: 700;
414
+ letter-spacing: 1px;
415
+ line-height: 1.5;
416
+ }
417
+
418
+ .label-a {
419
+ background-color: #00AAFF;
420
+ color: #333 !important;
421
+ }
422
+
423
+ .label-tie {
424
+ background-color: #f97316;
425
+ color: #333 !important;
426
+ }
427
+
428
+ .label-b {
429
+ background-color: #00AAFF;
430
+ color: #333 !important;
431
+ }
432
+ </style>
433
+ </head>
434
+ <body>
435
+ <div class="labels-container">
436
+ <div class="label label-a">Prefer A</div>
437
+ <div class="label label-tie">Toss-up</div>
438
+ <div class="label label-b">Prefer B</div>
439
+ </div>
440
+ </body>
441
+ </html>
442
+ """
443
+
444
+ slider_mushra = """
445
+ <!DOCTYPE html>
446
+ <html>
447
+ <head>
448
+ <meta charset="UTF-8">
449
+ <title>Labels Example</title>
450
+ <style>
451
+ body {
452
+ margin: 0;
453
+ padding: 0;
454
+ }
455
+
456
+ .labels-container {
457
+ display: flex;
458
+ justify-content: space-between;
459
+ align-items: center;
460
+ width: 100%;
461
+ height: 30px;
462
+ padding: 10px;
463
+ }
464
+
465
+ .label {
466
+ display: flex;
467
+ justify-content: center;
468
+ align-items: center;
469
+ width: 20%;
470
+ height: 100%;
471
+ font-weight: bold;
472
+ text-transform: uppercase;
473
+ padding: 10px;
474
+ font-family: Arial, sans-serif;
475
+ font-size: 13.5px;
476
+ font-weight: 700;
477
+ line-height: 1.5;
478
+ }
479
+
480
+ .label-bad {
481
+ background-color: #ff5555;
482
+ color: #333 !important;
483
+ }
484
+
485
+ .label-poor {
486
+ background-color: #ffa500;
487
+ color: #333 !important;
488
+ }
489
+
490
+ .label-fair {
491
+ background-color: #ffd700;
492
+ color: #333 !important;
493
+ }
494
+
495
+ .label-good {
496
+ background-color: #97d997;
497
+ color: #333 !important;
498
+ }
499
+
500
+ .label-excellent {
501
+ background-color: #04c822;
502
+ color: #333 !important;
503
+ }
504
+ </style>
505
+ </head>
506
+ <body>
507
+ <div class="labels-container">
508
+ <div class="label label-bad">bad</div>
509
+ <div class="label label-poor">poor</div>
510
+ <div class="label label-fair">fair</div>
511
+ <div class="label label-good">good</div>
512
+ <div class="label label-excellent">excellent</div>
513
+ </div>
514
+ </body>
515
+ </html>
516
+ """
517
+
518
+ #########################################################
519
+ ### Handling loading audio and tracking session state ###
520
+ #########################################################
521
+
522
+
523
+ class Samples:
524
+ def __init__(self, folder: str, shuffle: bool = True, n_samples: int = None):
525
+ files = find_audio(folder)
526
+ samples = defaultdict(lambda: defaultdict())
527
+
528
+ for f in files:
529
+ condition = f.parent.stem
530
+ samples[f.name][condition] = f
531
+
532
+ self.samples = samples
533
+ self.names = list(samples.keys())
534
+ self.filtered = False
535
+ self.current = 0
536
+
537
+ if shuffle:
538
+ random.shuffle(self.names)
539
+
540
+ self.n_samples = len(self.names) if n_samples is None else n_samples
541
+
542
+ def get_updates(self, idx, order):
543
+ key = self.names[idx]
544
+ return [gr.update(value=str(self.samples[key][o])) for o in order]
545
+
546
+ def progress(self):
547
+ try:
548
+ pct = self.current / len(self) * 100
549
+ except: # pragma: no cover
550
+ pct = 100
551
+ text = f"On {self.current} / {len(self)} samples"
552
+ pbar = (
553
+ copy.copy(progress_template)
554
+ .replace("{PROGRESS}", str(pct))
555
+ .replace("{TEXT}", str(text))
556
+ )
557
+ return gr.update(value=pbar)
558
+
559
+ def __len__(self):
560
+ return self.n_samples
561
+
562
+ def filter_completed(self, user, save_path):
563
+ if not self.filtered:
564
+ done = []
565
+ if Path(save_path).exists():
566
+ with open(save_path, "r") as f:
567
+ reader = csv.DictReader(f)
568
+ done = [r["sample"] for r in reader if r["user"] == user]
569
+ self.names = [k for k in self.names if k not in done]
570
+ self.names = self.names[: self.n_samples]
571
+ self.filtered = True # Avoid filtering more than once per session.
572
+
573
+ def get_next_sample(self, reference, conditions):
574
+ random.shuffle(conditions)
575
+ if reference is not None:
576
+ self.order = [reference] + conditions
577
+ else:
578
+ self.order = conditions
579
+
580
+ try:
581
+ updates = self.get_updates(self.current, self.order)
582
+ self.current += 1
583
+ done = gr.update(interactive=True)
584
+ pbar = self.progress()
585
+ except:
586
+ traceback.print_exc()
587
+ updates = [gr.update() for _ in range(len(self.order))]
588
+ done = gr.update(value="No more samples!", interactive=False)
589
+ self.current = len(self)
590
+ pbar = self.progress()
591
+
592
+ return updates, done, pbar
593
+
594
+
595
+ def save_result(result, save_path):
596
+ with open(save_path, mode="a", newline="") as file:
597
+ writer = csv.DictWriter(file, fieldnames=sorted(list(result.keys())))
598
+ if file.tell() == 0:
599
+ writer.writeheader()
600
+ writer.writerow(result)
src/inference.py ADDED
@@ -0,0 +1,169 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import pandas as pd
4
+ import torch
5
+ import librosa
6
+ import numpy as np
7
+ import soundfile as sf
8
+ from tqdm import tqdm
9
+ from .utils import scale_shift_re
10
+
11
+
12
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
13
+ """
14
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
15
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
16
+ """
17
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
18
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
19
+ # rescale the results from guidance (fixes overexposure)
20
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
21
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
22
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
23
+ return noise_cfg
24
+
25
+
26
+ @torch.no_grad()
27
+ def inference(autoencoder, unet, gt, gt_mask,
28
+ tokenizer, text_encoder,
29
+ params, noise_scheduler,
30
+ text_raw, neg_text=None,
31
+ audio_frames=500,
32
+ guidance_scale=3, guidance_rescale=0.0,
33
+ ddim_steps=50, eta=1, random_seed=2024,
34
+ device='cuda',
35
+ ):
36
+ if neg_text is None:
37
+ neg_text = [""]
38
+ if tokenizer is not None:
39
+ text_batch = tokenizer(text_raw,
40
+ max_length=params['text_encoder']['max_length'],
41
+ padding="max_length", truncation=True, return_tensors="pt")
42
+ text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
43
+ text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
44
+
45
+ uncond_text_batch = tokenizer(neg_text,
46
+ max_length=params['text_encoder']['max_length'],
47
+ padding="max_length", truncation=True, return_tensors="pt")
48
+ uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
49
+ uncond_text = text_encoder(input_ids=uncond_text,
50
+ attention_mask=uncond_text_mask).last_hidden_state
51
+ else:
52
+ text, text_mask = None, None
53
+ guidance_scale = None
54
+
55
+ codec_dim = params['model']['out_chans']
56
+ unet.eval()
57
+
58
+ if random_seed is not None:
59
+ generator = torch.Generator(device=device).manual_seed(random_seed)
60
+ else:
61
+ generator = torch.Generator(device=device)
62
+ generator.seed()
63
+
64
+ noise_scheduler.set_timesteps(ddim_steps)
65
+
66
+ # init noise
67
+ noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
68
+ latents = noise
69
+
70
+ for t in noise_scheduler.timesteps:
71
+ latents = noise_scheduler.scale_model_input(latents, t)
72
+
73
+ if guidance_scale:
74
+
75
+ latents_combined = torch.cat([latents, latents], dim=0)
76
+ text_combined = torch.cat([text, uncond_text], dim=0)
77
+ text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
78
+
79
+ if gt is not None:
80
+ gt_combined = torch.cat([gt, gt], dim=0)
81
+ gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
82
+ else:
83
+ gt_combined = None
84
+ gt_mask_combined = None
85
+
86
+ output_combined, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
87
+ cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined)
88
+ output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
89
+
90
+ output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
91
+ if guidance_rescale > 0.0:
92
+ output_pred = rescale_noise_cfg(output_pred, output_text,
93
+ guidance_rescale=guidance_rescale)
94
+ else:
95
+ output_pred, mae_mask = unet(latents, t, text, context_mask=text_mask,
96
+ cls_token=None, gt=gt, mae_mask_infer=gt_mask)
97
+
98
+ latents = noise_scheduler.step(model_output=output_pred, timestep=t,
99
+ sample=latents,
100
+ eta=eta, generator=generator).prev_sample
101
+
102
+ pred = scale_shift_re(latents, params['autoencoder']['scale'],
103
+ params['autoencoder']['shift'])
104
+ if gt is not None:
105
+ pred[~gt_mask] = gt[~gt_mask]
106
+ pred_wav = autoencoder(embedding=pred)
107
+ return pred_wav
108
+
109
+
110
+ @torch.no_grad()
111
+ def eval_udit(autoencoder, unet,
112
+ tokenizer, text_encoder,
113
+ params, noise_scheduler,
114
+ val_df, subset,
115
+ audio_frames, mae=False,
116
+ guidance_scale=3, guidance_rescale=0.0,
117
+ ddim_steps=50, eta=1, random_seed=2023,
118
+ device='cuda',
119
+ epoch=0, save_path='logs/eval/', val_num=5):
120
+ val_df = pd.read_csv(val_df)
121
+ val_df = val_df[val_df['split'] == subset]
122
+ if mae:
123
+ val_df = val_df[val_df['audio_length'] != 0]
124
+
125
+ save_path = save_path + str(epoch) + '/'
126
+ os.makedirs(save_path, exist_ok=True)
127
+
128
+ for i in tqdm(range(len(val_df))):
129
+ row = val_df.iloc[i]
130
+ text = [row['caption']]
131
+ if mae:
132
+ audio_path = params['data']['val_dir'] + str(row['audio_path'])
133
+ gt, sr = librosa.load(audio_path, sr=params['data']['sr'])
134
+ gt = gt / (np.max(np.abs(gt)) + 1e-9)
135
+ sf.write(save_path + text[0] + '_gt.wav', gt, samplerate=params['data']['sr'])
136
+ num_samples = 10 * sr
137
+ if len(gt) < num_samples:
138
+ padding = num_samples - len(gt)
139
+ gt = np.pad(gt, (0, padding), 'constant')
140
+ else:
141
+ gt = gt[:num_samples]
142
+ gt = torch.tensor(gt).unsqueeze(0).unsqueeze(1).to(device)
143
+ gt = autoencoder(audio=gt)
144
+ B, D, L = gt.shape
145
+ mask_len = int(L * 0.2)
146
+ gt_mask = torch.zeros(B, D, L).to(device)
147
+ for _ in range(2):
148
+ start = random.randint(0, L - mask_len)
149
+ gt_mask[:, :, start:start + mask_len] = 1
150
+ gt_mask = gt_mask.bool()
151
+ else:
152
+ gt = None
153
+ gt_mask = None
154
+
155
+ pred = inference(autoencoder, unet, gt, gt_mask,
156
+ tokenizer, text_encoder,
157
+ params, noise_scheduler,
158
+ text, neg_text=None,
159
+ audio_frames=audio_frames,
160
+ guidance_scale=guidance_scale, guidance_rescale=guidance_rescale,
161
+ ddim_steps=ddim_steps, eta=eta, random_seed=random_seed,
162
+ device=device)
163
+
164
+ pred = pred.cpu().numpy().squeeze(0).squeeze(0)
165
+
166
+ sf.write(save_path + text[0] + '.wav', pred, samplerate=params['data']['sr'])
167
+
168
+ if i + 1 >= val_num:
169
+ break
src/inference_controlnet.py ADDED
@@ -0,0 +1,129 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import pandas as pd
4
+ import torch
5
+ import librosa
6
+ import numpy as np
7
+ import soundfile as sf
8
+ from tqdm import tqdm
9
+ from .utils import scale_shift_re
10
+
11
+
12
+ def rescale_noise_cfg(noise_cfg, noise_pred_text, guidance_rescale=0.0):
13
+ """
14
+ Rescale `noise_cfg` according to `guidance_rescale`. Based on findings of [Common Diffusion Noise Schedules and
15
+ Sample Steps are Flawed](https://arxiv.org/pdf/2305.08891.pdf). See Section 3.4
16
+ """
17
+ std_text = noise_pred_text.std(dim=list(range(1, noise_pred_text.ndim)), keepdim=True)
18
+ std_cfg = noise_cfg.std(dim=list(range(1, noise_cfg.ndim)), keepdim=True)
19
+ # rescale the results from guidance (fixes overexposure)
20
+ noise_pred_rescaled = noise_cfg * (std_text / std_cfg)
21
+ # mix with the original results from guidance by factor guidance_rescale to avoid "plain looking" images
22
+ noise_cfg = guidance_rescale * noise_pred_rescaled + (1 - guidance_rescale) * noise_cfg
23
+ return noise_cfg
24
+
25
+
26
+ @torch.no_grad()
27
+ def inference(autoencoder, unet, controlnet,
28
+ gt, gt_mask, condition,
29
+ tokenizer, text_encoder,
30
+ params, noise_scheduler,
31
+ text_raw, neg_text=None,
32
+ audio_frames=500,
33
+ guidance_scale=3, guidance_rescale=0.0,
34
+ ddim_steps=50, eta=1, random_seed=2024,
35
+ conditioning_scale=1.0,
36
+ device='cuda',
37
+ ):
38
+ if neg_text is None:
39
+ neg_text = [""]
40
+ if tokenizer is not None:
41
+ text_batch = tokenizer(text_raw,
42
+ max_length=params['text_encoder']['max_length'],
43
+ padding="max_length", truncation=True, return_tensors="pt")
44
+ text, text_mask = text_batch.input_ids.to(device), text_batch.attention_mask.to(device).bool()
45
+ text = text_encoder(input_ids=text, attention_mask=text_mask).last_hidden_state
46
+
47
+ uncond_text_batch = tokenizer(neg_text,
48
+ max_length=params['text_encoder']['max_length'],
49
+ padding="max_length", truncation=True, return_tensors="pt")
50
+ uncond_text, uncond_text_mask = uncond_text_batch.input_ids.to(device), uncond_text_batch.attention_mask.to(device).bool()
51
+ uncond_text = text_encoder(input_ids=uncond_text,
52
+ attention_mask=uncond_text_mask).last_hidden_state
53
+ else:
54
+ text, text_mask = None, None
55
+ guidance_scale = None
56
+
57
+ codec_dim = params['model']['out_chans']
58
+ unet.eval()
59
+ controlnet.eval()
60
+
61
+ if random_seed is not None:
62
+ generator = torch.Generator(device=device).manual_seed(random_seed)
63
+ else:
64
+ generator = torch.Generator(device=device)
65
+ generator.seed()
66
+
67
+ noise_scheduler.set_timesteps(ddim_steps)
68
+
69
+ # init noise
70
+ noise = torch.randn((1, codec_dim, audio_frames), generator=generator, device=device)
71
+ latents = noise
72
+
73
+ for t in noise_scheduler.timesteps:
74
+ latents = noise_scheduler.scale_model_input(latents, t)
75
+
76
+ if guidance_scale:
77
+ latents_combined = torch.cat([latents, latents], dim=0)
78
+ text_combined = torch.cat([text, uncond_text], dim=0)
79
+ text_mask_combined = torch.cat([text_mask, uncond_text_mask], dim=0)
80
+ condition_combined = torch.cat([condition, condition], dim=0)
81
+
82
+ if gt is not None:
83
+ gt_combined = torch.cat([gt, gt], dim=0)
84
+ gt_mask_combined = torch.cat([gt_mask, gt_mask], dim=0)
85
+ else:
86
+ gt_combined = None
87
+ gt_mask_combined = None
88
+
89
+ x, _ = unet(latents_combined, t, text_combined, context_mask=text_mask_combined,
90
+ cls_token=None, gt=gt_combined, mae_mask_infer=gt_mask_combined,
91
+ forward_model=False)
92
+ controlnet_skips = controlnet(x, t, text_combined,
93
+ context_mask=text_mask_combined,
94
+ cls_token=None,
95
+ condition=condition_combined,
96
+ conditioning_scale=conditioning_scale)
97
+ output_combined = unet.model(x, t, text_combined,
98
+ context_mask=text_mask_combined,
99
+ cls_token=None, controlnet_skips=controlnet_skips)
100
+
101
+ output_text, output_uncond = torch.chunk(output_combined, 2, dim=0)
102
+
103
+ output_pred = output_uncond + guidance_scale * (output_text - output_uncond)
104
+ if guidance_rescale > 0.0:
105
+ output_pred = rescale_noise_cfg(output_pred, output_text,
106
+ guidance_rescale=guidance_rescale)
107
+ else:
108
+ x, _ = unet(latents, t, text, context_mask=text_mask,
109
+ cls_token=None, gt=gt, mae_mask_infer=gt_mask,
110
+ forward_model=False)
111
+ controlnet_skips = controlnet(x, t, text,
112
+ context_mask=text_mask,
113
+ cls_token=None,
114
+ condition=condition,
115
+ conditioning_scale=conditioning_scale)
116
+ output_pred = unet.model(x, t, text,
117
+ context_mask=text_mask,
118
+ cls_token=None, controlnet_skips=controlnet_skips)
119
+
120
+ latents = noise_scheduler.step(model_output=output_pred, timestep=t,
121
+ sample=latents,
122
+ eta=eta, generator=generator).prev_sample
123
+
124
+ pred = scale_shift_re(latents, params['autoencoder']['scale'],
125
+ params['autoencoder']['shift'])
126
+ if gt is not None:
127
+ pred[~gt_mask] = gt[~gt_mask]
128
+ pred_wav = autoencoder(embedding=pred)
129
+ return pred_wav
src/models/.ipynb_checkpoints/blocks-checkpoint.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+ from .utils.attention import Attention, JointAttention
5
+ from .utils.modules import unpatchify, FeedForward
6
+ from .utils.modules import film_modulate
7
+
8
+
9
+ class AdaLN(nn.Module):
10
+ def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
11
+ super().__init__()
12
+ self.ada_mode = ada_mode
13
+ self.scale_shift_table = None
14
+ if ada_mode == 'ada':
15
+ # move nn.silu outside
16
+ self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
17
+ elif ada_mode == 'ada_single':
18
+ # adaln used in pixel-art alpha
19
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
20
+ elif ada_mode in ['ada_lora', 'ada_lora_bias']:
21
+ self.lora_a = nn.Linear(dim, r * 6, bias=False)
22
+ self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
23
+ self.scaling = alpha / r
24
+ if ada_mode == 'ada_lora_bias':
25
+ # take bias out for consistency
26
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
27
+ else:
28
+ raise NotImplementedError
29
+
30
+ def forward(self, time_token=None, time_ada=None):
31
+ if self.ada_mode == 'ada':
32
+ assert time_ada is None
33
+ B = time_token.shape[0]
34
+ time_ada = self.time_ada(time_token).reshape(B, 6, -1)
35
+ elif self.ada_mode == 'ada_single':
36
+ B = time_ada.shape[0]
37
+ time_ada = time_ada.reshape(B, 6, -1)
38
+ time_ada = self.scale_shift_table[None] + time_ada
39
+ elif self.ada_mode in ['ada_lora', 'ada_lora_bias']:
40
+ B = time_ada.shape[0]
41
+ time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
42
+ time_ada = time_ada + time_ada_lora
43
+ time_ada = time_ada.reshape(B, 6, -1)
44
+ if self.scale_shift_table is not None:
45
+ time_ada = self.scale_shift_table[None] + time_ada
46
+ else:
47
+ raise NotImplementedError
48
+ return time_ada
49
+
50
+
51
+ class DiTBlock(nn.Module):
52
+ """
53
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
54
+ """
55
+
56
+ def __init__(self, dim, context_dim=None,
57
+ num_heads=8, mlp_ratio=4.,
58
+ qkv_bias=False, qk_scale=None, qk_norm=None,
59
+ act_layer='gelu', norm_layer=nn.LayerNorm,
60
+ time_fusion='none',
61
+ ada_lora_rank=None, ada_lora_alpha=None,
62
+ skip=False, skip_norm=False,
63
+ rope_mode='none',
64
+ context_norm=False,
65
+ use_checkpoint=False):
66
+
67
+ super().__init__()
68
+ self.norm1 = norm_layer(dim)
69
+ self.attn = Attention(dim=dim,
70
+ num_heads=num_heads,
71
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
72
+ qk_norm=qk_norm,
73
+ rope_mode=rope_mode)
74
+
75
+ if context_dim is not None:
76
+ self.use_context = True
77
+ self.cross_attn = Attention(dim=dim,
78
+ num_heads=num_heads,
79
+ context_dim=context_dim,
80
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
81
+ qk_norm=qk_norm,
82
+ rope_mode='none')
83
+ self.norm2 = norm_layer(dim)
84
+ if context_norm:
85
+ self.norm_context = norm_layer(context_dim)
86
+ else:
87
+ self.norm_context = nn.Identity()
88
+ else:
89
+ self.use_context = False
90
+
91
+ self.norm3 = norm_layer(dim)
92
+ self.mlp = FeedForward(dim=dim, mult=mlp_ratio,
93
+ activation_fn=act_layer, dropout=0)
94
+
95
+ self.use_adanorm = True if time_fusion != 'token' else False
96
+ if self.use_adanorm:
97
+ self.adaln = AdaLN(dim, ada_mode=time_fusion,
98
+ r=ada_lora_rank, alpha=ada_lora_alpha)
99
+ if skip:
100
+ self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity()
101
+ self.skip_linear = nn.Linear(2 * dim, dim)
102
+ else:
103
+ self.skip_linear = None
104
+
105
+ self.use_checkpoint = use_checkpoint
106
+
107
+ def forward(self, x, time_token=None, time_ada=None,
108
+ skip=None, context=None,
109
+ x_mask=None, context_mask=None, extras=None):
110
+ if self.use_checkpoint:
111
+ return checkpoint(self._forward, x,
112
+ time_token, time_ada, skip, context,
113
+ x_mask, context_mask, extras,
114
+ use_reentrant=False)
115
+ else:
116
+ return self._forward(x,
117
+ time_token, time_ada, skip, context,
118
+ x_mask, context_mask, extras)
119
+
120
+ def _forward(self, x, time_token=None, time_ada=None,
121
+ skip=None, context=None,
122
+ x_mask=None, context_mask=None, extras=None):
123
+ B, T, C = x.shape
124
+ if self.skip_linear is not None:
125
+ assert skip is not None
126
+ cat = torch.cat([x, skip], dim=-1)
127
+ cat = self.skip_norm(cat)
128
+ x = self.skip_linear(cat)
129
+
130
+ if self.use_adanorm:
131
+ time_ada = self.adaln(time_token, time_ada)
132
+ (shift_msa, scale_msa, gate_msa,
133
+ shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
134
+
135
+ # self attention
136
+ if self.use_adanorm:
137
+ x_norm = film_modulate(self.norm1(x), shift=shift_msa,
138
+ scale=scale_msa)
139
+ x = x + (1 - gate_msa) * self.attn(x_norm, context=None,
140
+ context_mask=x_mask,
141
+ extras=extras)
142
+ else:
143
+ x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask,
144
+ extras=extras)
145
+
146
+ # cross attention
147
+ if self.use_context:
148
+ assert context is not None
149
+ x = x + self.cross_attn(x=self.norm2(x),
150
+ context=self.norm_context(context),
151
+ context_mask=context_mask, extras=extras)
152
+
153
+ # mlp
154
+ if self.use_adanorm:
155
+ x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp)
156
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
157
+ else:
158
+ x = x + self.mlp(self.norm3(x))
159
+
160
+ return x
161
+
162
+
163
+ class JointDiTBlock(nn.Module):
164
+ """
165
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
166
+ """
167
+
168
+ def __init__(self, dim, context_dim=None,
169
+ num_heads=8, mlp_ratio=4.,
170
+ qkv_bias=False, qk_scale=None, qk_norm=None,
171
+ act_layer='gelu', norm_layer=nn.LayerNorm,
172
+ time_fusion='none',
173
+ ada_lora_rank=None, ada_lora_alpha=None,
174
+ skip=(False, False),
175
+ rope_mode=False,
176
+ context_norm=False,
177
+ use_checkpoint=False,):
178
+
179
+ super().__init__()
180
+ # no cross attention
181
+ assert context_dim is None
182
+ self.attn_norm_x = norm_layer(dim)
183
+ self.attn_norm_c = norm_layer(dim)
184
+ self.attn = JointAttention(dim=dim,
185
+ num_heads=num_heads,
186
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
187
+ qk_norm=qk_norm,
188
+ rope_mode=rope_mode)
189
+ self.ffn_norm_x = norm_layer(dim)
190
+ self.ffn_norm_c = norm_layer(dim)
191
+ self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio,
192
+ activation_fn=act_layer, dropout=0)
193
+ self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio,
194
+ activation_fn=act_layer, dropout=0)
195
+
196
+ # Zero-out the shift table
197
+ self.use_adanorm = True if time_fusion != 'token' else False
198
+ if self.use_adanorm:
199
+ self.adaln = AdaLN(dim, ada_mode=time_fusion,
200
+ r=ada_lora_rank, alpha=ada_lora_alpha)
201
+
202
+ if skip is False:
203
+ skip_x, skip_c = False, False
204
+ else:
205
+ skip_x, skip_c = skip
206
+
207
+ self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None
208
+ self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None
209
+
210
+ self.use_checkpoint = use_checkpoint
211
+
212
+ def forward(self, x, time_token=None, time_ada=None,
213
+ skip=None, context=None,
214
+ x_mask=None, context_mask=None, extras=None):
215
+ if self.use_checkpoint:
216
+ return checkpoint(self._forward, x,
217
+ time_token, time_ada, skip,
218
+ context, x_mask, context_mask, extras,
219
+ use_reentrant=False)
220
+ else:
221
+ return self._forward(x,
222
+ time_token, time_ada, skip,
223
+ context, x_mask, context_mask, extras)
224
+
225
+ def _forward(self, x, time_token=None, time_ada=None,
226
+ skip=None, context=None,
227
+ x_mask=None, context_mask=None, extras=None):
228
+
229
+ assert context is None and context_mask is None
230
+
231
+ context, x = x[:, :extras, :], x[:, extras:, :]
232
+ context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:]
233
+
234
+ if skip is not None:
235
+ skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :]
236
+
237
+ B, T, C = x.shape
238
+ if self.skip_linear_x is not None:
239
+ x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1))
240
+
241
+ if self.skip_linear_c is not None:
242
+ context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1))
243
+
244
+ if self.use_adanorm:
245
+ time_ada = self.adaln(time_token, time_ada)
246
+ (shift_msa, scale_msa, gate_msa,
247
+ shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
248
+
249
+ # self attention
250
+ x_norm = self.attn_norm_x(x)
251
+ c_norm = self.attn_norm_c(context)
252
+ if self.use_adanorm:
253
+ x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa)
254
+ x_out, c_out = self.attn(x_norm, context=c_norm,
255
+ x_mask=x_mask, context_mask=context_mask,
256
+ extras=extras)
257
+ if self.use_adanorm:
258
+ x = x + (1 - gate_msa) * x_out
259
+ else:
260
+ x = x + x_out
261
+ context = context + c_out
262
+
263
+ # mlp
264
+ if self.use_adanorm:
265
+ x_norm = film_modulate(self.ffn_norm_x(x),
266
+ shift=shift_mlp, scale=scale_mlp)
267
+ x = x + (1 - gate_mlp) * self.mlp_x(x_norm)
268
+ else:
269
+ x = x + self.mlp_x(self.ffn_norm_x(x))
270
+
271
+ c_norm = self.ffn_norm_c(context)
272
+ context = context + self.mlp_c(c_norm)
273
+
274
+ return torch.cat((context, x), dim=1)
275
+
276
+
277
+ class FinalBlock(nn.Module):
278
+ def __init__(self, embed_dim, patch_size, in_chans,
279
+ img_size,
280
+ input_type='2d',
281
+ norm_layer=nn.LayerNorm,
282
+ use_conv=True,
283
+ use_adanorm=True):
284
+ super().__init__()
285
+ self.in_chans = in_chans
286
+ self.img_size = img_size
287
+ self.input_type = input_type
288
+
289
+ self.norm = norm_layer(embed_dim)
290
+ if use_adanorm:
291
+ self.use_adanorm = True
292
+ else:
293
+ self.use_adanorm = False
294
+
295
+ if input_type == '2d':
296
+ self.patch_dim = patch_size ** 2 * in_chans
297
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
298
+ if use_conv:
299
+ self.final_layer = nn.Conv2d(self.in_chans, self.in_chans,
300
+ 3, padding=1)
301
+ else:
302
+ self.final_layer = nn.Identity()
303
+
304
+ elif input_type == '1d':
305
+ self.patch_dim = patch_size * in_chans
306
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
307
+ if use_conv:
308
+ self.final_layer = nn.Conv1d(self.in_chans, self.in_chans,
309
+ 3, padding=1)
310
+ else:
311
+ self.final_layer = nn.Identity()
312
+
313
+ def forward(self, x, time_ada=None, extras=0):
314
+ B, T, C = x.shape
315
+ x = x[:, extras:, :]
316
+ # only handle generation target
317
+ if self.use_adanorm:
318
+ shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
319
+ x = film_modulate(self.norm(x), shift, scale)
320
+ else:
321
+ x = self.norm(x)
322
+ x = self.linear(x)
323
+ x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
324
+ x = self.final_layer(x)
325
+ return x
src/models/.ipynb_checkpoints/conditioners-checkpoint.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ from einops import repeat
5
+ import math
6
+ from .udit import UDiT
7
+ from .utils.span_mask import compute_mask_indices
8
+
9
+
10
+ class EmbeddingCFG(nn.Module):
11
+ """
12
+ Handles label dropout for classifier-free guidance.
13
+ """
14
+ # todo: support 2D input
15
+
16
+ def __init__(self, in_channels):
17
+ super().__init__()
18
+ self.cfg_embedding = nn.Parameter(
19
+ torch.randn(in_channels) / in_channels ** 0.5)
20
+
21
+ def token_drop(self, condition, condition_mask, cfg_prob):
22
+ """
23
+ Drops labels to enable classifier-free guidance.
24
+ """
25
+ b, t, device = condition.shape[0], condition.shape[1], condition.device
26
+ drop_ids = torch.rand(b, device=device) < cfg_prob
27
+ uncond = repeat(self.cfg_embedding, "c -> b t c", b=b, t=t)
28
+ condition = torch.where(drop_ids[:, None, None], uncond, condition)
29
+ if condition_mask is not None:
30
+ condition_mask[drop_ids] = False
31
+ condition_mask[drop_ids, 0] = True
32
+
33
+ return condition, condition_mask
34
+
35
+ def forward(self, condition, condition_mask, cfg_prob=0.0):
36
+ if condition_mask is not None:
37
+ condition_mask = condition_mask.clone()
38
+ if cfg_prob > 0:
39
+ condition, condition_mask = self.token_drop(condition,
40
+ condition_mask,
41
+ cfg_prob)
42
+ return condition, condition_mask
43
+
44
+
45
+ class DiscreteCFG(nn.Module):
46
+ def __init__(self, replace_id=2):
47
+ super(DiscreteCFG, self).__init__()
48
+ self.replace_id = replace_id
49
+
50
+ def forward(self, context, context_mask, cfg_prob):
51
+ context = context.clone()
52
+ if context_mask is not None:
53
+ context_mask = context_mask.clone()
54
+ if cfg_prob > 0:
55
+ cfg_mask = torch.rand(len(context)) < cfg_prob
56
+ if torch.any(cfg_mask):
57
+ context[cfg_mask] = 0
58
+ context[cfg_mask, 0] = self.replace_id
59
+ if context_mask is not None:
60
+ context_mask[cfg_mask] = False
61
+ context_mask[cfg_mask, 0] = True
62
+ return context, context_mask
63
+
64
+
65
+ class CFGModel(nn.Module):
66
+ def __init__(self, context_dim, backbone):
67
+ super().__init__()
68
+ self.model = backbone
69
+ self.context_cfg = EmbeddingCFG(context_dim)
70
+
71
+ def forward(self, x, timesteps,
72
+ context, x_mask=None, context_mask=None,
73
+ cfg_prob=0.0):
74
+ context = self.context_cfg(context, cfg_prob)
75
+ x = self.model(x=x, timesteps=timesteps,
76
+ context=context,
77
+ x_mask=x_mask, context_mask=context_mask)
78
+ return x
79
+
80
+
81
+ class ConcatModel(nn.Module):
82
+ def __init__(self, backbone, in_dim, stride=[]):
83
+ super().__init__()
84
+ self.model = backbone
85
+
86
+ self.downsample_layers = nn.ModuleList()
87
+ for i, s in enumerate(stride):
88
+ downsample_layer = nn.Conv1d(
89
+ in_dim,
90
+ in_dim * 2,
91
+ kernel_size=2 * s,
92
+ stride=s,
93
+ padding=math.ceil(s / 2),
94
+ )
95
+ self.downsample_layers.append(downsample_layer)
96
+ in_dim = in_dim * 2
97
+
98
+ self.context_cfg = EmbeddingCFG(in_dim)
99
+
100
+ def forward(self, x, timesteps,
101
+ context, x_mask=None,
102
+ cfg=False, cfg_prob=0.0):
103
+
104
+ # todo: support 2D input
105
+ # x: B, C, L
106
+ # context: B, C, L
107
+
108
+ for downsample_layer in self.downsample_layers:
109
+ context = downsample_layer(context)
110
+
111
+ context = context.transpose(1, 2)
112
+ context = self.context_cfg(caption=context,
113
+ cfg=cfg, cfg_prob=cfg_prob)
114
+ context = context.transpose(1, 2)
115
+
116
+ assert context.shape[-1] == x.shape[-1]
117
+ x = torch.cat([context, x], dim=1)
118
+ x = self.model(x=x, timesteps=timesteps,
119
+ context=None, x_mask=x_mask, context_mask=None)
120
+ return x
121
+
122
+
123
+ class MaskDiT(nn.Module):
124
+ def __init__(self, mae=False, mae_prob=0.5, mask_ratio=[0.25, 1.0], mask_span=10, **kwargs):
125
+ super().__init__()
126
+ self.model = UDiT(**kwargs)
127
+ self.mae = mae
128
+ if self.mae:
129
+ out_channel = kwargs.pop('out_chans', None)
130
+ self.mask_embed = nn.Parameter(torch.zeros((out_channel)))
131
+ self.mae_prob = mae_prob
132
+ self.mask_ratio = mask_ratio
133
+ self.mask_span = mask_span
134
+
135
+ def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
136
+ B, D, L = gt.shape
137
+ if mae_mask_infer is None:
138
+ # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
139
+ mask_ratios = mask_ratios.cpu().numpy()
140
+ mask = compute_mask_indices(shape=[B, L],
141
+ padding_mask=None,
142
+ mask_prob=mask_ratios,
143
+ mask_length=self.mask_span,
144
+ mask_type="static",
145
+ mask_other=0.0,
146
+ min_masks=1,
147
+ no_overlap=False,
148
+ min_space=0,)
149
+ mask = mask.unsqueeze(1).expand_as(gt)
150
+ else:
151
+ mask = mae_mask_infer
152
+ mask = mask.expand_as(gt)
153
+ gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask]
154
+ return gt, mask.type_as(gt)
155
+
156
+ def forward(self, x, timesteps, context,
157
+ x_mask=None, context_mask=None, cls_token=None,
158
+ gt=None, mae_mask_infer=None,
159
+ forward_model=True):
160
+ # todo: handle controlnet inside
161
+ mae_mask = torch.ones_like(x)
162
+ if self.mae:
163
+ if gt is not None:
164
+ B, D, L = gt.shape
165
+ mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(gt.device)
166
+ gt, mae_mask = self.random_masking(gt, mask_ratios, mae_mask_infer)
167
+ # apply mae only to the selected batches
168
+ if mae_mask_infer is None:
169
+ # determine mae batch
170
+ mae_batch = torch.rand(B) < self.mae_prob
171
+ gt[~mae_batch] = self.mask_embed.view(1, D, 1).expand_as(gt)[~mae_batch]
172
+ mae_mask[~mae_batch] = 1.0
173
+ else:
174
+ B, D, L = x.shape
175
+ gt = self.mask_embed.view(1, D, 1).expand_as(x)
176
+ x = torch.cat([x, gt, mae_mask[:, 0:1, :]], dim=1)
177
+
178
+ if forward_model:
179
+ x = self.model(x=x, timesteps=timesteps, context=context,
180
+ x_mask=x_mask, context_mask=context_mask,
181
+ cls_token=cls_token)
182
+ # print(mae_mask[:, 0, :].sum(dim=-1))
183
+ return x, mae_mask
src/models/.ipynb_checkpoints/controlnet-checkpoint.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .utils.modules import PatchEmbed, TimestepEmbedder
5
+ from .utils.modules import PE_wrapper, RMSNorm
6
+ from .blocks import DiTBlock, JointDiTBlock
7
+ from .utils.span_mask import compute_mask_indices
8
+
9
+
10
+ class DiTControlNetEmbed(nn.Module):
11
+ def __init__(self, in_chans, out_chans, blocks,
12
+ cond_mask=False, cond_mask_prob=None,
13
+ cond_mask_ratio=None, cond_mask_span=None):
14
+ super().__init__()
15
+ self.conv_in = nn.Conv1d(in_chans, blocks[0], kernel_size=1)
16
+
17
+ self.cond_mask = cond_mask
18
+ if self.cond_mask:
19
+ self.mask_embed = nn.Parameter(torch.zeros((blocks[0])))
20
+ self.mask_prob = cond_mask_prob
21
+ self.mask_ratio = cond_mask_ratio
22
+ self.mask_span = cond_mask_span
23
+ blocks[0] = blocks[0] + 1
24
+
25
+ conv_blocks = []
26
+ for i in range(len(blocks) - 1):
27
+ channel_in = blocks[i]
28
+ channel_out = blocks[i + 1]
29
+ block = nn.Sequential(
30
+ nn.Conv1d(channel_in, channel_in, kernel_size=3, padding=1),
31
+ nn.SiLU(),
32
+ nn.Conv1d(channel_in, channel_out, kernel_size=3, padding=1, stride=2),
33
+ nn.SiLU(),)
34
+ conv_blocks.append(block)
35
+ self.blocks = nn.ModuleList(conv_blocks)
36
+
37
+ self.conv_out = nn.Conv1d(blocks[-1], out_chans, kernel_size=1)
38
+ nn.init.zeros_(self.conv_out.weight)
39
+ nn.init.zeros_(self.conv_out.bias)
40
+
41
+ def random_masking(self, gt, mask_ratios, mae_mask_infer=None):
42
+ B, D, L = gt.shape
43
+ if mae_mask_infer is None:
44
+ # mask = torch.rand(B, L).to(gt.device) < mask_ratios.unsqueeze(1)
45
+ mask_ratios = mask_ratios.cpu().numpy()
46
+ mask = compute_mask_indices(shape=[B, L],
47
+ padding_mask=None,
48
+ mask_prob=mask_ratios,
49
+ mask_length=self.mask_span,
50
+ mask_type="static",
51
+ mask_other=0.0,
52
+ min_masks=1,
53
+ no_overlap=False,
54
+ min_space=0,)
55
+ # only apply mask to some batches
56
+ mask_batch = torch.rand(B) < self.mask_prob
57
+ mask[~mask_batch] = False
58
+ mask = mask.unsqueeze(1).expand_as(gt)
59
+ else:
60
+ mask = mae_mask_infer
61
+ mask = mask.expand_as(gt)
62
+ gt[mask] = self.mask_embed.view(1, D, 1).expand_as(gt)[mask].type_as(gt)
63
+ return gt, mask.type_as(gt)
64
+
65
+ def forward(self, conditioning, cond_mask_infer=None):
66
+ embedding = self.conv_in(conditioning)
67
+
68
+ if self.cond_mask:
69
+ B, D, L = embedding.shape
70
+ if not self.training and cond_mask_infer is None:
71
+ cond_mask_infer = torch.zeros_like(embedding).bool()
72
+ mask_ratios = torch.FloatTensor(B).uniform_(*self.mask_ratio).to(embedding.device)
73
+ embedding, cond_mask = self.random_masking(embedding, mask_ratios, cond_mask_infer)
74
+ embedding = torch.cat([embedding, cond_mask[:, 0:1, :]], dim=1)
75
+
76
+ for block in self.blocks:
77
+ embedding = block(embedding)
78
+
79
+ embedding = self.conv_out(embedding)
80
+
81
+ # B, L, C
82
+ embedding = embedding.transpose(1, 2).contiguous()
83
+
84
+ return embedding
85
+
86
+
87
+ class DiTControlNet(nn.Module):
88
+ def __init__(self,
89
+ img_size=(224, 224), patch_size=16, in_chans=3,
90
+ input_type='2d', out_chans=None,
91
+ embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
92
+ qkv_bias=False, qk_scale=None, qk_norm=None,
93
+ act_layer='gelu', norm_layer='layernorm',
94
+ context_norm=False,
95
+ use_checkpoint=False,
96
+ # time fusion ada or token
97
+ time_fusion='token',
98
+ ada_lora_rank=None, ada_lora_alpha=None,
99
+ cls_dim=None,
100
+ # max length is only used for concat
101
+ context_dim=768, context_fusion='concat',
102
+ context_max_length=128, context_pe_method='sinu',
103
+ pe_method='abs', rope_mode='none',
104
+ use_conv=True,
105
+ skip=True, skip_norm=True,
106
+ # controlnet configs
107
+ cond_in=None, cond_blocks=None,
108
+ cond_mask=False, cond_mask_prob=None,
109
+ cond_mask_ratio=None, cond_mask_span=None,
110
+ **kwargs):
111
+ super().__init__()
112
+ self.num_features = self.embed_dim = embed_dim
113
+ # input
114
+ self.in_chans = in_chans
115
+ self.input_type = input_type
116
+ if self.input_type == '2d':
117
+ num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
118
+ elif self.input_type == '1d':
119
+ num_patches = img_size // patch_size
120
+ self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans,
121
+ embed_dim=embed_dim, input_type=input_type)
122
+ out_chans = in_chans if out_chans is None else out_chans
123
+ self.out_chans = out_chans
124
+
125
+ # position embedding
126
+ self.rope = rope_mode
127
+ self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method,
128
+ length=num_patches)
129
+
130
+ print(f'x position embedding: {pe_method}')
131
+ print(f'rope mode: {self.rope}')
132
+
133
+ # time embed
134
+ self.time_embed = TimestepEmbedder(embed_dim)
135
+ self.time_fusion = time_fusion
136
+ self.use_adanorm = False
137
+
138
+ # cls embed
139
+ if cls_dim is not None:
140
+ self.cls_embed = nn.Sequential(
141
+ nn.Linear(cls_dim, embed_dim, bias=True),
142
+ nn.SiLU(),
143
+ nn.Linear(embed_dim, embed_dim, bias=True),)
144
+ else:
145
+ self.cls_embed = None
146
+
147
+ # time fusion
148
+ if time_fusion == 'token':
149
+ # put token at the beginning of sequence
150
+ self.extras = 2 if self.cls_embed else 1
151
+ self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras)
152
+ elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']:
153
+ self.use_adanorm = True
154
+ # aviod repetitive silu for each adaln block
155
+ self.time_act = nn.SiLU()
156
+ self.extras = 0
157
+ if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']:
158
+ # shared adaln
159
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
160
+ else:
161
+ self.time_ada = None
162
+ else:
163
+ raise NotImplementedError
164
+ print(f'time fusion mode: {self.time_fusion}')
165
+
166
+ # context
167
+ # use a simple projection
168
+ self.use_context = False
169
+ self.context_cross = False
170
+ self.context_max_length = context_max_length
171
+ self.context_fusion = 'none'
172
+ if context_dim is not None:
173
+ self.use_context = True
174
+ self.context_embed = nn.Sequential(
175
+ nn.Linear(context_dim, embed_dim, bias=True),
176
+ nn.SiLU(),
177
+ nn.Linear(embed_dim, embed_dim, bias=True),)
178
+ self.context_fusion = context_fusion
179
+ if context_fusion == 'concat' or context_fusion == 'joint':
180
+ self.extras += context_max_length
181
+ self.context_pe = PE_wrapper(dim=embed_dim,
182
+ method=context_pe_method,
183
+ length=context_max_length)
184
+ # no cross attention layers
185
+ context_dim = None
186
+ elif context_fusion == 'cross':
187
+ self.context_pe = PE_wrapper(dim=embed_dim,
188
+ method=context_pe_method,
189
+ length=context_max_length)
190
+ self.context_cross = True
191
+ context_dim = embed_dim
192
+ else:
193
+ raise NotImplementedError
194
+ print(f'context fusion mode: {context_fusion}')
195
+ print(f'context position embedding: {context_pe_method}')
196
+
197
+ if self.context_fusion == 'joint':
198
+ Block = JointDiTBlock
199
+ else:
200
+ Block = DiTBlock
201
+
202
+ # norm layers
203
+ if norm_layer == 'layernorm':
204
+ norm_layer = nn.LayerNorm
205
+ elif norm_layer == 'rmsnorm':
206
+ norm_layer = RMSNorm
207
+ else:
208
+ raise NotImplementedError
209
+
210
+ self.in_blocks = nn.ModuleList([
211
+ Block(
212
+ dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
213
+ mlp_ratio=mlp_ratio,
214
+ qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
215
+ act_layer=act_layer, norm_layer=norm_layer,
216
+ time_fusion=time_fusion,
217
+ ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
218
+ skip=False, skip_norm=False,
219
+ rope_mode=self.rope,
220
+ context_norm=context_norm,
221
+ use_checkpoint=use_checkpoint)
222
+ for _ in range(depth // 2)])
223
+
224
+ self.controlnet_pre = DiTControlNetEmbed(in_chans=cond_in, out_chans=embed_dim,
225
+ blocks=cond_blocks,
226
+ cond_mask=cond_mask,
227
+ cond_mask_prob=cond_mask_prob,
228
+ cond_mask_ratio=cond_mask_ratio,
229
+ cond_mask_span=cond_mask_span)
230
+
231
+ controlnet_zero_blocks = []
232
+ for i in range(depth // 2):
233
+ block = nn.Linear(embed_dim, embed_dim)
234
+ nn.init.zeros_(block.weight)
235
+ nn.init.zeros_(block.bias)
236
+ controlnet_zero_blocks.append(block)
237
+ self.controlnet_zero_blocks = nn.ModuleList(controlnet_zero_blocks)
238
+
239
+ print('ControlNet ready \n')
240
+
241
+ def set_trainable(self):
242
+ for param in self.parameters():
243
+ param.requires_grad = False
244
+
245
+ # only train input_proj, blocks, and output_proj
246
+ for module_name in ['controlnet_pre', 'in_blocks', 'controlnet_zero_blocks']:
247
+ module = getattr(self, module_name, None)
248
+ if module is not None:
249
+ for param in module.parameters():
250
+ param.requires_grad = True
251
+ module.train()
252
+ else:
253
+ print(f'\n!!!warning missing trainable blocks: {module_name}!!!\n')
254
+
255
+ def forward(self, x, timesteps, context,
256
+ x_mask=None, context_mask=None,
257
+ cls_token=None,
258
+ condition=None, cond_mask_infer=None,
259
+ conditioning_scale=1.0):
260
+ # make it compatible with int time step during inference
261
+ if timesteps.dim() == 0:
262
+ timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
263
+
264
+ x = self.patch_embed(x)
265
+ # add condition to x
266
+ condition = self.controlnet_pre(condition)
267
+ x = x + condition
268
+ x = self.x_pe(x)
269
+
270
+ B, L, D = x.shape
271
+
272
+ if self.use_context:
273
+ context_token = self.context_embed(context)
274
+ context_token = self.context_pe(context_token)
275
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
276
+ x, x_mask = self._concat_x_context(x=x, context=context_token,
277
+ x_mask=x_mask,
278
+ context_mask=context_mask)
279
+ context_token, context_mask = None, None
280
+ else:
281
+ context_token, context_mask = None, None
282
+
283
+ time_token = self.time_embed(timesteps)
284
+ if self.cls_embed:
285
+ cls_token = self.cls_embed(cls_token)
286
+ time_ada = None
287
+ if self.use_adanorm:
288
+ if self.cls_embed:
289
+ time_token = time_token + cls_token
290
+ time_token = self.time_act(time_token)
291
+ if self.time_ada is not None:
292
+ time_ada = self.time_ada(time_token)
293
+ else:
294
+ time_token = time_token.unsqueeze(dim=1)
295
+ if self.cls_embed:
296
+ cls_token = cls_token.unsqueeze(dim=1)
297
+ time_token = torch.cat([time_token, cls_token], dim=1)
298
+ time_token = self.time_pe(time_token)
299
+ x = torch.cat((time_token, x), dim=1)
300
+ if x_mask is not None:
301
+ x_mask = torch.cat(
302
+ [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(),
303
+ x_mask], dim=1)
304
+ time_token = None
305
+
306
+ skips = []
307
+ for blk in self.in_blocks:
308
+ x = blk(x=x, time_token=time_token, time_ada=time_ada,
309
+ skip=None, context=context_token,
310
+ x_mask=x_mask, context_mask=context_mask,
311
+ extras=self.extras)
312
+ skips.append(x)
313
+
314
+ controlnet_skips = []
315
+ for skip, controlnet_block in zip(skips, self.controlnet_zero_blocks):
316
+ controlnet_skips.append(controlnet_block(skip) * conditioning_scale)
317
+
318
+ return controlnet_skips
src/models/.ipynb_checkpoints/udit-checkpoint.py ADDED
@@ -0,0 +1,365 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.utils.checkpoint
4
+ import math
5
+ from .utils.modules import PatchEmbed, TimestepEmbedder
6
+ from .utils.modules import PE_wrapper, RMSNorm
7
+ from .blocks import DiTBlock, JointDiTBlock, FinalBlock
8
+
9
+
10
+ class UDiT(nn.Module):
11
+ def __init__(self,
12
+ img_size=224, patch_size=16, in_chans=3,
13
+ input_type='2d', out_chans=None,
14
+ embed_dim=768, depth=12, num_heads=12, mlp_ratio=4.,
15
+ qkv_bias=False, qk_scale=None, qk_norm=None,
16
+ act_layer='gelu', norm_layer='layernorm',
17
+ context_norm=False,
18
+ use_checkpoint=False,
19
+ # time fusion ada or token
20
+ time_fusion='token',
21
+ ada_lora_rank=None, ada_lora_alpha=None,
22
+ cls_dim=None,
23
+ # max length is only used for concat
24
+ context_dim=768, context_fusion='concat',
25
+ context_max_length=128, context_pe_method='sinu',
26
+ pe_method='abs', rope_mode='none',
27
+ use_conv=True,
28
+ skip=True, skip_norm=True):
29
+ super().__init__()
30
+ self.num_features = self.embed_dim = embed_dim # num_features for consistency with other models
31
+
32
+ # input
33
+ self.in_chans = in_chans
34
+ self.input_type = input_type
35
+ if self.input_type == '2d':
36
+ num_patches = (img_size[0] // patch_size) * (img_size[1] // patch_size)
37
+ elif self.input_type == '1d':
38
+ num_patches = img_size // patch_size
39
+ self.patch_embed = PatchEmbed(patch_size=patch_size, in_chans=in_chans,
40
+ embed_dim=embed_dim, input_type=input_type)
41
+ out_chans = in_chans if out_chans is None else out_chans
42
+ self.out_chans = out_chans
43
+
44
+ # position embedding
45
+ self.rope = rope_mode
46
+ self.x_pe = PE_wrapper(dim=embed_dim, method=pe_method,
47
+ length=num_patches)
48
+
49
+ print(f'x position embedding: {pe_method}')
50
+ print(f'rope mode: {self.rope}')
51
+
52
+ # time embed
53
+ self.time_embed = TimestepEmbedder(embed_dim)
54
+ self.time_fusion = time_fusion
55
+ self.use_adanorm = False
56
+
57
+ # cls embed
58
+ if cls_dim is not None:
59
+ self.cls_embed = nn.Sequential(
60
+ nn.Linear(cls_dim, embed_dim, bias=True),
61
+ nn.SiLU(),
62
+ nn.Linear(embed_dim, embed_dim, bias=True),)
63
+ else:
64
+ self.cls_embed = None
65
+
66
+ # time fusion
67
+ if time_fusion == 'token':
68
+ # put token at the beginning of sequence
69
+ self.extras = 2 if self.cls_embed else 1
70
+ self.time_pe = PE_wrapper(dim=embed_dim, method='abs', length=self.extras)
71
+ elif time_fusion in ['ada', 'ada_single', 'ada_lora', 'ada_lora_bias']:
72
+ self.use_adanorm = True
73
+ # aviod repetitive silu for each adaln block
74
+ self.time_act = nn.SiLU()
75
+ self.extras = 0
76
+ self.time_ada_final = nn.Linear(embed_dim, 2 * embed_dim, bias=True)
77
+ if time_fusion in ['ada_single', 'ada_lora', 'ada_lora_bias']:
78
+ # shared adaln
79
+ self.time_ada = nn.Linear(embed_dim, 6 * embed_dim, bias=True)
80
+ else:
81
+ self.time_ada = None
82
+ else:
83
+ raise NotImplementedError
84
+ print(f'time fusion mode: {self.time_fusion}')
85
+
86
+ # context
87
+ # use a simple projection
88
+ self.use_context = False
89
+ self.context_cross = False
90
+ self.context_max_length = context_max_length
91
+ self.context_fusion = 'none'
92
+ if context_dim is not None:
93
+ self.use_context = True
94
+ self.context_embed = nn.Sequential(
95
+ nn.Linear(context_dim, embed_dim, bias=True),
96
+ nn.SiLU(),
97
+ nn.Linear(embed_dim, embed_dim, bias=True),)
98
+ self.context_fusion = context_fusion
99
+ if context_fusion == 'concat' or context_fusion == 'joint':
100
+ self.extras += context_max_length
101
+ self.context_pe = PE_wrapper(dim=embed_dim,
102
+ method=context_pe_method,
103
+ length=context_max_length)
104
+ # no cross attention layers
105
+ context_dim = None
106
+ elif context_fusion == 'cross':
107
+ self.context_pe = PE_wrapper(dim=embed_dim,
108
+ method=context_pe_method,
109
+ length=context_max_length)
110
+ self.context_cross = True
111
+ context_dim = embed_dim
112
+ else:
113
+ raise NotImplementedError
114
+ print(f'context fusion mode: {context_fusion}')
115
+ print(f'context position embedding: {context_pe_method}')
116
+
117
+ if self.context_fusion == 'joint':
118
+ Block = JointDiTBlock
119
+ self.use_skip = skip[0]
120
+ else:
121
+ Block = DiTBlock
122
+ self.use_skip = skip
123
+
124
+ # norm layers
125
+ if norm_layer == 'layernorm':
126
+ norm_layer = nn.LayerNorm
127
+ elif norm_layer == 'rmsnorm':
128
+ norm_layer = RMSNorm
129
+ else:
130
+ raise NotImplementedError
131
+
132
+ print(f'use long skip connection: {skip}')
133
+ self.in_blocks = nn.ModuleList([
134
+ Block(
135
+ dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
136
+ mlp_ratio=mlp_ratio,
137
+ qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
138
+ act_layer=act_layer, norm_layer=norm_layer,
139
+ time_fusion=time_fusion,
140
+ ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
141
+ skip=False, skip_norm=False,
142
+ rope_mode=self.rope,
143
+ context_norm=context_norm,
144
+ use_checkpoint=use_checkpoint)
145
+ for _ in range(depth // 2)])
146
+
147
+ self.mid_block = Block(
148
+ dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
149
+ mlp_ratio=mlp_ratio,
150
+ qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
151
+ act_layer=act_layer, norm_layer=norm_layer,
152
+ time_fusion=time_fusion,
153
+ ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
154
+ skip=False, skip_norm=False,
155
+ rope_mode=self.rope,
156
+ context_norm=context_norm,
157
+ use_checkpoint=use_checkpoint)
158
+
159
+ self.out_blocks = nn.ModuleList([
160
+ Block(
161
+ dim=embed_dim, context_dim=context_dim, num_heads=num_heads,
162
+ mlp_ratio=mlp_ratio,
163
+ qkv_bias=qkv_bias, qk_scale=qk_scale, qk_norm=qk_norm,
164
+ act_layer=act_layer, norm_layer=norm_layer,
165
+ time_fusion=time_fusion,
166
+ ada_lora_rank=ada_lora_rank, ada_lora_alpha=ada_lora_alpha,
167
+ skip=skip, skip_norm=skip_norm,
168
+ rope_mode=self.rope,
169
+ context_norm=context_norm,
170
+ use_checkpoint=use_checkpoint)
171
+ for _ in range(depth // 2)])
172
+
173
+ # FinalLayer block
174
+ self.use_conv = use_conv
175
+ self.final_block = FinalBlock(embed_dim=embed_dim,
176
+ patch_size=patch_size,
177
+ img_size=img_size,
178
+ in_chans=out_chans,
179
+ input_type=input_type,
180
+ norm_layer=norm_layer,
181
+ use_conv=use_conv,
182
+ use_adanorm=self.use_adanorm)
183
+ self.initialize_weights()
184
+
185
+ def _init_ada(self):
186
+ if self.time_fusion == 'ada':
187
+ nn.init.constant_(self.time_ada_final.weight, 0)
188
+ nn.init.constant_(self.time_ada_final.bias, 0)
189
+ for block in self.in_blocks:
190
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
191
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
192
+ nn.init.constant_(self.mid_block.adaln.time_ada.weight, 0)
193
+ nn.init.constant_(self.mid_block.adaln.time_ada.bias, 0)
194
+ for block in self.out_blocks:
195
+ nn.init.constant_(block.adaln.time_ada.weight, 0)
196
+ nn.init.constant_(block.adaln.time_ada.bias, 0)
197
+ elif self.time_fusion == 'ada_single':
198
+ nn.init.constant_(self.time_ada.weight, 0)
199
+ nn.init.constant_(self.time_ada.bias, 0)
200
+ nn.init.constant_(self.time_ada_final.weight, 0)
201
+ nn.init.constant_(self.time_ada_final.bias, 0)
202
+ elif self.time_fusion in ['ada_lora', 'ada_lora_bias']:
203
+ nn.init.constant_(self.time_ada.weight, 0)
204
+ nn.init.constant_(self.time_ada.bias, 0)
205
+ nn.init.constant_(self.time_ada_final.weight, 0)
206
+ nn.init.constant_(self.time_ada_final.bias, 0)
207
+ for block in self.in_blocks:
208
+ nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
209
+ a=math.sqrt(5))
210
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
211
+ nn.init.kaiming_uniform_(self.mid_block.adaln.lora_a.weight,
212
+ a=math.sqrt(5))
213
+ nn.init.constant_(self.mid_block.adaln.lora_b.weight, 0)
214
+ for block in self.out_blocks:
215
+ nn.init.kaiming_uniform_(block.adaln.lora_a.weight,
216
+ a=math.sqrt(5))
217
+ nn.init.constant_(block.adaln.lora_b.weight, 0)
218
+
219
+ def initialize_weights(self):
220
+ # Basic init for all layers
221
+ def _basic_init(module):
222
+ if isinstance(module, nn.Linear):
223
+ torch.nn.init.xavier_uniform_(module.weight)
224
+ if module.bias is not None:
225
+ nn.init.constant_(module.bias, 0)
226
+ self.apply(_basic_init)
227
+
228
+ # init patch Conv like Linear
229
+ w = self.patch_embed.proj.weight.data
230
+ nn.init.xavier_uniform_(w.view([w.shape[0], -1]))
231
+ nn.init.constant_(self.patch_embed.proj.bias, 0)
232
+
233
+ # Zero-out AdaLN
234
+ if self.use_adanorm:
235
+ self._init_ada()
236
+
237
+ # Zero-out Cross Attention
238
+ if self.context_cross:
239
+ for block in self.in_blocks:
240
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
241
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
242
+ nn.init.constant_(self.mid_block.cross_attn.proj.weight, 0)
243
+ nn.init.constant_(self.mid_block.cross_attn.proj.bias, 0)
244
+ for block in self.out_blocks:
245
+ nn.init.constant_(block.cross_attn.proj.weight, 0)
246
+ nn.init.constant_(block.cross_attn.proj.bias, 0)
247
+
248
+ # Zero-out cls embedding
249
+ if self.cls_embed:
250
+ if self.use_adanorm:
251
+ nn.init.constant_(self.cls_embed[-1].weight, 0)
252
+ nn.init.constant_(self.cls_embed[-1].bias, 0)
253
+
254
+ # Zero-out Output
255
+ # might not zero-out this when using v-prediction
256
+ # it could be good when using noise-prediction
257
+ # nn.init.constant_(self.final_block.linear.weight, 0)
258
+ # nn.init.constant_(self.final_block.linear.bias, 0)
259
+ # if self.use_conv:
260
+ # nn.init.constant_(self.final_block.final_layer.weight.data, 0)
261
+ # nn.init.constant_(self.final_block.final_layer.bias, 0)
262
+
263
+ # init out Conv
264
+ if self.use_conv:
265
+ nn.init.xavier_uniform_(self.final_block.final_layer.weight)
266
+ nn.init.constant_(self.final_block.final_layer.bias, 0)
267
+
268
+ def _concat_x_context(self, x, context, x_mask=None, context_mask=None):
269
+ assert context.shape[-2] == self.context_max_length
270
+ # Check if either x_mask or context_mask is provided
271
+ B = x.shape[0]
272
+ # Create default masks if they are not provided
273
+ if x_mask is None:
274
+ x_mask = torch.ones(B, x.shape[-2], device=x.device).bool()
275
+ if context_mask is None:
276
+ context_mask = torch.ones(B, context.shape[-2],
277
+ device=context.device).bool()
278
+ # Concatenate the masks along the second dimension (dim=1)
279
+ x_mask = torch.cat([context_mask, x_mask], dim=1)
280
+ # Concatenate context and x along the second dimension (dim=1)
281
+ x = torch.cat((context, x), dim=1)
282
+ return x, x_mask
283
+
284
+ def forward(self, x, timesteps, context,
285
+ x_mask=None, context_mask=None,
286
+ cls_token=None, controlnet_skips=None,
287
+ ):
288
+ # make it compatible with int time step during inference
289
+ if timesteps.dim() == 0:
290
+ timesteps = timesteps.expand(x.shape[0]).to(x.device, dtype=torch.long)
291
+
292
+ x = self.patch_embed(x)
293
+ x = self.x_pe(x)
294
+
295
+ B, L, D = x.shape
296
+
297
+ if self.use_context:
298
+ context_token = self.context_embed(context)
299
+ context_token = self.context_pe(context_token)
300
+ if self.context_fusion == 'concat' or self.context_fusion == 'joint':
301
+ x, x_mask = self._concat_x_context(x=x, context=context_token,
302
+ x_mask=x_mask,
303
+ context_mask=context_mask)
304
+ context_token, context_mask = None, None
305
+ else:
306
+ context_token, context_mask = None, None
307
+
308
+ time_token = self.time_embed(timesteps)
309
+ if self.cls_embed:
310
+ cls_token = self.cls_embed(cls_token)
311
+ time_ada = None
312
+ time_ada_final = None
313
+ if self.use_adanorm:
314
+ if self.cls_embed:
315
+ time_token = time_token + cls_token
316
+ time_token = self.time_act(time_token)
317
+ time_ada_final = self.time_ada_final(time_token)
318
+ if self.time_ada is not None:
319
+ time_ada = self.time_ada(time_token)
320
+ else:
321
+ time_token = time_token.unsqueeze(dim=1)
322
+ if self.cls_embed:
323
+ cls_token = cls_token.unsqueeze(dim=1)
324
+ time_token = torch.cat([time_token, cls_token], dim=1)
325
+ time_token = self.time_pe(time_token)
326
+ x = torch.cat((time_token, x), dim=1)
327
+ if x_mask is not None:
328
+ x_mask = torch.cat(
329
+ [torch.ones(B, time_token.shape[1], device=x_mask.device).bool(),
330
+ x_mask], dim=1)
331
+ time_token = None
332
+
333
+ skips = []
334
+ for blk in self.in_blocks:
335
+ x = blk(x=x, time_token=time_token, time_ada=time_ada,
336
+ skip=None, context=context_token,
337
+ x_mask=x_mask, context_mask=context_mask,
338
+ extras=self.extras)
339
+ if self.use_skip:
340
+ skips.append(x)
341
+
342
+ x = self.mid_block(x=x, time_token=time_token, time_ada=time_ada,
343
+ skip=None, context=context_token,
344
+ x_mask=x_mask, context_mask=context_mask,
345
+ extras=self.extras)
346
+ for blk in self.out_blocks:
347
+ if self.use_skip:
348
+ skip = skips.pop()
349
+ if controlnet_skips:
350
+ # add to skip like u-net controlnet
351
+ skip = skip + controlnet_skips.pop()
352
+ else:
353
+ skip = None
354
+ if controlnet_skips:
355
+ # directly add to x
356
+ x = x + controlnet_skips.pop()
357
+
358
+ x = blk(x=x, time_token=time_token, time_ada=time_ada,
359
+ skip=skip, context=context_token,
360
+ x_mask=x_mask, context_mask=context_mask,
361
+ extras=self.extras)
362
+
363
+ x = self.final_block(x, time_ada=time_ada_final, extras=self.extras)
364
+
365
+ return x
src/models/__pycache__/attention.cpython-311.pyc ADDED
Binary file (6.05 kB). View file
 
src/models/__pycache__/blocks.cpython-310.pyc ADDED
Binary file (7.32 kB). View file
 
src/models/__pycache__/blocks.cpython-311.pyc ADDED
Binary file (14.9 kB). View file
 
src/models/__pycache__/conditioners.cpython-310.pyc ADDED
Binary file (5.63 kB). View file
 
src/models/__pycache__/conditioners.cpython-311.pyc ADDED
Binary file (10.3 kB). View file
 
src/models/__pycache__/controlnet.cpython-311.pyc ADDED
Binary file (15.2 kB). View file
 
src/models/__pycache__/modules.cpython-311.pyc ADDED
Binary file (11.3 kB). View file
 
src/models/__pycache__/rotary.cpython-311.pyc ADDED
Binary file (4.83 kB). View file
 
src/models/__pycache__/timm.cpython-311.pyc ADDED
Binary file (6.45 kB). View file
 
src/models/__pycache__/udit.cpython-310.pyc ADDED
Binary file (7.9 kB). View file
 
src/models/__pycache__/udit.cpython-311.pyc ADDED
Binary file (18.5 kB). View file
 
src/models/blocks.py ADDED
@@ -0,0 +1,325 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from torch.utils.checkpoint import checkpoint
4
+ from .utils.attention import Attention, JointAttention
5
+ from .utils.modules import unpatchify, FeedForward
6
+ from .utils.modules import film_modulate
7
+
8
+
9
+ class AdaLN(nn.Module):
10
+ def __init__(self, dim, ada_mode='ada', r=None, alpha=None):
11
+ super().__init__()
12
+ self.ada_mode = ada_mode
13
+ self.scale_shift_table = None
14
+ if ada_mode == 'ada':
15
+ # move nn.silu outside
16
+ self.time_ada = nn.Linear(dim, 6 * dim, bias=True)
17
+ elif ada_mode == 'ada_single':
18
+ # adaln used in pixel-art alpha
19
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
20
+ elif ada_mode in ['ada_lora', 'ada_lora_bias']:
21
+ self.lora_a = nn.Linear(dim, r * 6, bias=False)
22
+ self.lora_b = nn.Linear(r * 6, dim * 6, bias=False)
23
+ self.scaling = alpha / r
24
+ if ada_mode == 'ada_lora_bias':
25
+ # take bias out for consistency
26
+ self.scale_shift_table = nn.Parameter(torch.zeros(6, dim))
27
+ else:
28
+ raise NotImplementedError
29
+
30
+ def forward(self, time_token=None, time_ada=None):
31
+ if self.ada_mode == 'ada':
32
+ assert time_ada is None
33
+ B = time_token.shape[0]
34
+ time_ada = self.time_ada(time_token).reshape(B, 6, -1)
35
+ elif self.ada_mode == 'ada_single':
36
+ B = time_ada.shape[0]
37
+ time_ada = time_ada.reshape(B, 6, -1)
38
+ time_ada = self.scale_shift_table[None] + time_ada
39
+ elif self.ada_mode in ['ada_lora', 'ada_lora_bias']:
40
+ B = time_ada.shape[0]
41
+ time_ada_lora = self.lora_b(self.lora_a(time_token)) * self.scaling
42
+ time_ada = time_ada + time_ada_lora
43
+ time_ada = time_ada.reshape(B, 6, -1)
44
+ if self.scale_shift_table is not None:
45
+ time_ada = self.scale_shift_table[None] + time_ada
46
+ else:
47
+ raise NotImplementedError
48
+ return time_ada
49
+
50
+
51
+ class DiTBlock(nn.Module):
52
+ """
53
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
54
+ """
55
+
56
+ def __init__(self, dim, context_dim=None,
57
+ num_heads=8, mlp_ratio=4.,
58
+ qkv_bias=False, qk_scale=None, qk_norm=None,
59
+ act_layer='gelu', norm_layer=nn.LayerNorm,
60
+ time_fusion='none',
61
+ ada_lora_rank=None, ada_lora_alpha=None,
62
+ skip=False, skip_norm=False,
63
+ rope_mode='none',
64
+ context_norm=False,
65
+ use_checkpoint=False):
66
+
67
+ super().__init__()
68
+ self.norm1 = norm_layer(dim)
69
+ self.attn = Attention(dim=dim,
70
+ num_heads=num_heads,
71
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
72
+ qk_norm=qk_norm,
73
+ rope_mode=rope_mode)
74
+
75
+ if context_dim is not None:
76
+ self.use_context = True
77
+ self.cross_attn = Attention(dim=dim,
78
+ num_heads=num_heads,
79
+ context_dim=context_dim,
80
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
81
+ qk_norm=qk_norm,
82
+ rope_mode='none')
83
+ self.norm2 = norm_layer(dim)
84
+ if context_norm:
85
+ self.norm_context = norm_layer(context_dim)
86
+ else:
87
+ self.norm_context = nn.Identity()
88
+ else:
89
+ self.use_context = False
90
+
91
+ self.norm3 = norm_layer(dim)
92
+ self.mlp = FeedForward(dim=dim, mult=mlp_ratio,
93
+ activation_fn=act_layer, dropout=0)
94
+
95
+ self.use_adanorm = True if time_fusion != 'token' else False
96
+ if self.use_adanorm:
97
+ self.adaln = AdaLN(dim, ada_mode=time_fusion,
98
+ r=ada_lora_rank, alpha=ada_lora_alpha)
99
+ if skip:
100
+ self.skip_norm = norm_layer(2 * dim) if skip_norm else nn.Identity()
101
+ self.skip_linear = nn.Linear(2 * dim, dim)
102
+ else:
103
+ self.skip_linear = None
104
+
105
+ self.use_checkpoint = use_checkpoint
106
+
107
+ def forward(self, x, time_token=None, time_ada=None,
108
+ skip=None, context=None,
109
+ x_mask=None, context_mask=None, extras=None):
110
+ if self.use_checkpoint:
111
+ return checkpoint(self._forward, x,
112
+ time_token, time_ada, skip, context,
113
+ x_mask, context_mask, extras,
114
+ use_reentrant=False)
115
+ else:
116
+ return self._forward(x,
117
+ time_token, time_ada, skip, context,
118
+ x_mask, context_mask, extras)
119
+
120
+ def _forward(self, x, time_token=None, time_ada=None,
121
+ skip=None, context=None,
122
+ x_mask=None, context_mask=None, extras=None):
123
+ B, T, C = x.shape
124
+ if self.skip_linear is not None:
125
+ assert skip is not None
126
+ cat = torch.cat([x, skip], dim=-1)
127
+ cat = self.skip_norm(cat)
128
+ x = self.skip_linear(cat)
129
+
130
+ if self.use_adanorm:
131
+ time_ada = self.adaln(time_token, time_ada)
132
+ (shift_msa, scale_msa, gate_msa,
133
+ shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
134
+
135
+ # self attention
136
+ if self.use_adanorm:
137
+ x_norm = film_modulate(self.norm1(x), shift=shift_msa,
138
+ scale=scale_msa)
139
+ x = x + (1 - gate_msa) * self.attn(x_norm, context=None,
140
+ context_mask=x_mask,
141
+ extras=extras)
142
+ else:
143
+ x = x + self.attn(self.norm1(x), context=None, context_mask=x_mask,
144
+ extras=extras)
145
+
146
+ # cross attention
147
+ if self.use_context:
148
+ assert context is not None
149
+ x = x + self.cross_attn(x=self.norm2(x),
150
+ context=self.norm_context(context),
151
+ context_mask=context_mask, extras=extras)
152
+
153
+ # mlp
154
+ if self.use_adanorm:
155
+ x_norm = film_modulate(self.norm3(x), shift=shift_mlp, scale=scale_mlp)
156
+ x = x + (1 - gate_mlp) * self.mlp(x_norm)
157
+ else:
158
+ x = x + self.mlp(self.norm3(x))
159
+
160
+ return x
161
+
162
+
163
+ class JointDiTBlock(nn.Module):
164
+ """
165
+ A modified PixArt block with adaptive layer norm (adaLN-single) conditioning.
166
+ """
167
+
168
+ def __init__(self, dim, context_dim=None,
169
+ num_heads=8, mlp_ratio=4.,
170
+ qkv_bias=False, qk_scale=None, qk_norm=None,
171
+ act_layer='gelu', norm_layer=nn.LayerNorm,
172
+ time_fusion='none',
173
+ ada_lora_rank=None, ada_lora_alpha=None,
174
+ skip=(False, False),
175
+ rope_mode=False,
176
+ context_norm=False,
177
+ use_checkpoint=False,):
178
+
179
+ super().__init__()
180
+ # no cross attention
181
+ assert context_dim is None
182
+ self.attn_norm_x = norm_layer(dim)
183
+ self.attn_norm_c = norm_layer(dim)
184
+ self.attn = JointAttention(dim=dim,
185
+ num_heads=num_heads,
186
+ qkv_bias=qkv_bias, qk_scale=qk_scale,
187
+ qk_norm=qk_norm,
188
+ rope_mode=rope_mode)
189
+ self.ffn_norm_x = norm_layer(dim)
190
+ self.ffn_norm_c = norm_layer(dim)
191
+ self.mlp_x = FeedForward(dim=dim, mult=mlp_ratio,
192
+ activation_fn=act_layer, dropout=0)
193
+ self.mlp_c = FeedForward(dim=dim, mult=mlp_ratio,
194
+ activation_fn=act_layer, dropout=0)
195
+
196
+ # Zero-out the shift table
197
+ self.use_adanorm = True if time_fusion != 'token' else False
198
+ if self.use_adanorm:
199
+ self.adaln = AdaLN(dim, ada_mode=time_fusion,
200
+ r=ada_lora_rank, alpha=ada_lora_alpha)
201
+
202
+ if skip is False:
203
+ skip_x, skip_c = False, False
204
+ else:
205
+ skip_x, skip_c = skip
206
+
207
+ self.skip_linear_x = nn.Linear(2 * dim, dim) if skip_x else None
208
+ self.skip_linear_c = nn.Linear(2 * dim, dim) if skip_c else None
209
+
210
+ self.use_checkpoint = use_checkpoint
211
+
212
+ def forward(self, x, time_token=None, time_ada=None,
213
+ skip=None, context=None,
214
+ x_mask=None, context_mask=None, extras=None):
215
+ if self.use_checkpoint:
216
+ return checkpoint(self._forward, x,
217
+ time_token, time_ada, skip,
218
+ context, x_mask, context_mask, extras,
219
+ use_reentrant=False)
220
+ else:
221
+ return self._forward(x,
222
+ time_token, time_ada, skip,
223
+ context, x_mask, context_mask, extras)
224
+
225
+ def _forward(self, x, time_token=None, time_ada=None,
226
+ skip=None, context=None,
227
+ x_mask=None, context_mask=None, extras=None):
228
+
229
+ assert context is None and context_mask is None
230
+
231
+ context, x = x[:, :extras, :], x[:, extras:, :]
232
+ context_mask, x_mask = x_mask[:, :extras], x_mask[:, extras:]
233
+
234
+ if skip is not None:
235
+ skip_c, skip_x = skip[:, :extras, :], skip[:, extras:, :]
236
+
237
+ B, T, C = x.shape
238
+ if self.skip_linear_x is not None:
239
+ x = self.skip_linear_x(torch.cat([x, skip_x], dim=-1))
240
+
241
+ if self.skip_linear_c is not None:
242
+ context = self.skip_linear_c(torch.cat([context, skip_c], dim=-1))
243
+
244
+ if self.use_adanorm:
245
+ time_ada = self.adaln(time_token, time_ada)
246
+ (shift_msa, scale_msa, gate_msa,
247
+ shift_mlp, scale_mlp, gate_mlp) = time_ada.chunk(6, dim=1)
248
+
249
+ # self attention
250
+ x_norm = self.attn_norm_x(x)
251
+ c_norm = self.attn_norm_c(context)
252
+ if self.use_adanorm:
253
+ x_norm = film_modulate(x_norm, shift=shift_msa, scale=scale_msa)
254
+ x_out, c_out = self.attn(x_norm, context=c_norm,
255
+ x_mask=x_mask, context_mask=context_mask,
256
+ extras=extras)
257
+ if self.use_adanorm:
258
+ x = x + (1 - gate_msa) * x_out
259
+ else:
260
+ x = x + x_out
261
+ context = context + c_out
262
+
263
+ # mlp
264
+ if self.use_adanorm:
265
+ x_norm = film_modulate(self.ffn_norm_x(x),
266
+ shift=shift_mlp, scale=scale_mlp)
267
+ x = x + (1 - gate_mlp) * self.mlp_x(x_norm)
268
+ else:
269
+ x = x + self.mlp_x(self.ffn_norm_x(x))
270
+
271
+ c_norm = self.ffn_norm_c(context)
272
+ context = context + self.mlp_c(c_norm)
273
+
274
+ return torch.cat((context, x), dim=1)
275
+
276
+
277
+ class FinalBlock(nn.Module):
278
+ def __init__(self, embed_dim, patch_size, in_chans,
279
+ img_size,
280
+ input_type='2d',
281
+ norm_layer=nn.LayerNorm,
282
+ use_conv=True,
283
+ use_adanorm=True):
284
+ super().__init__()
285
+ self.in_chans = in_chans
286
+ self.img_size = img_size
287
+ self.input_type = input_type
288
+
289
+ self.norm = norm_layer(embed_dim)
290
+ if use_adanorm:
291
+ self.use_adanorm = True
292
+ else:
293
+ self.use_adanorm = False
294
+
295
+ if input_type == '2d':
296
+ self.patch_dim = patch_size ** 2 * in_chans
297
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
298
+ if use_conv:
299
+ self.final_layer = nn.Conv2d(self.in_chans, self.in_chans,
300
+ 3, padding=1)
301
+ else:
302
+ self.final_layer = nn.Identity()
303
+
304
+ elif input_type == '1d':
305
+ self.patch_dim = patch_size * in_chans
306
+ self.linear = nn.Linear(embed_dim, self.patch_dim, bias=True)
307
+ if use_conv:
308
+ self.final_layer = nn.Conv1d(self.in_chans, self.in_chans,
309
+ 3, padding=1)
310
+ else:
311
+ self.final_layer = nn.Identity()
312
+
313
+ def forward(self, x, time_ada=None, extras=0):
314
+ B, T, C = x.shape
315
+ x = x[:, extras:, :]
316
+ # only handle generation target
317
+ if self.use_adanorm:
318
+ shift, scale = time_ada.reshape(B, 2, -1).chunk(2, dim=1)
319
+ x = film_modulate(self.norm(x), shift, scale)
320
+ else:
321
+ x = self.norm(x)
322
+ x = self.linear(x)
323
+ x = unpatchify(x, self.in_chans, self.input_type, self.img_size)
324
+ x = self.final_layer(x)
325
+ return x