teticio commited on
Commit
e94fc5c
1 Parent(s): e97b301

add audio_diffusion_pipeline notebook

Browse files
notebooks/audio_diffusion_pipeline.ipynb ADDED
@@ -0,0 +1,688 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "cells": [
3
+ {
4
+ "cell_type": "markdown",
5
+ "id": "fef7e1fb",
6
+ "metadata": {},
7
+ "source": [
8
+ "<a href=\"https://colab.research.google.com/github/teticio/audio-diffusion/blob/master/notebooks/audio_diffusion_pipeline.ipynb\" target=\"_parent\"><img src=\"https://colab.research.google.com/assets/colab-badge.svg\" alt=\"Open In Colab\"/></a>"
9
+ ]
10
+ },
11
+ {
12
+ "cell_type": "markdown",
13
+ "id": "2ada074b",
14
+ "metadata": {},
15
+ "source": [
16
+ "# Audio Diffusion\n",
17
+ "For training scripts and notebooks visit https://github.com/teticio/audio-diffusion"
18
+ ]
19
+ },
20
+ {
21
+ "cell_type": "code",
22
+ "execution_count": null,
23
+ "id": "6c7800a6",
24
+ "metadata": {},
25
+ "outputs": [],
26
+ "source": [
27
+ "try:\n",
28
+ " # are we running on Google Colab?\n",
29
+ " import google.colab\n",
30
+ " !pip install -q -r diffusers torch librosa\n",
31
+ "except:\n",
32
+ " pass"
33
+ ]
34
+ },
35
+ {
36
+ "cell_type": "code",
37
+ "execution_count": null,
38
+ "id": "c2fc0e7a",
39
+ "metadata": {},
40
+ "outputs": [],
41
+ "source": [
42
+ "import torch\n",
43
+ "import random\n",
44
+ "import librosa\n",
45
+ "import numpy as np\n",
46
+ "from datasets import load_dataset\n",
47
+ "from IPython.display import Audio\n",
48
+ "from librosa.beat import beat_track\n",
49
+ "from diffusers import DiffusionPipeline, Mel"
50
+ ]
51
+ },
52
+ {
53
+ "cell_type": "code",
54
+ "execution_count": null,
55
+ "id": "b294a94a",
56
+ "metadata": {},
57
+ "outputs": [],
58
+ "source": [
59
+ "mel = Mel()\n",
60
+ "sample_rate = mel.get_sample_rate()\n",
61
+ "device = \"cuda\" if torch.cuda.is_available() else \"cpu\"\n",
62
+ "generator = torch.Generator(device=device)"
63
+ ]
64
+ },
65
+ {
66
+ "cell_type": "markdown",
67
+ "id": "f3feb265",
68
+ "metadata": {},
69
+ "source": [
70
+ "## DDPM (De-noising Diffusion Probabilistic Models)"
71
+ ]
72
+ },
73
+ {
74
+ "cell_type": "markdown",
75
+ "id": "7fd945bb",
76
+ "metadata": {},
77
+ "source": [
78
+ "### Select model"
79
+ ]
80
+ },
81
+ {
82
+ "cell_type": "code",
83
+ "execution_count": null,
84
+ "id": "97f24046",
85
+ "metadata": {},
86
+ "outputs": [],
87
+ "source": [
88
+ "#@markdown teticio/audio-diffusion-256 - trained on my Spotify \"liked\" playlist\n",
89
+ "\n",
90
+ "#@markdown teticio/audio-diffusion-breaks-256 - trained on samples used in music\n",
91
+ "\n",
92
+ "#@markdown teticio/audio-diffusion-instrumental-hiphop-256 - trained on instrumental hiphop\n",
93
+ "\n",
94
+ "model_id = \"teticio/audio-diffusion-256\" #@param [\"teticio/audio-diffusion-256\", \"teticio/audio-diffusion-breaks-256\", \"audio-diffusion-instrumenal-hiphop-256\", \"teticio/audio-diffusion-ddim-256\"]"
95
+ ]
96
+ },
97
+ {
98
+ "cell_type": "code",
99
+ "execution_count": null,
100
+ "id": "a3d45c36",
101
+ "metadata": {},
102
+ "outputs": [],
103
+ "source": [
104
+ "audio_diffusion = DiffusionPipeline.from_pretrained(model_id).to(device)"
105
+ ]
106
+ },
107
+ {
108
+ "cell_type": "code",
109
+ "execution_count": null,
110
+ "id": "ab0d705c",
111
+ "metadata": {},
112
+ "outputs": [],
113
+ "source": [
114
+ "def loop_it(audio: np.ndarray,\n",
115
+ " sample_rate: int,\n",
116
+ " loops: int = 12) -> np.ndarray:\n",
117
+ " \"\"\"Loop audio\n",
118
+ "\n",
119
+ " Args:\n",
120
+ " audio (np.ndarray): audio as numpy array\n",
121
+ " sample_rate (int): sample rate of audio\n",
122
+ " loops (int): number of times to loop\n",
123
+ "\n",
124
+ " Returns:\n",
125
+ " (float, np.ndarray): sample rate and raw audio or None\n",
126
+ " \"\"\"\n",
127
+ " _, beats = beat_track(y=audio, sr=sample_rate, units='samples')\n",
128
+ " for beats_in_bar in [16, 12, 8, 4]:\n",
129
+ " if len(beats) > beats_in_bar:\n",
130
+ " return np.tile(audio[beats[0]:beats[beats_in_bar]], loops)\n",
131
+ " return None"
132
+ ]
133
+ },
134
+ {
135
+ "cell_type": "markdown",
136
+ "id": "011fb5a1",
137
+ "metadata": {},
138
+ "source": [
139
+ "### Run model inference to generate mel spectrogram, audios and loops"
140
+ ]
141
+ },
142
+ {
143
+ "cell_type": "code",
144
+ "execution_count": null,
145
+ "id": "b809fed5",
146
+ "metadata": {},
147
+ "outputs": [],
148
+ "source": [
149
+ "for _ in range(10):\n",
150
+ " seed = generator.seed()\n",
151
+ " print(f'Seed = {seed}')\n",
152
+ " generator.manual_seed(seed)\n",
153
+ " output = audio_diffusion(mel=mel, generator=generator)\n",
154
+ " image = output.images[0]\n",
155
+ " audio = output.audios[0, 0]\n",
156
+ " display(image)\n",
157
+ " display(Audio(audio, rate=sample_rate))\n",
158
+ " loop = loop_it(audio, sample_rate)\n",
159
+ " if loop is not None:\n",
160
+ " display(Audio(loop, rate=sample_rate))\n",
161
+ " else:\n",
162
+ " print(\"Unable to determine loop points\")"
163
+ ]
164
+ },
165
+ {
166
+ "cell_type": "markdown",
167
+ "id": "0bb03e33",
168
+ "metadata": {},
169
+ "source": [
170
+ "### Generate variations of audios"
171
+ ]
172
+ },
173
+ {
174
+ "cell_type": "markdown",
175
+ "id": "80e5b5fa",
176
+ "metadata": {},
177
+ "source": [
178
+ "Try playing around with `start_steps`. Values closer to zero will produce new samples, while values closer to 1,000 will produce samples more faithful to the original."
179
+ ]
180
+ },
181
+ {
182
+ "cell_type": "code",
183
+ "execution_count": null,
184
+ "id": "5074ec11",
185
+ "metadata": {},
186
+ "outputs": [],
187
+ "source": [
188
+ "seed = 2391504374279719 #@param {type:\"integer\"}\n",
189
+ "generator.manual_seed(seed)\n",
190
+ "output = audio_diffusion(mel=mel, generator=generator)\n",
191
+ "image = output.images[0]\n",
192
+ "audio = output.audios[0, 0]\n",
193
+ "display(image)\n",
194
+ "display(Audio(audio, rate=sample_rate))"
195
+ ]
196
+ },
197
+ {
198
+ "cell_type": "code",
199
+ "execution_count": null,
200
+ "id": "a0fefe28",
201
+ "metadata": {
202
+ "scrolled": false
203
+ },
204
+ "outputs": [],
205
+ "source": [
206
+ "start_step = 500 #@param {type:\"slider\", min:0, max:1000, step:10}\n",
207
+ "track = loop_it(audio, sample_rate, loops=1)\n",
208
+ "for variation in range(12):\n",
209
+ " output = audio_diffusion(mel=mel, raw_audio=audio, start_step=start_step)\n",
210
+ " image2 = output.images[0]\n",
211
+ " audio2 = output.audios[0, 0]\n",
212
+ " display(image2)\n",
213
+ " display(Audio(audio2, rate=sample_rate))\n",
214
+ " track = np.concatenate([track, loop_it(audio2, sample_rate, loops=1)])\n",
215
+ "display(Audio(track, rate=sample_rate))"
216
+ ]
217
+ },
218
+ {
219
+ "cell_type": "markdown",
220
+ "id": "58a876c1",
221
+ "metadata": {},
222
+ "source": [
223
+ "### Generate continuations (\"out-painting\")"
224
+ ]
225
+ },
226
+ {
227
+ "cell_type": "code",
228
+ "execution_count": null,
229
+ "id": "b95d5780",
230
+ "metadata": {},
231
+ "outputs": [],
232
+ "source": [
233
+ "overlap_secs = 2 #@param {type:\"integer\"}\n",
234
+ "start_step = 0 #@param {type:\"slider\", min:0, max:1000, step:10}\n",
235
+ "overlap_samples = overlap_secs * sample_rate\n",
236
+ "track = audio\n",
237
+ "for variation in range(12):\n",
238
+ " output = audio_diffusion(mel=mel,\n",
239
+ " raw_audio=audio[-overlap_samples:],\n",
240
+ " start_step=start_step,\n",
241
+ " mask_start_secs=overlap_secs)\n",
242
+ " image2 = output.images[0]\n",
243
+ " audio2 = output.audios[0, 0]\n",
244
+ " display(image2)\n",
245
+ " display(Audio(audio2, rate=sample_rate))\n",
246
+ " track = np.concatenate([track, audio2[overlap_samples:]])\n",
247
+ " audio = audio2\n",
248
+ "display(Audio(track, rate=sample_rate))"
249
+ ]
250
+ },
251
+ {
252
+ "cell_type": "markdown",
253
+ "id": "b6434d3f",
254
+ "metadata": {},
255
+ "source": [
256
+ "### Remix (style transfer)"
257
+ ]
258
+ },
259
+ {
260
+ "cell_type": "markdown",
261
+ "id": "0da030b2",
262
+ "metadata": {},
263
+ "source": [
264
+ "Alternatively, you can start from another audio altogether, resulting in a kind of style transfer. Maintaining the same seed during generation fixes the style, while masking helps stitch consecutive segments together more smoothly."
265
+ ]
266
+ },
267
+ {
268
+ "cell_type": "code",
269
+ "execution_count": null,
270
+ "id": "fc620a80",
271
+ "metadata": {},
272
+ "outputs": [],
273
+ "source": [
274
+ "try:\n",
275
+ " # are we running on Google Colab?\n",
276
+ " from google.colab import files\n",
277
+ " audio_file = list(files.upload().keys())[0]\n",
278
+ "except:\n",
279
+ " audio_file = \"/home/teticio/Music/liked/El Michels Affair - Glaciers Of Ice.mp3\""
280
+ ]
281
+ },
282
+ {
283
+ "cell_type": "code",
284
+ "execution_count": null,
285
+ "id": "5a257e69",
286
+ "metadata": {
287
+ "scrolled": false
288
+ },
289
+ "outputs": [],
290
+ "source": [
291
+ "start_step = 500 #@param {type:\"slider\", min:0, max:1000, step:10}\n",
292
+ "overlap_secs = 2 #@param {type:\"integer\"}\n",
293
+ "track_audio, _ = librosa.load(audio_file, mono=True, sr=sample_rate)\n",
294
+ "overlap_samples = overlap_secs * sample_rate\n",
295
+ "slice_size = mel.x_res * mel.hop_length\n",
296
+ "stride = slice_size - overlap_samples\n",
297
+ "generator = torch.Generator(device=device)\n",
298
+ "seed = generator.seed()\n",
299
+ "print(f'Seed = {seed}')\n",
300
+ "track = np.array([])\n",
301
+ "not_first = 0\n",
302
+ "for sample in range(len(track_audio) // stride):\n",
303
+ " generator.manual_seed(seed)\n",
304
+ " audio = np.array(track_audio[sample * stride:sample * stride + slice_size])\n",
305
+ " if not_first:\n",
306
+ " # Normalize and re-insert generated audio\n",
307
+ " audio[:overlap_samples] = audio2[-overlap_samples:] * np.max(\n",
308
+ " audio[:overlap_samples]) / np.max(audio2[-overlap_samples:])\n",
309
+ " output = audio_diffusion(mel=mel,\n",
310
+ " raw_audio=audio,\n",
311
+ " start_step=start_step,\n",
312
+ " generator=generator,\n",
313
+ " mask_start_secs=overlap_secs * not_first)\n",
314
+ " audio2 = output.audios[0, 0]\n",
315
+ " track = np.concatenate([track, audio2[overlap_samples * not_first:]])\n",
316
+ " not_first = 1\n",
317
+ " display(Audio(track, rate=sample_rate))"
318
+ ]
319
+ },
320
+ {
321
+ "cell_type": "markdown",
322
+ "id": "924ff9d5",
323
+ "metadata": {},
324
+ "source": [
325
+ "### Fill the gap (\"in-painting\")"
326
+ ]
327
+ },
328
+ {
329
+ "cell_type": "code",
330
+ "execution_count": null,
331
+ "id": "0200264c",
332
+ "metadata": {},
333
+ "outputs": [],
334
+ "source": [
335
+ "sample = 3 #@param {type:\"integer\"}\n",
336
+ "raw_audio = track_audio[sample * stride:sample * stride + slice_size]\n",
337
+ "output = audio_diffusion(mel=mel,\n",
338
+ " raw_audio=raw_audio,\n",
339
+ " mask_start_secs=1,\n",
340
+ " mask_end_secs=1,\n",
341
+ " step_generator=torch.Generator(device=device))\n",
342
+ "audio2 = output.audios[0, 0]\n",
343
+ "display(Audio(audio, rate=sample_rate))\n",
344
+ "display(Audio(audio2, rate=sample_rate))"
345
+ ]
346
+ },
347
+ {
348
+ "cell_type": "markdown",
349
+ "id": "efc32dae",
350
+ "metadata": {},
351
+ "source": [
352
+ "## DDIM (De-noising Diffusion Implicit Models)"
353
+ ]
354
+ },
355
+ {
356
+ "cell_type": "code",
357
+ "execution_count": null,
358
+ "id": "a021f78a",
359
+ "metadata": {},
360
+ "outputs": [],
361
+ "source": [
362
+ "audio_diffusion = DiffusionPipeline.from_pretrained('teticio/audio-diffusion-ddim-256').to(device)"
363
+ ]
364
+ },
365
+ {
366
+ "cell_type": "markdown",
367
+ "id": "deb23339",
368
+ "metadata": {},
369
+ "source": [
370
+ "### Generation can be done in many fewer steps with DDIMs"
371
+ ]
372
+ },
373
+ {
374
+ "cell_type": "code",
375
+ "execution_count": null,
376
+ "id": "c105a497",
377
+ "metadata": {},
378
+ "outputs": [],
379
+ "source": [
380
+ "for _ in range(10):\n",
381
+ " seed = generator.seed()\n",
382
+ " print(f'Seed = {seed}')\n",
383
+ " generator.manual_seed(seed)\n",
384
+ " output = audio_diffusion(mel=mel, generator=generator)\n",
385
+ " image = output.images[0]\n",
386
+ " audio = output.audios[0, 0]\n",
387
+ " display(image)\n",
388
+ " display(Audio(audio, rate=sample_rate))\n",
389
+ " loop = loop_it(audio, sample_rate)\n",
390
+ " if loop is not None:\n",
391
+ " display(Audio(loop, rate=sample_rate))\n",
392
+ " else:\n",
393
+ " print(\"Unable to determine loop points\")"
394
+ ]
395
+ },
396
+ {
397
+ "cell_type": "markdown",
398
+ "id": "cab4692c",
399
+ "metadata": {},
400
+ "source": [
401
+ "The parameter eta controls the variance:\n",
402
+ "* 0 - DDIM (deterministic)\n",
403
+ "* 1 - DDPM (De-noising Diffusion Probabilistic Model)"
404
+ ]
405
+ },
406
+ {
407
+ "cell_type": "code",
408
+ "execution_count": null,
409
+ "id": "72bdd207",
410
+ "metadata": {},
411
+ "outputs": [],
412
+ "source": [
413
+ "output = audio_diffusion(mel=mel, steps=1000, generator=generator, eta=1)\n",
414
+ "image = output.images[0]\n",
415
+ "audio = output.audios[0, 0]\n",
416
+ "display(image)\n",
417
+ "display(Audio(audio, rate=sample_rate))"
418
+ ]
419
+ },
420
+ {
421
+ "cell_type": "markdown",
422
+ "id": "b8d5442c",
423
+ "metadata": {},
424
+ "source": [
425
+ "### DDIMs can be used as encoders..."
426
+ ]
427
+ },
428
+ {
429
+ "cell_type": "code",
430
+ "execution_count": null,
431
+ "id": "269ee816",
432
+ "metadata": {},
433
+ "outputs": [],
434
+ "source": [
435
+ "# Doesn't have to be an audio from the train dataset, this is just for convenience\n",
436
+ "ds = load_dataset('teticio/audio-diffusion-256')"
437
+ ]
438
+ },
439
+ {
440
+ "cell_type": "code",
441
+ "execution_count": null,
442
+ "id": "278d1d80",
443
+ "metadata": {},
444
+ "outputs": [],
445
+ "source": [
446
+ "image = ds['train'][264]['image']\n",
447
+ "display(Audio(mel.image_to_audio(image), rate=sample_rate))"
448
+ ]
449
+ },
450
+ {
451
+ "cell_type": "code",
452
+ "execution_count": null,
453
+ "id": "912b54e4",
454
+ "metadata": {},
455
+ "outputs": [],
456
+ "source": [
457
+ "noise = audio_diffusion.encode([image])"
458
+ ]
459
+ },
460
+ {
461
+ "cell_type": "code",
462
+ "execution_count": null,
463
+ "id": "c7b31f97",
464
+ "metadata": {},
465
+ "outputs": [],
466
+ "source": [
467
+ "# Reconstruct original audio from noise\n",
468
+ "output = audio_diffusion(mel=mel, noise=noise, generator=generator)\n",
469
+ "image = output.images[0]\n",
470
+ "audio = output.audios[0, 0]\n",
471
+ "display(Audio(audio, rate=sample_rate))"
472
+ ]
473
+ },
474
+ {
475
+ "cell_type": "markdown",
476
+ "id": "998c776b",
477
+ "metadata": {},
478
+ "source": [
479
+ "### ...or to interpolate between audios"
480
+ ]
481
+ },
482
+ {
483
+ "cell_type": "code",
484
+ "execution_count": null,
485
+ "id": "33f82367",
486
+ "metadata": {},
487
+ "outputs": [],
488
+ "source": [
489
+ "image2 = ds['train'][15978]['image']\n",
490
+ "display(Audio(mel.image_to_audio(image2), rate=sample_rate))"
491
+ ]
492
+ },
493
+ {
494
+ "cell_type": "code",
495
+ "execution_count": null,
496
+ "id": "f93fb6c0",
497
+ "metadata": {},
498
+ "outputs": [],
499
+ "source": [
500
+ "noise2 = audio_diffusion.encode([image2])"
501
+ ]
502
+ },
503
+ {
504
+ "cell_type": "code",
505
+ "execution_count": null,
506
+ "id": "a4190563",
507
+ "metadata": {},
508
+ "outputs": [],
509
+ "source": [
510
+ "alpha = 0.5 #@param {type:\"slider\", min:0, max:1, step:0.1}\n",
511
+ "output = audio_diffusion(\n",
512
+ " mel=mel,\n",
513
+ " noise=audio_diffusion.slerp(noise, noise2, alpha),\n",
514
+ " generator=generator)\n",
515
+ "audio = output.audios[0, 0]\n",
516
+ "display(Audio(mel.image_to_audio(image), rate=sample_rate))\n",
517
+ "display(Audio(mel.image_to_audio(image2), rate=sample_rate))\n",
518
+ "display(Audio(audio, rate=sample_rate))"
519
+ ]
520
+ },
521
+ {
522
+ "cell_type": "markdown",
523
+ "id": "9b244547",
524
+ "metadata": {},
525
+ "source": [
526
+ "## Latent Audio Diffusion\n",
527
+ "Instead of de-noising images directly in the pixel space, we can work in the latent space of a pre-trained VAE (Variational AutoEncoder). This is much faster to train and run inference on, although the quality suffers as there are now three stages involved in encoding / decoding: mel spectrogram, VAE and de-noising."
528
+ ]
529
+ },
530
+ {
531
+ "cell_type": "code",
532
+ "execution_count": null,
533
+ "id": "a88b3fbb",
534
+ "metadata": {},
535
+ "outputs": [],
536
+ "source": [
537
+ "model_id = \"teticio/latent-audio-diffusion-ddim-256\" #@param [\"teticio/latent-audio-diffusion-256\", \"teticio/latent-audio-diffusion-ddim-256\"]"
538
+ ]
539
+ },
540
+ {
541
+ "cell_type": "code",
542
+ "execution_count": null,
543
+ "id": "15e353ee",
544
+ "metadata": {},
545
+ "outputs": [],
546
+ "source": [
547
+ "audio_diffusion = DiffusionPipeline.from_pretrained(model_id).to(device)"
548
+ ]
549
+ },
550
+ {
551
+ "cell_type": "code",
552
+ "execution_count": null,
553
+ "id": "fa0f0c8c",
554
+ "metadata": {},
555
+ "outputs": [],
556
+ "source": [
557
+ "seed = 3412253600050855 #@param {type:\"integer\"}\n",
558
+ "generator.manual_seed(seed)\n",
559
+ "output = audio_diffusion(mel=mel, generator=generator)\n",
560
+ "image = output.images[0]\n",
561
+ "audio = output.audios[0, 0]\n",
562
+ "display(image)\n",
563
+ "display(Audio(audio, rate=sample_rate))"
564
+ ]
565
+ },
566
+ {
567
+ "cell_type": "code",
568
+ "execution_count": null,
569
+ "id": "73dc575d",
570
+ "metadata": {},
571
+ "outputs": [],
572
+ "source": [
573
+ "seed2 = 7016114633369557 #@param {type:\"integer\"}\n",
574
+ "generator.manual_seed(seed2)\n",
575
+ "output = audio_diffusion(mel=mel, generator=generator)\n",
576
+ "image2 = output.images[0]\n",
577
+ "audio2 = output.audios[0, 0]\n",
578
+ "display(image2)\n",
579
+ "display(Audio(audio2, rate=sample_rate))"
580
+ ]
581
+ },
582
+ {
583
+ "cell_type": "markdown",
584
+ "id": "428d2d67",
585
+ "metadata": {},
586
+ "source": [
587
+ "### Interpolation in latent space\n",
588
+ "As the VAE forces a more compact, lower dimensional representation for the spectrograms, interpolation in latent space can lead to meaningful combinations of audios. In combination with the (deterministic) DDIM from the previous section, the model can be used as an encoder / decoder to a lower dimensional space."
589
+ ]
590
+ },
591
+ {
592
+ "cell_type": "code",
593
+ "execution_count": null,
594
+ "id": "72211c2b",
595
+ "metadata": {},
596
+ "outputs": [],
597
+ "source": [
598
+ "generator.manual_seed(seed)\n",
599
+ "latents = torch.randn(\n",
600
+ " (1, audio_diffusion.unet.in_channels, audio_diffusion.unet.sample_size[0],\n",
601
+ " audio_diffusion.unet.sample_size[1]),\n",
602
+ " generator=generator, device=device)\n",
603
+ "latents.shape"
604
+ ]
605
+ },
606
+ {
607
+ "cell_type": "code",
608
+ "execution_count": null,
609
+ "id": "6c732dbe",
610
+ "metadata": {},
611
+ "outputs": [],
612
+ "source": [
613
+ "generator.manual_seed(seed2)\n",
614
+ "latents2 = torch.randn(\n",
615
+ " (1, audio_diffusion.unet.in_channels, audio_diffusion.unet.sample_size[0],\n",
616
+ " audio_diffusion.unet.sample_size[1]),\n",
617
+ " generator=generator,\n",
618
+ " device=device)\n",
619
+ "latents2.shape"
620
+ ]
621
+ },
622
+ {
623
+ "cell_type": "code",
624
+ "execution_count": null,
625
+ "id": "159bcfc4",
626
+ "metadata": {},
627
+ "outputs": [],
628
+ "source": [
629
+ "alpha = 0.5 #@param {type:\"slider\", min:0, max:1, step:0.1}\n",
630
+ "output = audio_diffusion(\n",
631
+ " mel=mel,\n",
632
+ " noise=audio_diffusion.slerp(latents, latents2, alpha),\n",
633
+ " generator=generator)\n",
634
+ "audio3 = output.audios[0, 0]\n",
635
+ "display(Audio(audio, rate=mel.get_sample_rate()))\n",
636
+ "display(Audio(audio2, rate=mel.get_sample_rate()))\n",
637
+ "display(Audio(audio3, rate=sample_rate))"
638
+ ]
639
+ },
640
+ {
641
+ "cell_type": "code",
642
+ "execution_count": null,
643
+ "id": "ce6c9cc1",
644
+ "metadata": {},
645
+ "outputs": [],
646
+ "source": []
647
+ }
648
+ ],
649
+ "metadata": {
650
+ "accelerator": "GPU",
651
+ "colab": {
652
+ "provenance": []
653
+ },
654
+ "gpuClass": "standard",
655
+ "kernelspec": {
656
+ "display_name": "huggingface",
657
+ "language": "python",
658
+ "name": "huggingface"
659
+ },
660
+ "language_info": {
661
+ "codemirror_mode": {
662
+ "name": "ipython",
663
+ "version": 3
664
+ },
665
+ "file_extension": ".py",
666
+ "mimetype": "text/x-python",
667
+ "name": "python",
668
+ "nbconvert_exporter": "python",
669
+ "pygments_lexer": "ipython3",
670
+ "version": "3.10.6"
671
+ },
672
+ "toc": {
673
+ "base_numbering": 1,
674
+ "nav_menu": {},
675
+ "number_sections": true,
676
+ "sideBar": true,
677
+ "skip_h1_title": false,
678
+ "title_cell": "Table of Contents",
679
+ "title_sidebar": "Contents",
680
+ "toc_cell": false,
681
+ "toc_position": {},
682
+ "toc_section_display": true,
683
+ "toc_window_display": false
684
+ }
685
+ },
686
+ "nbformat": 4,
687
+ "nbformat_minor": 5
688
+ }
notebooks/test_model.ipynb CHANGED
@@ -309,10 +309,10 @@
309
  "outputs": [],
