diff --git "a/Training/MSML612_Project_DDPMTraining.ipynb" "b/Training/MSML612_Project_DDPMTraining.ipynb" new file mode 100644--- /dev/null +++ "b/Training/MSML612_Project_DDPMTraining.ipynb" @@ -0,0 +1,3092 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "A100", + "machine_shape": "hm" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Introduction\n", + "\n", + "This is a training script for the Sprite generation task. It uses a diffusion technique called DDPM.\n", + "\n", + "To get started you need following files:\n", + "\n", + "* data files:\n", + " - sprites_1788_16x16.npy (These are all images converted to numpy arrays)\n", + " - sprite_labels_nc_1788_16x16.npy (These are labels corresponding to the images)\n", + "\n" + ], + "metadata": { + "id": "3dginjdSFc2H" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Setup" + ], + "metadata": { + "id": "LCFKkR7kGM-d" + } + }, + { + "cell_type": "code", + "source": [ + "#install wandb for MLOps\n", + "!pip install wandb" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8NjPO1m4zdsr", + "outputId": "f0469d5f-3ac0-41c0-c510-5b31a89b7765" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting wandb\n", + " Downloading wandb-0.15.8-py3-none-any.whl (2.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.6)\n", + "Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)\n", + " Downloading GitPython-3.1.32-py3-none-any.whl (188 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m188.5/188.5 kB\u001b[0m \u001b[31m17.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.31.0)\n", + "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (5.9.5)\n", + "Collecting sentry-sdk>=1.0.0 (from wandb)\n", + " Downloading sentry_sdk-1.29.2-py2.py3-none-any.whl (215 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m215.6/215.6 kB\u001b[0m \u001b[31m18.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting docker-pycreds>=0.4.0 (from wandb)\n", + " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", + "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from wandb) (6.0.1)\n", + "Collecting pathtools (from wandb)\n", + " Downloading pathtools-0.1.2.tar.gz (11 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Collecting setproctitle (from wandb)\n", + " Downloading setproctitle-1.3.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n", + "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n", + "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", + "Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)\n", + " Downloading gitdb-4.0.10-py3-none-any.whl (62 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.2.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2.0.4)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2023.7.22)\n", + "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)\n", + " Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n", + "Building wheels for collected packages: pathtools\n", + " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8791 sha256=afe7dc05a24350c846b8cbc2b758b0a324a90709d2b5b4438b6c047351e3bad6\n", + " Stored in directory: /root/.cache/pip/wheels/e7/f3/22/152153d6eb222ee7a56ff8617d80ee5207207a8c00a7aab794\n", + "Successfully built pathtools\n", + "Installing collected packages: pathtools, smmap, setproctitle, sentry-sdk, docker-pycreds, gitdb, GitPython, wandb\n", + "Successfully installed GitPython-3.1.32 docker-pycreds-0.4.0 gitdb-4.0.10 pathtools-0.1.2 sentry-sdk-1.29.2 setproctitle-1.3.2 smmap-5.0.0 wandb-0.15.8\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "fBmYpo1QfAeD" + }, + "outputs": [], + "source": [ + "#utility imports\n", + "from typing import Dict, Tuple\n", + "from types import SimpleNamespace\n", + "from tqdm import tqdm\n", + "from pathlib import Path\n", + "import numpy as np\n", + "from IPython.display import HTML\n", + "import wandb\n", + "\n", + "#PyTorch imports\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader\n", + "from torchvision import models, transforms\n", + "from torchvision.utils import save_image, make_grid\n", + "from torch.utils.data import Dataset\n", + "\n", + "#plotting imports\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.animation import FuncAnimation, PillowWriter" + ] + }, + { + "cell_type": "code", + "source": [ + "#login to WandB: You'll need an API key\n", + "wandb.login(anonymous=\"allow\")" + ], + "metadata": { + "id": "RiNhbJIzlo04", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 70 + }, + "outputId": "389ce201-0794-4cc8-ff7c-1987f1905774" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " window._wandbApiKey = new Promise((resolve, reject) => {\n", + " function loadScript(url) {\n", + " return new Promise(function(resolve, reject) {\n", + " let newScript = document.createElement(\"script\");\n", + " newScript.onerror = reject;\n", + " newScript.onload = resolve;\n", + " document.body.appendChild(newScript);\n", + " newScript.src = url;\n", + " });\n", + " }\n", + " loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n", + " const iframe = document.createElement('iframe')\n", + " iframe.style.cssText = \"width:0;height:0;border:none\"\n", + " document.body.appendChild(iframe)\n", + " const handshake = new Postmate({\n", + " container: iframe,\n", + " url: 'https://wandb.ai/authorize'\n", + " });\n", + " const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n", + " handshake.then(function(child) {\n", + " child.on('authorize', data => {\n", + " clearTimeout(timeout)\n", + " resolve(data)\n", + " });\n", + " });\n", + " })\n", + " });\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": {}, + "execution_count": 3 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Model Architecture" + ], + "metadata": { + "id": "X5hf7yPFGPaO" + } + }, + { + "cell_type": "code", + "source": [ + "class ResidualConvBlock(nn.Module):\n", + " \"\"\"\n", + " This class creates a residual convolution block\n", + " \"\"\"\n", + " def __init__(\n", + " self, in_channels: int, out_channels: int, is_res: bool = False\n", + " ) -> None:\n", + " super().__init__()\n", + "\n", + " # Check if input and output channels are the same for the residual connection\n", + " self.same_channels = in_channels == out_channels\n", + "\n", + " # Flag for whether or not to use residual connection\n", + " self.is_res = is_res\n", + "\n", + " # First convolutional layer\n", + " self.conv1 = nn.Sequential(\n", + " nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1\n", + " nn.BatchNorm2d(out_channels), # Batch normalization\n", + " nn.GELU(), # GELU activation function\n", + " )\n", + "\n", + " # Second convolutional layer\n", + " self.conv2 = nn.Sequential(\n", + " nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1\n", + " nn.BatchNorm2d(out_channels), # Batch normalization\n", + " nn.GELU(), # GELU activation function\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + "\n", + " # If using residual connection\n", + " if self.is_res:\n", + " # Apply first convolutional layer\n", + " x1 = self.conv1(x)\n", + "\n", + " # Apply second convolutional layer\n", + " x2 = self.conv2(x1)\n", + "\n", + " # If input and output channels are the same, add residual connection directly\n", + " if self.same_channels:\n", + " out = x + x2\n", + " else:\n", + " # If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection\n", + " shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device)\n", + " out = shortcut(x) + x2\n", + "\n", + " # Normalize output tensor\n", + " return out / 1.414\n", + "\n", + " # If not using residual connection, return output of second convolutional layer\n", + " else:\n", + " x1 = self.conv1(x)\n", + " x2 = self.conv2(x1)\n", + " return x2\n", + "\n", + " # Method to get the number of output channels for this block\n", + " def get_out_channels(self):\n", + " return self.conv2[0].out_channels\n", + "\n", + " # Method to set the number of output channels for this block\n", + " def set_out_channels(self, out_channels):\n", + " self.conv1[0].out_channels = out_channels\n", + " self.conv2[0].in_channels = out_channels\n", + " self.conv2[0].out_channels = out_channels" + ], + "metadata": { + "id": "ob7A2BIULPN-" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "class UnetUp(nn.Module):\n", + " def __init__(self, in_channels, out_channels):\n", + " super(UnetUp, self).__init__()\n", + "\n", + " # Create a list of layers for the upsampling block\n", + " # The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers\n", + " layers = [\n", + " nn.ConvTranspose2d(in_channels, out_channels, 2, 2),\n", + " ResidualConvBlock(out_channels, out_channels),\n", + " ResidualConvBlock(out_channels, out_channels),\n", + " ]\n", + "\n", + " # Use the layers to create a sequential model\n", + " self.model = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x, skip):\n", + " # Concatenate the input tensor x with the skip connection tensor along the channel dimension\n", + " x = torch.cat((x, skip), 1)\n", + "\n", + " # Get output\n", + " x = self.model(x)\n", + " return x" + ], + "metadata": { + "id": "DEFXf2rJLU3W" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "class UnetDown(nn.Module):\n", + " def __init__(self, in_channels, out_channels):\n", + " super(UnetDown, self).__init__()\n", + "\n", + " # Create a list of layers for the downsampling block\n", + " layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)]\n", + "\n", + " # Use the layers to create a sequential model\n", + " self.model = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x):\n", + " # Pass the input through the sequential model and return the output\n", + " return self.model(x)" + ], + "metadata": { + "id": "s4Lh0wPhLZFr" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "class EmbedFC(nn.Module):\n", + " \"\"\"\n", + " This class defines a generic one layer feed-forward neural network for embedding input data of\n", + " dimensionality input_dim to an embedding space of dimensionality emb_dim.\n", + " \"\"\"\n", + " def __init__(self, input_dim, emb_dim):\n", + " super(EmbedFC, self).__init__()\n", + " self.input_dim = input_dim\n", + "\n", + " # define the layers for the network\n", + " layers = [\n", + " nn.Linear(input_dim, emb_dim),\n", + " nn.GELU(),\n", + " nn.Linear(emb_dim, emb_dim),\n", + " ]\n", + "\n", + " # create a PyTorch sequential model consisting of the defined layers\n", + " self.model = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x):\n", + " # flatten the input tensor\n", + " x = x.view(-1, self.input_dim)\n", + " # apply the model layers to the flattened tensor\n", + " return self.model(x)" + ], + "metadata": { + "id": "bgVfZKY4LaYf" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#Blueprint of Unet\n", + "class ContextUnet(nn.Module):\n", + " \"\"\"\n", + " A Unet that is conditioned on context\n", + " \"\"\"\n", + " def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28): # cfeat - context features\n", + " \"\"\"\n", + " Constructor\n", + " \"\"\"\n", + "\n", + " #call parent class constructor\n", + " super(ContextUnet, self).__init__()\n", + "\n", + " # setup input channels, number of intermediate feature maps and number of classes\n", + " self.in_channels = in_channels\n", + " self.n_feat = n_feat\n", + " self.n_cfeat = n_cfeat\n", + " self.h = height #height and width of image. Divisible by 4\n", + "\n", + " # Initialize the convolution layer\n", + " self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)\n", + "\n", + " # Two level down sampling\n", + " self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8]\n", + " self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4]\n", + "\n", + " #vectorize\n", + " self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())\n", + "\n", + " # Embed the timestep and context labels with a one-layer fully connected neural network\n", + " self.timeembed1 = EmbedFC(1, 2*n_feat)\n", + " self.timeembed2 = EmbedFC(1, 1*n_feat)\n", + " self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)\n", + " self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)\n", + "\n", + " # 3 level upsampling\n", + " self.up0 = nn.Sequential(\n", + " nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample\n", + " nn.GroupNorm(8, 2 * n_feat), # normalize\n", + " nn.ReLU(),\n", + " )\n", + " self.up1 = UnetUp(4 * n_feat, n_feat)\n", + " self.up2 = UnetUp(2 * n_feat, n_feat)\n", + "\n", + " # Final convolutional layer\n", + " self.out = nn.Sequential(\n", + " nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),\n", + " nn.GroupNorm(8, n_feat), # normalize\n", + " nn.ReLU(),\n", + " nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input\n", + " )\n", + "\n", + " def forward(self, x, t, c=None):\n", + " \"\"\"\n", + " x : (batch, n_feat, h, w) : input image\n", + " t : (batch, n_cfeat) : time step\n", + " c : (batch, n_classes) : context label\n", + " \"\"\"\n", + "\n", + " # pass the input image through the initial convolutional layer\n", + " x = self.init_conv(x)\n", + "\n", + " # pass the result through the down-sampling path\n", + " down1 = self.down1(x) #[10, 256, 8, 8]\n", + " down2 = self.down2(down1) #[10, 256, 4, 4]\n", + "\n", + " # convert the feature maps to a vector and apply an activation\n", + " hiddenvec = self.to_vec(down2)\n", + "\n", + " # mask out context if context_mask == 1\n", + " if c is None:\n", + " c = torch.zeros(x.shape[0], self.n_cfeat).to(x)\n", + "\n", + " # embed context and timestep\n", + " cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) # (batch, 2*n_feat, 1,1)\n", + " temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)\n", + " cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)\n", + " temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)\n", + "\n", + " #upsample\n", + " up1 = self.up0(hiddenvec)\n", + " up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings\n", + " up3 = self.up2(cemb2*up2 + temb2, down1)\n", + " out = self.out(torch.cat((up3, x), 1))\n", + " return out" + ], + "metadata": { + "id": "gKt2pXAgfq1F" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Setup for saving the data and weights\n", + "DATA_DIR = Path('./data/')\n", + "SAVE_DIR = Path('./data/weights/')\n", + "SAVE_DIR.mkdir(exist_ok=True, parents=True)\n", + "save_dir = SAVE_DIR\n", + "data_dir = DATA_DIR\n", + "\n", + "config = SimpleNamespace(\n", + " # hyperparameters for WandB\n", + " num_samples = 30,\n", + "\n", + " # diffusion hyperparameters\n", + " timesteps = 500,\n", + " beta1 = 1e-4,\n", + " beta2 = 0.02,\n", + "\n", + " # network hyperparameters\n", + " n_feat = 64, # 64 hidden dimension feature\n", + " n_cfeat = 5, # context vector is of size 5\n", + " height = 16, # 16x16 image\n", + "\n", + " # training hyperparameters\n", + " batch_size = 100,\n", + " n_epoch = 128,\n", + " lrate = 1e-3,\n", + ")\n", + "\n", + "# Get device\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else torch.device('cpu'))\n", + "print(device)\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cppkUCRdfyzh", + "outputId": "c68fd65f-4722-4fc1-f476-75c5acd8e69d" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "cuda:0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# DDPM Algorithm" + ], + "metadata": { + "id": "gt_RQHa0H4Q1" + } + }, + { + "cell_type": "code", + "source": [ + "def setup_ddpm(beta1, beta2, timesteps, device):\n", + "\n", + " # construct DDPM noise schedule and sampling functions\n", + " b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1\n", + " a_t = 1 - b_t\n", + " ab_t = torch.cumsum(a_t.log(), dim=0).exp()\n", + " ab_t[0] = 1\n", + "\n", + " # helper function: perturbs an image to a specified noise level\n", + " def perturb_input(x, t, noise):\n", + " return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise\n", + "\n", + " # helper function; removes the predicted noise (but adds some noise back in to avoid collapse)\n", + " def _denoise_add_noise(x, t, pred_noise, z=None):\n", + " if z is None:\n", + " z = torch.randn_like(x)\n", + " noise = b_t.sqrt()[t] * z\n", + " mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()\n", + " return mean + noise\n", + "\n", + " # sample with context using DDPM algorithm\n", + " @torch.no_grad()\n", + " def sample_ddpm_context(nn_model, noises, context, save_rate=20):\n", + " # array to keep track of generated steps for plotting\n", + " intermediate = []\n", + " pbar = tqdm(range(timesteps, 0, -1), leave=False)\n", + " for i in pbar:\n", + " pbar.set_description(f'sampling timestep {i:3d}')\n", + "\n", + " # reshape time tensor\n", + " t = torch.tensor([i / timesteps])[:, None, None, None].to(noises.device)\n", + "\n", + " # Add some noise back in if i is not 1\n", + " z = torch.randn_like(noises) if i > 1 else 0\n", + "\n", + " eps = nn_model(noises, t, c=context) # predict noise\n", + " noises = _denoise_add_noise(noises, i, eps, z)\n", + " if i % save_rate==0 or i==timesteps or i<8:\n", + " intermediate.append(noises.detach().cpu().numpy())\n", + "\n", + " intermediate = np.stack(intermediate)\n", + " return noises.clip(-1, 1), intermediate\n", + "\n", + " return perturb_input, sample_ddpm_context" + ], + "metadata": { + "id": "blicU3wqf6d2" + }, + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#create an instance of the model\n", + "nn_model = ContextUnet(in_channels=3, n_feat=config.n_feat, n_cfeat=config.n_cfeat, height=config.height).to(device)\n" + ], + "metadata": { + "id": "rxjrXri5gI42" + }, + "execution_count": 11, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Data" + ], + "metadata": { + "id": "IKyT93CVIQ-5" + } + }, + { + "cell_type": "code", + "source": [ + "class CustomDataset(Dataset):\n", + " def __init__(self, sfilename, lfilename, transform, null_context=False):\n", + "\n", + " #load image arrays and labels\n", + " self.sprites = np.load(sfilename)\n", + " self.slabels = np.load(lfilename)\n", + " print(f\"sprite shape: {self.sprites.shape}\")\n", + " print(f\"labels shape: {self.slabels.shape}\")\n", + "\n", + " #get transformations to be done\n", + " self.transform = transform\n", + " self.null_context = null_context\n", + " self.sprites_shape = self.sprites.shape\n", + " self.slabel_shape = self.slabels.shape\n", + "\n", + " # Return the number of images in the dataset\n", + " def __len__(self):\n", + " return len(self.sprites)\n", + "\n", + " # Get the image and label at a given index\n", + " def __getitem__(self, idx):\n", + " # Return the image and label as a tuple\n", + " if self.transform:\n", + " image = self.transform(self.sprites[idx])\n", + " if self.null_context:\n", + " label = torch.tensor(0).to(torch.int64)\n", + " else:\n", + " label = torch.tensor(self.slabels[idx]).to(torch.int64)\n", + " return (image, label)\n", + "\n", + " def getshapes(self):\n", + " # return shapes of data and labels\n", + " return self.sprites_shape, self.slabel_shape\n", + "\n", + "transform = transforms.Compose([\n", + " transforms.ToTensor(), # from [0,255] to range [0.0,1.0]\n", + " transforms.Normalize((0.5,), (0.5,)) # range [-1,1]\n", + "\n", + "])" + ], + "metadata": { + "id": "GwAQE0g6LxY9" + }, + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# load dataset and construct optimizer\n", + "dataset = CustomDataset(\"./sprites_1788_16x16.npy\", \"./sprite_labels_nc_1788_16x16.npy\", transform, null_context=False)\n", + "dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=1)\n", + "optim = torch.optim.Adam(nn_model.parameters(), lr=config.lrate)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "R7xWQnO7gTRu", + "outputId": "c7b5a888-7ebd-4879-8a7f-b033794c9546" + }, + "execution_count": 13, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "sprite shape: (89400, 16, 16, 3)\n", + "labels shape: (89400, 5)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# helper function: perturbs an image to a specified noise level\n", + "perturb_input, sample_ddpm_context = setup_ddpm(config.beta1,\n", + " config.beta2,\n", + " config.timesteps,\n", + " device)" + ], + "metadata": { + "id": "OkQs7K29gY1r" + }, + "execution_count": 14, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "noises = torch.randn(config.num_samples, 3,\n", + " config.height, config.height).to(device)\n", + "\n", + "# A fixed context vector to sample from\n", + "ctx_vector = F.one_hot(torch.tensor([0,0,0,0,0,0, # hero\n", + " 1,1,1,1,1,1, # non-hero\n", + " 2,2,2,2,2,2, # food\n", + " 3,3,3,3,3,3, # spell\n", + " 4,4,4,4,4,4]), # side-facing\n", + " 5).to(device).float()" + ], + "metadata": { + "id": "7x7h_bJA6L39" + }, + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Training" + ], + "metadata": { + "id": "hPL2wdBNJy08" + } + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "\n", + "# set into train mode\n", + "nn_model.train()\n", + "\n", + "#initialize wandb project\n", + "run = wandb.init(project=\"msml612_sprite\",\n", + " job_type=\"train\",\n", + " config=config)\n", + "\n", + "# pass config to wandb\n", + "config = wandb.config\n", + "\n", + "#train\n", + "for ep in range(config.n_epoch):\n", + " print(f'epoch {ep}')\n", + "\n", + " # LR Decay (linear)\n", + " optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)\n", + "\n", + " #progres bar\n", + " pbar = tqdm(dataloader, mininterval=2 )\n", + " for x, c in pbar: # x: images c: context\n", + " optim.zero_grad()\n", + " x = x.to(device)\n", + " c = c.to(x)\n", + "\n", + " # randomly mask out c\n", + " context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.9).to(device)\n", + " c = c * context_mask.unsqueeze(-1)\n", + "\n", + " # perturb data\n", + " noise = torch.randn_like(x)\n", + " t = torch.randint(1, config.timesteps + 1, (x.shape[0],)).to(device)\n", + " x_pert = perturb_input(x, t, noise)\n", + "\n", + " # use network to recover noise\n", + " pred_noise = nn_model(x_pert, t / config.timesteps, c=c)\n", + "\n", + " # loss is mean squared error between the predicted and true noise\n", + " loss = F.mse_loss(pred_noise, noise)\n", + " loss.backward()\n", + "\n", + " optim.step()\n", + "\n", + " wandb.log({\"loss\": loss.item(),\n", + " \"lr\": optim.param_groups[0]['lr'],\n", + " \"epoch\": ep})\n", + "\n", + " # save model periodically\n", + " if ep%4==0 or ep == int(config.n_epoch-1):\n", + " nn_model.eval()\n", + " ckpt_file = SAVE_DIR/f\"context_model.pth\"\n", + " torch.save(nn_model.state_dict(), ckpt_file)\n", + "\n", + " #save to WandB\n", + " artifact_name = f\"{wandb.run.id}_context_model\"\n", + " at = wandb.Artifact(artifact_name, type=\"model\")\n", + " at.add_file(ckpt_file)\n", + " wandb.log_artifact(at, aliases=[f\"epoch_{ep}\"])\n", + "\n", + " samples, _ = sample_ddpm_context(nn_model,\n", + " noises,\n", + " ctx_vector[:config.num_samples])\n", + "\n", + " #Save samples\n", + " wandb.log({\n", + " \"train_samples\": [\n", + " wandb.Image(img) for img in samples.split(1)\n", + " ]})\n", + "wandb.finish()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "do0GE6yHgfPG", + "outputId": "4c31427f-07dc-4acc-b71f-b083401356b2" + }, + "execution_count": 16, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33madityapatkar\u001b[0m (\u001b[33mteamaditya\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Tracking run with wandb version 0.15.8" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Run data is saved locally in /content/wandb/run-20230813_092510-01ihqkt2" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Syncing run absurd-paper-9 to Weights & Biases (docs)
" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View project at https://wandb.ai/teamaditya/msml612_sprite" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View run at https://wandb.ai/teamaditya/msml612_sprite/runs/01ihqkt2" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 0\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:23<00:00, 37.98it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 1\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 56.82it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 2\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.92it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 3\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.48it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 4\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.29it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 5\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.09it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 6\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.59it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 7\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.16it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 8\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.86it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 9\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.08it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 10\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.33it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 11\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.97it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 12\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.74it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 13\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 59.73it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 14\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.73it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 15\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.81it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 16\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.96it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 17\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 59.90it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 18\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.68it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 19\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.28it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 20\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.87it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 21\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 59.33it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 22\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.64it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 23\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.56it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 24\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.50it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 25\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.06it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 26\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.13it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 27\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.16it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 28\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.58it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 29\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.97it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 30\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.40it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 31\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.98it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 32\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.75it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 33\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.37it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 34\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.19it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 35\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.53it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 36\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.14it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 37\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.53it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 38\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.51it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 39\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.52it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 40\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.47it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 41\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.91it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 42\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 63.92it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 43\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.48it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 44\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.59it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 45\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.33it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 46\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.61it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 47\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.65it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 48\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.60it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 49\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.83it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 50\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 58.65it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 51\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.79it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 52\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.02it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 53\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.84it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 54\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.10it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 55\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.16it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 56\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.95it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 57\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.03it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 58\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.94it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 59\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.67it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 60\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.65it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 61\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.30it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 62\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.48it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 63\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.65it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 64\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.05it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 65\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.25it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 66\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.74it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 67\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.84it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 68\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.28it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 69\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.70it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 70\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.60it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 71\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.50it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 72\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.32it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 73\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.57it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 74\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.49it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 75\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.27it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 76\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.41it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 77\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 57.84it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 78\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.12it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 79\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.94it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 80\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.48it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 81\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 58.05it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 82\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.53it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 83\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.59it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 84\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.76it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 85\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.89it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 86\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.28it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 87\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.33it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 88\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.68it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 89\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 58.72it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 90\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.80it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 91\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.45it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 92\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.37it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 93\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.36it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 94\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.85it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 95\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.24it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 96\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.93it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 97\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.35it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 98\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.17it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 99\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.49it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 100\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.10it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 101\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.31it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 102\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.73it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 103\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.29it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 104\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.27it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 105\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 58.13it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 106\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.19it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 107\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.11it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 108\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.69it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 109\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 57.03it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 110\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.12it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 111\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 63.96it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 112\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.96it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 113\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.29it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 114\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.19it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 115\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.02it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 116\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.18it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 117\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 58.33it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 118\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.87it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 119\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 63.99it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 120\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.69it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 121\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.09it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 122\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.29it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 123\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.25it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 124\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 63.87it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 125\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.88it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 126\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.77it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 127\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.05it/s]\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Waiting for W&B process to finish... (success)." + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "

Run history:


epoch▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss█▅▄▄▄▄▃▂▄▃▃▃▃▃▂▂▂▄▃▃▃▂▄▂▂▁▃▂▂▃▂▃▁▂▁▁▃▁▁▂
lr███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

Run summary:


epoch127
loss0.0658
lr1e-05

" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View run absurd-paper-9 at: https://wandb.ai/teamaditya/msml612_sprite/runs/01ihqkt2
Synced 5 W&B file(s), 990 media file(s), 33 artifact file(s) and 0 other file(s)" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Find logs at: ./wandb/run-20230813_092510-01ihqkt2/logs" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Inference" + ], + "metadata": { + "id": "d6vF6aPcJ2Bp" + } + }, + { + "cell_type": "code", + "source": [ + "#setup DDPM\n", + "_, sample_ddpm_context = setup_ddpm(config.beta1,\n", + " config.beta2,\n", + " config.timesteps,\n", + " device)" + ], + "metadata": { + "id": "0riboFnkiPXy" + }, + "execution_count": 27, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#create noise\n", + "noises = torch.randn(config.num_samples, 3,\n", + " config.height, config.height).to(device)\n", + "\n", + "# A fixed context vector to sample from\n", + "ctx_vector = F.one_hot(torch.tensor([0,0,0,0,0,0, # hero\n", + " 1,1,1,1,1,1, # non-hero\n", + " 2,2,2,2,2,2, # food\n", + " 3,3,3,3,3,3, # spell\n", + " 4,4,4,4,4,4]), # side-facing\n", + " 5).to(device).float()" + ], + "metadata": { + "id": "rXb9q47biWGR" + }, + "execution_count": 18, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "ddpm_samples, _ = sample_ddpm_context(nn_model, noises, ctx_vector)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AzEPoaJXAiPu", + "outputId": "0ac84abb-7a37-4679-ed06-c7180514f8ea" + }, + "execution_count": 19, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [] + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "vgVhBlxyVxYP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#create a wandb table\n", + "table = wandb.Table(columns=[\"input_noise\", \"ddpm\", \"class\"])" + ], + "metadata": { + "id": "cVSXoNPfAnSW" + }, + "execution_count": 20, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#get classes for each ctx vector\n", + "def to_classes(ctx_vector):\n", + " classes = \"hero,non-hero,food,spell,side-facing\".split(\",\")\n", + " return [classes[i] for i in ctx_vector.argmax(dim=1)]" + ], + "metadata": { + "id": "-Ynl0Ch8AzyC" + }, + "execution_count": 21, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#add images to WandB\n", + "for noise, ddpm_s, c in zip(noises,\n", + " ddpm_samples,\n", + " to_classes(ctx_vector)):\n", + "\n", + " # add data row by row to the Table\n", + " table.add_data(wandb.Image(noise),\n", + " wandb.Image(ddpm_s),\n", + " c)" + ], + "metadata": { + "id": "KG_sbuOyAogK" + }, + "execution_count": 29, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#save samples to wandb\n", + "with wandb.init(project=\"msml612_sprite\",\n", + " job_type=\"samples\",\n", + " config=config):\n", + "\n", + " wandb.log({\"samplers_table\":table})" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 174 + }, + "id": "nCNI3tRkA8ON", + "outputId": "91906c1d-76a3-4078-a6be-5b95695d84fe" + }, + "execution_count": 23, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Tracking run with wandb version 0.15.8" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Run data is saved locally in /content/wandb/run-20230813_095756-2j5113bc" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Syncing run glamorous-mountain-10 to Weights & Biases (docs)
" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View project at https://wandb.ai/teamaditya/msml612_sprite" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View run at https://wandb.ai/teamaditya/msml612_sprite/runs/2j5113bc" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Waiting for W&B process to finish... (success)." + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View run glamorous-mountain-10 at: https://wandb.ai/teamaditya/msml612_sprite/runs/2j5113bc
Synced 4 W&B file(s), 1 media file(s), 61 artifact file(s) and 0 other file(s)" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Find logs at: ./wandb/run-20230813_095756-2j5113bc/logs" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [ + "#Function to show images\n", + "def show_images(imgs, nrow=2):\n", + " _, axs = plt.subplots(nrow, imgs.shape[0] // nrow, figsize=(4,2 ))\n", + " axs = axs.flatten()\n", + " for img, ax in zip(imgs, axs):\n", + " img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + " ax.imshow(img)\n", + " plt.show()" + ], + "metadata": { + "id": "1CoL9AMJiYow" + }, + "execution_count": 24, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# user defined context\n", + "ctx = torch.tensor([\n", + " # hero, non-hero, food, spell, side-facing\n", + " [1,0,0,0,0],\n", + " [1,0,0,0,0],\n", + " [0,0,0,0,1],\n", + " [0,0,0,0,1],\n", + " [0,1,0,0,0],\n", + " [0,1,0,0,0],\n", + " [0,0,1,0,0],\n", + " [0,0,1,0,0],\n", + "]).float().to(device)\n", + "\n", + "noises = torch.randn(8, 3,\n", + " config.height, config.height).to(device)\n", + "\n", + "samples, _ = sample_ddpm_context(nn_model, noises, ctx)\n", + "show_images(samples)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 188 + }, + "id": "Z-XhQD-XiZGs", + "outputId": "2f48ecd8-b9b1-4fc1-f620-e681e1078840" + }, + "execution_count": 37, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [ + "# mix of defined context\n", + "ctx = torch.tensor([\n", + " # hero, non-hero, food, spell, side-facing\n", + " [1,0,0,0,0], #human\n", + " [1,0,0.6,0,0],\n", + " [0,0,0.6,0.4,0],\n", + " [1,0,0,0,1],\n", + " [1,1,0,0,0],\n", + " [1,0,0,1,0],\n", + " [1,0,0,0,0], #human\n", + " [0,1,0,0,0],\n", + "]).float().to(device)\n", + "samples, _ = sample_ddpm_context(nn_model, noises, ctx)\n", + "show_images(samples)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 188 + }, + "id": "YsJoy_Ahichx", + "outputId": "c04e7ad9-9ba9-48e9-c578-6c17b60435e0" + }, + "execution_count": 42, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "\n" + }, + "metadata": {} + } + ] + } + ] +} \ No newline at end of file