diff --git "a/notebooks/test-model.ipynb" "b/notebooks/test-model.ipynb"
--- "a/notebooks/test-model.ipynb"
+++ "b/notebooks/test-model.ipynb"
@@ -2,8 +2,8 @@
"cells": [
{
"cell_type": "code",
- "execution_count": 1,
- "id": "fd262b00",
+ "execution_count": 2,
+ "id": "b447e2c4",
"metadata": {},
"outputs": [],
"source": [
@@ -14,8 +14,8 @@
},
{
"cell_type": "code",
- "execution_count": 28,
- "id": "d2253762",
+ "execution_count": 11,
+ "id": "c2fc0e7a",
"metadata": {},
"outputs": [],
"source": [
@@ -23,14 +23,14 @@
"from src.mel import Mel\n",
"from PIL import ImageOps, Image\n",
"from IPython.display import Audio\n",
- "from datasets import load_from_disk\n",
- "from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline"
+ "from diffusers import DDPMPipeline\n",
+ "from datasets import load_from_disk"
]
},
{
"cell_type": "code",
- "execution_count": 3,
- "id": "293dd2c7",
+ "execution_count": 12,
+ "id": "a3d45c36",
"metadata": {},
"outputs": [],
"source": [
@@ -39,7 +39,7 @@
},
{
"cell_type": "markdown",
- "id": "5bdb2648",
+ "id": "011fb5a1",
"metadata": {},
"source": [
"### Run model inference to generate Mel spectrogram"
@@ -47,14 +47,14 @@
},
{
"cell_type": "code",
- "execution_count": 5,
- "id": "aac92f90",
+ "execution_count": 13,
+ "id": "b809fed5",
"metadata": {},
"outputs": [
{
"data": {
"application/vnd.jupyter.widget-view+json": {
- "model_id": "9fa5515ab1984c45bf459e9dfa12c3b9",
+ "model_id": "a7fe83b5914a437e99cf1838cb47f2b5",
"version_major": 2,
"version_minor": 0
},
@@ -69,12 +69,13 @@
"source": [
"model_id = \"../ddpm-ema-audio-64\"\n",
"ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference\n",
- "image = ddpm()[\"sample\"][0]"
+ "images = ddpm(output_type=\"numpy\")[\"sample\"]\n",
+ "images = (images * 255).round().astype(\"uint8\").transpose(0, 3, 1, 2)"
]
},
{
"cell_type": "markdown",
- "id": "df6c533b",
+ "id": "7230c280",
"metadata": {},
"source": [
"### Transform Mel spectrogram to audio"
@@ -82,8 +83,8 @@
},
{
"cell_type": "code",
- "execution_count": 8,
- "id": "37c24f43",
+ "execution_count": 18,
+ "id": "5f8a149d",
"metadata": {},
"outputs": [
{
@@ -91,7 +92,7 @@
"text/html": [
"\n",
" \n",
" "
@@ -100,19 +101,19 @@
""
]
},
- "execution_count": 8,
+ "execution_count": 18,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
- "audio = mel.image_to_audio(ImageOps.grayscale(image))\n",
+ "audio = mel.image_to_audio(Image.fromarray(images[0][0]))\n",
"Audio(data=audio, rate=mel.get_sample_rate())"
]
},
{
"cell_type": "markdown",
- "id": "10805113",
+ "id": "ef54cef3",
"metadata": {},
"source": [
"### Compare results with random sample from training set"
@@ -120,8 +121,8 @@
},
{
"cell_type": "code",
- "execution_count": 29,
- "id": "7a366813",
+ "execution_count": 21,
+ "id": "269ee816",
"metadata": {},
"outputs": [],
"source": [
@@ -130,8 +131,8 @@
},
{
"cell_type": "code",
- "execution_count": 38,
- "id": "55a29505",
+ "execution_count": 22,
+ "id": "492e2334",
"metadata": {},
"outputs": [
{
@@ -139,7 +140,7 @@
"text/html": [
"\n",
" \n",
" "
@@ -148,7 +149,7 @@
""
]
},
- "execution_count": 38,
+ "execution_count": 22,
"metadata": {},
"output_type": "execute_result"
}
@@ -161,7 +162,7 @@
{
"cell_type": "code",
"execution_count": null,
- "id": "afb1f699",
+ "id": "a8ae1a19",
"metadata": {},
"outputs": [],
"source": []