310
  "source": [
311
  "slice = 3 #@param {type:\"integer\"}\n",
312
- "audio = mel.get_audio_slice(slice)\n",
313
  "_, (sample_rate,\n",
314
  " audio2) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
315
- " raw_audio=mel.get_audio_slice(slice),\n",
316
  " mask_start_secs=1,\n",
317
  " mask_end_secs=1,\n",
318
  " step_generator=torch.Generator())\n",
@@ -471,7 +471,7 @@
471
  "metadata": {},
472
  "outputs": [],
473
  "source": [
474
- "noise2 = audio_diffusion.pipe.encode([image2], steps=1000)"
475
  ]
476
  },
477
  {
 
309
  "outputs": [],
310
  "source": [
311
  "slice = 3 #@param {type:\"integer\"}\n",
312
+ "raw_audio = mel.get_audio_slice(slice)\n",
313
  "_, (sample_rate,\n",
314
  " audio2) = audio_diffusion.generate_spectrogram_and_audio_from_audio(\n",
315
+ " raw_audio=raw_audio,\n",
316
  " mask_start_secs=1,\n",
317
  " mask_end_secs=1,\n",
318
  " step_generator=torch.Generator())\n",
 
471
  "metadata": {},
472
  "outputs": [],
473
  "source": [
474
+ "noise2 = audio_diffusion.pipe.encode([image2])"
475
  ]
476
  },
477
  {