{ "cells": [ { "cell_type": "code", "execution_count": null, "id": "bcbbe26c", "metadata": {}, "outputs": [], "source": [ "import os\n", "import sys\n", "sys.path.insert(0, os.path.dirname(os.path.abspath(\"\")))" ] }, { "cell_type": "code", "execution_count": null, "id": "b451ab22", "metadata": {}, "outputs": [], "source": [ "import torch\n", "import random\n", "import numpy as np\n", "from PIL import Image\n", "from datasets import load_dataset\n", "from IPython.display import Audio\n", "from diffusers import AutoencoderKL\n", "from audiodiffusion.mel import Mel" ] }, { "cell_type": "code", "execution_count": null, "id": "324cef44", "metadata": {}, "outputs": [], "source": [ "mel = Mel()\n", "vae = AutoencoderKL.from_pretrained('../models/autoencoder-kl')" ] }, { "cell_type": "code", "execution_count": null, "id": "da55ce79", "metadata": {}, "outputs": [], "source": [ "vae.config" ] }, { "cell_type": "code", "execution_count": null, "id": "5fea99ff", "metadata": {}, "outputs": [], "source": [ "ds = load_dataset('teticio/audio-diffusion-256')" ] }, { "cell_type": "code", "execution_count": null, "id": "426c6edd", "metadata": {}, "outputs": [], "source": [ "image = random.choice(ds['train'])['image']\n", "display(image)\n", "Audio(data=mel.image_to_audio(image), rate=mel.get_sample_rate())" ] }, { "cell_type": "code", "execution_count": null, "id": "d123f8a0", "metadata": {}, "outputs": [], "source": [ "# encode\n", "input_image = np.frombuffer(image.convert('RGB').tobytes(), dtype=\"uint8\").reshape(\n", " (image.height, image.width, 3))\n", "input_image = ((input_image / 255) * 2 - 1).transpose(2, 0, 1)\n", "posterior = vae.encode(torch.tensor([input_image], dtype=torch.float32)).latent_dist\n", "latents = posterior.sample()" ] }, { "cell_type": "code", "execution_count": null, "id": "482c458f", "metadata": {}, "outputs": [], "source": [ "# reconstruct\n", "output_image = vae.decode(latents)['sample']\n", "output_image = torch.clamp(output_image, -1., 1.)\n", "output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w\n", "output_image = (output_image.detach().cpu().numpy() *\n", " 255).round().astype(\"uint8\").transpose(0, 2, 3, 1)[0]\n", "output_image = Image.fromarray(output_image).convert('L')\n", "display(output_image)\n", "Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())" ] }, { "cell_type": "code", "execution_count": null, "id": "f10db020", "metadata": {}, "outputs": [], "source": [ "# sample\n", "output_image = vae.decode(torch.randn_like(posterior.sample()))['sample']\n", "output_image = torch.clamp(output_image, -1., 1.)\n", "output_image = (output_image + 1.0) / 2.0 # -1,1 -> 0,1; c,h,w\n", "output_image = (output_image.detach().cpu().numpy() *\n", " 255).round().astype(\"uint8\").transpose(0, 2, 3, 1)[0]\n", "output_image = Image.fromarray(output_image).convert('L')\n", "display(output_image)\n", "Audio(data=mel.image_to_audio(output_image), rate=mel.get_sample_rate())" ] }, { "cell_type": "code", "execution_count": null, "id": "46019770", "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "huggingface", "language": "python", "name": "huggingface" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.10.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": false, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 5 }