diff --git "a/notebooks/test-model.ipynb" "b/notebooks/test-model.ipynb" --- "a/notebooks/test-model.ipynb" +++ "b/notebooks/test-model.ipynb" @@ -2,8 +2,8 @@ "cells": [ { "cell_type": "code", - "execution_count": 1, - "id": "fd262b00", + "execution_count": 2, + "id": "b447e2c4", "metadata": {}, "outputs": [], "source": [ @@ -14,8 +14,8 @@ }, { "cell_type": "code", - "execution_count": 28, - "id": "d2253762", + "execution_count": 11, + "id": "c2fc0e7a", "metadata": {}, "outputs": [], "source": [ @@ -23,14 +23,14 @@ "from src.mel import Mel\n", "from PIL import ImageOps, Image\n", "from IPython.display import Audio\n", - "from datasets import load_from_disk\n", - "from diffusers import DDPMPipeline, DDIMPipeline, PNDMPipeline" + "from diffusers import DDPMPipeline\n", + "from datasets import load_from_disk" ] }, { "cell_type": "code", - "execution_count": 3, - "id": "293dd2c7", + "execution_count": 12, + "id": "a3d45c36", "metadata": {}, "outputs": [], "source": [ @@ -39,7 +39,7 @@ }, { "cell_type": "markdown", - "id": "5bdb2648", + "id": "011fb5a1", "metadata": {}, "source": [ "### Run model inference to generate Mel spectrogram" @@ -47,14 +47,14 @@ }, { "cell_type": "code", - "execution_count": 5, - "id": "aac92f90", + "execution_count": 13, + "id": "b809fed5", "metadata": {}, "outputs": [ { "data": { "application/vnd.jupyter.widget-view+json": { - "model_id": "9fa5515ab1984c45bf459e9dfa12c3b9", + "model_id": "a7fe83b5914a437e99cf1838cb47f2b5", "version_major": 2, "version_minor": 0 }, @@ -69,12 +69,13 @@ "source": [ "model_id = \"../ddpm-ema-audio-64\"\n", "ddpm = DDPMPipeline.from_pretrained(model_id) # you can replace DDPMPipeline with DDIMPipeline or PNDMPipeline for faster inference\n", - "image = ddpm()[\"sample\"][0]" + "images = ddpm(output_type=\"numpy\")[\"sample\"]\n", + "images = (images * 255).round().astype(\"uint8\").transpose(0, 3, 1, 2)" ] }, { "cell_type": "markdown", - "id": "df6c533b", + "id": "7230c280", "metadata": {}, "source": [ "### Transform Mel spectrogram to audio" @@ -82,8 +83,8 @@ }, { "cell_type": "code", - "execution_count": 8, - "id": "37c24f43", + "execution_count": 18, + "id": "5f8a149d", "metadata": {}, "outputs": [ { @@ -91,7 +92,7 @@ "text/html": [ "\n", " \n", " " @@ -100,19 +101,19 @@ "" ] }, - "execution_count": 8, + "execution_count": 18, "metadata": {}, "output_type": "execute_result" } ], "source": [ - "audio = mel.image_to_audio(ImageOps.grayscale(image))\n", + "audio = mel.image_to_audio(Image.fromarray(images[0][0]))\n", "Audio(data=audio, rate=mel.get_sample_rate())" ] }, { "cell_type": "markdown", - "id": "10805113", + "id": "ef54cef3", "metadata": {}, "source": [ "### Compare results with random sample from training set" @@ -120,8 +121,8 @@ }, { "cell_type": "code", - "execution_count": 29, - "id": "7a366813", + "execution_count": 21, + "id": "269ee816", "metadata": {}, "outputs": [], "source": [ @@ -130,8 +131,8 @@ }, { "cell_type": "code", - "execution_count": 38, - "id": "55a29505", + "execution_count": 22, + "id": "492e2334", "metadata": {}, "outputs": [ { @@ -139,7 +140,7 @@ "text/html": [ "\n", " \n", " " @@ -148,7 +149,7 @@ "" ] }, - "execution_count": 38, + "execution_count": 22, "metadata": {}, "output_type": "execute_result" } @@ -161,7 +162,7 @@ { "cell_type": "code", "execution_count": null, - "id": "afb1f699", + "id": "a8ae1a19", "metadata": {}, "outputs": [], "source": []