diff --git "a/Training/MSML612_Project_DDPMTraining.ipynb" "b/Training/MSML612_Project_DDPMTraining.ipynb" new file mode 100644--- /dev/null +++ "b/Training/MSML612_Project_DDPMTraining.ipynb" @@ -0,0 +1,3092 @@ +{ + "nbformat": 4, + "nbformat_minor": 0, + "metadata": { + "colab": { + "provenance": [], + "gpuType": "A100", + "machine_shape": "hm" + }, + "kernelspec": { + "name": "python3", + "display_name": "Python 3" + }, + "language_info": { + "name": "python" + }, + "accelerator": "GPU" + }, + "cells": [ + { + "cell_type": "markdown", + "source": [ + "# Introduction\n", + "\n", + "This is a training script for the Sprite generation task. It uses a diffusion technique called DDPM.\n", + "\n", + "To get started you need following files:\n", + "\n", + "* data files:\n", + " - sprites_1788_16x16.npy (These are all images converted to numpy arrays)\n", + " - sprite_labels_nc_1788_16x16.npy (These are labels corresponding to the images)\n", + "\n" + ], + "metadata": { + "id": "3dginjdSFc2H" + } + }, + { + "cell_type": "markdown", + "source": [ + "# Setup" + ], + "metadata": { + "id": "LCFKkR7kGM-d" + } + }, + { + "cell_type": "code", + "source": [ + "#install wandb for MLOps\n", + "!pip install wandb" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "8NjPO1m4zdsr", + "outputId": "f0469d5f-3ac0-41c0-c510-5b31a89b7765" + }, + "execution_count": 1, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "Collecting wandb\n", + " Downloading wandb-0.15.8-py3-none-any.whl (2.1 MB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m2.1/2.1 MB\u001b[0m \u001b[31m14.4 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: Click!=8.0.0,>=7.1 in /usr/local/lib/python3.10/dist-packages (from wandb) (8.1.6)\n", + "Collecting GitPython!=3.1.29,>=1.0.0 (from wandb)\n", + " Downloading GitPython-3.1.32-py3-none-any.whl (188 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m188.5/188.5 kB\u001b[0m \u001b[31m17.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: requests<3,>=2.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (2.31.0)\n", + "Requirement already satisfied: psutil>=5.0.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (5.9.5)\n", + "Collecting sentry-sdk>=1.0.0 (from wandb)\n", + " Downloading sentry_sdk-1.29.2-py2.py3-none-any.whl (215 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m215.6/215.6 kB\u001b[0m \u001b[31m18.9 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hCollecting docker-pycreds>=0.4.0 (from wandb)\n", + " Downloading docker_pycreds-0.4.0-py2.py3-none-any.whl (9.0 kB)\n", + "Requirement already satisfied: PyYAML in /usr/local/lib/python3.10/dist-packages (from wandb) (6.0.1)\n", + "Collecting pathtools (from wandb)\n", + " Downloading pathtools-0.1.2.tar.gz (11 kB)\n", + " Preparing metadata (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + "Collecting setproctitle (from wandb)\n", + " Downloading setproctitle-1.3.2-cp310-cp310-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl (30 kB)\n", + "Requirement already satisfied: setuptools in /usr/local/lib/python3.10/dist-packages (from wandb) (67.7.2)\n", + "Requirement already satisfied: appdirs>=1.4.3 in /usr/local/lib/python3.10/dist-packages (from wandb) (1.4.4)\n", + "Requirement already satisfied: protobuf!=4.21.0,<5,>=3.19.0 in /usr/local/lib/python3.10/dist-packages (from wandb) (3.20.3)\n", + "Requirement already satisfied: six>=1.4.0 in /usr/local/lib/python3.10/dist-packages (from docker-pycreds>=0.4.0->wandb) (1.16.0)\n", + "Collecting gitdb<5,>=4.0.1 (from GitPython!=3.1.29,>=1.0.0->wandb)\n", + " Downloading gitdb-4.0.10-py3-none-any.whl (62 kB)\n", + "\u001b[2K \u001b[90m━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━━\u001b[0m \u001b[32m62.7/62.7 kB\u001b[0m \u001b[31m9.7 MB/s\u001b[0m eta \u001b[36m0:00:00\u001b[0m\n", + "\u001b[?25hRequirement already satisfied: charset-normalizer<4,>=2 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.2.0)\n", + "Requirement already satisfied: idna<4,>=2.5 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (3.4)\n", + "Requirement already satisfied: urllib3<3,>=1.21.1 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2.0.4)\n", + "Requirement already satisfied: certifi>=2017.4.17 in /usr/local/lib/python3.10/dist-packages (from requests<3,>=2.0.0->wandb) (2023.7.22)\n", + "Collecting smmap<6,>=3.0.1 (from gitdb<5,>=4.0.1->GitPython!=3.1.29,>=1.0.0->wandb)\n", + " Downloading smmap-5.0.0-py3-none-any.whl (24 kB)\n", + "Building wheels for collected packages: pathtools\n", + " Building wheel for pathtools (setup.py) ... \u001b[?25l\u001b[?25hdone\n", + " Created wheel for pathtools: filename=pathtools-0.1.2-py3-none-any.whl size=8791 sha256=afe7dc05a24350c846b8cbc2b758b0a324a90709d2b5b4438b6c047351e3bad6\n", + " Stored in directory: /root/.cache/pip/wheels/e7/f3/22/152153d6eb222ee7a56ff8617d80ee5207207a8c00a7aab794\n", + "Successfully built pathtools\n", + "Installing collected packages: pathtools, smmap, setproctitle, sentry-sdk, docker-pycreds, gitdb, GitPython, wandb\n", + "Successfully installed GitPython-3.1.32 docker-pycreds-0.4.0 gitdb-4.0.10 pathtools-0.1.2 sentry-sdk-1.29.2 setproctitle-1.3.2 smmap-5.0.0 wandb-0.15.8\n" + ] + } + ] + }, + { + "cell_type": "code", + "execution_count": 2, + "metadata": { + "id": "fBmYpo1QfAeD" + }, + "outputs": [], + "source": [ + "#utility imports\n", + "from typing import Dict, Tuple\n", + "from types import SimpleNamespace\n", + "from tqdm import tqdm\n", + "from pathlib import Path\n", + "import numpy as np\n", + "from IPython.display import HTML\n", + "import wandb\n", + "\n", + "#PyTorch imports\n", + "import torch\n", + "import torch.nn as nn\n", + "import torch.nn.functional as F\n", + "from torch.utils.data import DataLoader\n", + "from torchvision import models, transforms\n", + "from torchvision.utils import save_image, make_grid\n", + "from torch.utils.data import Dataset\n", + "\n", + "#plotting imports\n", + "import matplotlib.pyplot as plt\n", + "from matplotlib.animation import FuncAnimation, PillowWriter" + ] + }, + { + "cell_type": "code", + "source": [ + "#login to WandB: You'll need an API key\n", + "wandb.login(anonymous=\"allow\")" + ], + "metadata": { + "id": "RiNhbJIzlo04", + "colab": { + "base_uri": "https://localhost:8080/", + "height": 70 + }, + "outputId": "389ce201-0794-4cc8-ff7c-1987f1905774" + }, + "execution_count": 3, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "application/javascript": [ + "\n", + " window._wandbApiKey = new Promise((resolve, reject) => {\n", + " function loadScript(url) {\n", + " return new Promise(function(resolve, reject) {\n", + " let newScript = document.createElement(\"script\");\n", + " newScript.onerror = reject;\n", + " newScript.onload = resolve;\n", + " document.body.appendChild(newScript);\n", + " newScript.src = url;\n", + " });\n", + " }\n", + " loadScript(\"https://cdn.jsdelivr.net/npm/postmate/build/postmate.min.js\").then(() => {\n", + " const iframe = document.createElement('iframe')\n", + " iframe.style.cssText = \"width:0;height:0;border:none\"\n", + " document.body.appendChild(iframe)\n", + " const handshake = new Postmate({\n", + " container: iframe,\n", + " url: 'https://wandb.ai/authorize'\n", + " });\n", + " const timeout = setTimeout(() => reject(\"Couldn't auto authenticate\"), 5000)\n", + " handshake.then(function(child) {\n", + " child.on('authorize', data => {\n", + " clearTimeout(timeout)\n", + " resolve(data)\n", + " });\n", + " });\n", + " })\n", + " });\n", + " " + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Appending key for api.wandb.ai to your netrc file: /root/.netrc\n" + ] + }, + { + "output_type": "execute_result", + "data": { + "text/plain": [ + "True" + ] + }, + "metadata": {}, + "execution_count": 3 + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Model Architecture" + ], + "metadata": { + "id": "X5hf7yPFGPaO" + } + }, + { + "cell_type": "code", + "source": [ + "class ResidualConvBlock(nn.Module):\n", + " \"\"\"\n", + " This class creates a residual convolution block\n", + " \"\"\"\n", + " def __init__(\n", + " self, in_channels: int, out_channels: int, is_res: bool = False\n", + " ) -> None:\n", + " super().__init__()\n", + "\n", + " # Check if input and output channels are the same for the residual connection\n", + " self.same_channels = in_channels == out_channels\n", + "\n", + " # Flag for whether or not to use residual connection\n", + " self.is_res = is_res\n", + "\n", + " # First convolutional layer\n", + " self.conv1 = nn.Sequential(\n", + " nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1\n", + " nn.BatchNorm2d(out_channels), # Batch normalization\n", + " nn.GELU(), # GELU activation function\n", + " )\n", + "\n", + " # Second convolutional layer\n", + " self.conv2 = nn.Sequential(\n", + " nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1\n", + " nn.BatchNorm2d(out_channels), # Batch normalization\n", + " nn.GELU(), # GELU activation function\n", + " )\n", + "\n", + " def forward(self, x: torch.Tensor) -> torch.Tensor:\n", + "\n", + " # If using residual connection\n", + " if self.is_res:\n", + " # Apply first convolutional layer\n", + " x1 = self.conv1(x)\n", + "\n", + " # Apply second convolutional layer\n", + " x2 = self.conv2(x1)\n", + "\n", + " # If input and output channels are the same, add residual connection directly\n", + " if self.same_channels:\n", + " out = x + x2\n", + " else:\n", + " # If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection\n", + " shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device)\n", + " out = shortcut(x) + x2\n", + "\n", + " # Normalize output tensor\n", + " return out / 1.414\n", + "\n", + " # If not using residual connection, return output of second convolutional layer\n", + " else:\n", + " x1 = self.conv1(x)\n", + " x2 = self.conv2(x1)\n", + " return x2\n", + "\n", + " # Method to get the number of output channels for this block\n", + " def get_out_channels(self):\n", + " return self.conv2[0].out_channels\n", + "\n", + " # Method to set the number of output channels for this block\n", + " def set_out_channels(self, out_channels):\n", + " self.conv1[0].out_channels = out_channels\n", + " self.conv2[0].in_channels = out_channels\n", + " self.conv2[0].out_channels = out_channels" + ], + "metadata": { + "id": "ob7A2BIULPN-" + }, + "execution_count": 4, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "class UnetUp(nn.Module):\n", + " def __init__(self, in_channels, out_channels):\n", + " super(UnetUp, self).__init__()\n", + "\n", + " # Create a list of layers for the upsampling block\n", + " # The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers\n", + " layers = [\n", + " nn.ConvTranspose2d(in_channels, out_channels, 2, 2),\n", + " ResidualConvBlock(out_channels, out_channels),\n", + " ResidualConvBlock(out_channels, out_channels),\n", + " ]\n", + "\n", + " # Use the layers to create a sequential model\n", + " self.model = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x, skip):\n", + " # Concatenate the input tensor x with the skip connection tensor along the channel dimension\n", + " x = torch.cat((x, skip), 1)\n", + "\n", + " # Get output\n", + " x = self.model(x)\n", + " return x" + ], + "metadata": { + "id": "DEFXf2rJLU3W" + }, + "execution_count": 5, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "class UnetDown(nn.Module):\n", + " def __init__(self, in_channels, out_channels):\n", + " super(UnetDown, self).__init__()\n", + "\n", + " # Create a list of layers for the downsampling block\n", + " layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)]\n", + "\n", + " # Use the layers to create a sequential model\n", + " self.model = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x):\n", + " # Pass the input through the sequential model and return the output\n", + " return self.model(x)" + ], + "metadata": { + "id": "s4Lh0wPhLZFr" + }, + "execution_count": 6, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "class EmbedFC(nn.Module):\n", + " \"\"\"\n", + " This class defines a generic one layer feed-forward neural network for embedding input data of\n", + " dimensionality input_dim to an embedding space of dimensionality emb_dim.\n", + " \"\"\"\n", + " def __init__(self, input_dim, emb_dim):\n", + " super(EmbedFC, self).__init__()\n", + " self.input_dim = input_dim\n", + "\n", + " # define the layers for the network\n", + " layers = [\n", + " nn.Linear(input_dim, emb_dim),\n", + " nn.GELU(),\n", + " nn.Linear(emb_dim, emb_dim),\n", + " ]\n", + "\n", + " # create a PyTorch sequential model consisting of the defined layers\n", + " self.model = nn.Sequential(*layers)\n", + "\n", + " def forward(self, x):\n", + " # flatten the input tensor\n", + " x = x.view(-1, self.input_dim)\n", + " # apply the model layers to the flattened tensor\n", + " return self.model(x)" + ], + "metadata": { + "id": "bgVfZKY4LaYf" + }, + "execution_count": 7, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#Blueprint of Unet\n", + "class ContextUnet(nn.Module):\n", + " \"\"\"\n", + " A Unet that is conditioned on context\n", + " \"\"\"\n", + " def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28): # cfeat - context features\n", + " \"\"\"\n", + " Constructor\n", + " \"\"\"\n", + "\n", + " #call parent class constructor\n", + " super(ContextUnet, self).__init__()\n", + "\n", + " # setup input channels, number of intermediate feature maps and number of classes\n", + " self.in_channels = in_channels\n", + " self.n_feat = n_feat\n", + " self.n_cfeat = n_cfeat\n", + " self.h = height #height and width of image. Divisible by 4\n", + "\n", + " # Initialize the convolution layer\n", + " self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)\n", + "\n", + " # Two level down sampling\n", + " self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8]\n", + " self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4]\n", + "\n", + " #vectorize\n", + " self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())\n", + "\n", + " # Embed the timestep and context labels with a one-layer fully connected neural network\n", + " self.timeembed1 = EmbedFC(1, 2*n_feat)\n", + " self.timeembed2 = EmbedFC(1, 1*n_feat)\n", + " self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)\n", + " self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)\n", + "\n", + " # 3 level upsampling\n", + " self.up0 = nn.Sequential(\n", + " nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample\n", + " nn.GroupNorm(8, 2 * n_feat), # normalize\n", + " nn.ReLU(),\n", + " )\n", + " self.up1 = UnetUp(4 * n_feat, n_feat)\n", + " self.up2 = UnetUp(2 * n_feat, n_feat)\n", + "\n", + " # Final convolutional layer\n", + " self.out = nn.Sequential(\n", + " nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1),\n", + " nn.GroupNorm(8, n_feat), # normalize\n", + " nn.ReLU(),\n", + " nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input\n", + " )\n", + "\n", + " def forward(self, x, t, c=None):\n", + " \"\"\"\n", + " x : (batch, n_feat, h, w) : input image\n", + " t : (batch, n_cfeat) : time step\n", + " c : (batch, n_classes) : context label\n", + " \"\"\"\n", + "\n", + " # pass the input image through the initial convolutional layer\n", + " x = self.init_conv(x)\n", + "\n", + " # pass the result through the down-sampling path\n", + " down1 = self.down1(x) #[10, 256, 8, 8]\n", + " down2 = self.down2(down1) #[10, 256, 4, 4]\n", + "\n", + " # convert the feature maps to a vector and apply an activation\n", + " hiddenvec = self.to_vec(down2)\n", + "\n", + " # mask out context if context_mask == 1\n", + " if c is None:\n", + " c = torch.zeros(x.shape[0], self.n_cfeat).to(x)\n", + "\n", + " # embed context and timestep\n", + " cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) # (batch, 2*n_feat, 1,1)\n", + " temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)\n", + " cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)\n", + " temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)\n", + "\n", + " #upsample\n", + " up1 = self.up0(hiddenvec)\n", + " up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings\n", + " up3 = self.up2(cemb2*up2 + temb2, down1)\n", + " out = self.out(torch.cat((up3, x), 1))\n", + " return out" + ], + "metadata": { + "id": "gKt2pXAgfq1F" + }, + "execution_count": 8, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# Setup for saving the data and weights\n", + "DATA_DIR = Path('./data/')\n", + "SAVE_DIR = Path('./data/weights/')\n", + "SAVE_DIR.mkdir(exist_ok=True, parents=True)\n", + "save_dir = SAVE_DIR\n", + "data_dir = DATA_DIR\n", + "\n", + "config = SimpleNamespace(\n", + " # hyperparameters for WandB\n", + " num_samples = 30,\n", + "\n", + " # diffusion hyperparameters\n", + " timesteps = 500,\n", + " beta1 = 1e-4,\n", + " beta2 = 0.02,\n", + "\n", + " # network hyperparameters\n", + " n_feat = 64, # 64 hidden dimension feature\n", + " n_cfeat = 5, # context vector is of size 5\n", + " height = 16, # 16x16 image\n", + "\n", + " # training hyperparameters\n", + " batch_size = 100,\n", + " n_epoch = 128,\n", + " lrate = 1e-3,\n", + ")\n", + "\n", + "# Get device\n", + "device = torch.device(\"cuda:0\" if torch.cuda.is_available() else torch.device('cpu'))\n", + "print(device)\n" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "cppkUCRdfyzh", + "outputId": "c68fd65f-4722-4fc1-f476-75c5acd8e69d" + }, + "execution_count": 9, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "cuda:0\n" + ] + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# DDPM Algorithm" + ], + "metadata": { + "id": "gt_RQHa0H4Q1" + } + }, + { + "cell_type": "code", + "source": [ + "def setup_ddpm(beta1, beta2, timesteps, device):\n", + "\n", + " # construct DDPM noise schedule and sampling functions\n", + " b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1\n", + " a_t = 1 - b_t\n", + " ab_t = torch.cumsum(a_t.log(), dim=0).exp()\n", + " ab_t[0] = 1\n", + "\n", + " # helper function: perturbs an image to a specified noise level\n", + " def perturb_input(x, t, noise):\n", + " return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise\n", + "\n", + " # helper function; removes the predicted noise (but adds some noise back in to avoid collapse)\n", + " def _denoise_add_noise(x, t, pred_noise, z=None):\n", + " if z is None:\n", + " z = torch.randn_like(x)\n", + " noise = b_t.sqrt()[t] * z\n", + " mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()\n", + " return mean + noise\n", + "\n", + " # sample with context using DDPM algorithm\n", + " @torch.no_grad()\n", + " def sample_ddpm_context(nn_model, noises, context, save_rate=20):\n", + " # array to keep track of generated steps for plotting\n", + " intermediate = []\n", + " pbar = tqdm(range(timesteps, 0, -1), leave=False)\n", + " for i in pbar:\n", + " pbar.set_description(f'sampling timestep {i:3d}')\n", + "\n", + " # reshape time tensor\n", + " t = torch.tensor([i / timesteps])[:, None, None, None].to(noises.device)\n", + "\n", + " # Add some noise back in if i is not 1\n", + " z = torch.randn_like(noises) if i > 1 else 0\n", + "\n", + " eps = nn_model(noises, t, c=context) # predict noise\n", + " noises = _denoise_add_noise(noises, i, eps, z)\n", + " if i % save_rate==0 or i==timesteps or i<8:\n", + " intermediate.append(noises.detach().cpu().numpy())\n", + "\n", + " intermediate = np.stack(intermediate)\n", + " return noises.clip(-1, 1), intermediate\n", + "\n", + " return perturb_input, sample_ddpm_context" + ], + "metadata": { + "id": "blicU3wqf6d2" + }, + "execution_count": 10, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#create an instance of the model\n", + "nn_model = ContextUnet(in_channels=3, n_feat=config.n_feat, n_cfeat=config.n_cfeat, height=config.height).to(device)\n" + ], + "metadata": { + "id": "rxjrXri5gI42" + }, + "execution_count": 11, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Data" + ], + "metadata": { + "id": "IKyT93CVIQ-5" + } + }, + { + "cell_type": "code", + "source": [ + "class CustomDataset(Dataset):\n", + " def __init__(self, sfilename, lfilename, transform, null_context=False):\n", + "\n", + " #load image arrays and labels\n", + " self.sprites = np.load(sfilename)\n", + " self.slabels = np.load(lfilename)\n", + " print(f\"sprite shape: {self.sprites.shape}\")\n", + " print(f\"labels shape: {self.slabels.shape}\")\n", + "\n", + " #get transformations to be done\n", + " self.transform = transform\n", + " self.null_context = null_context\n", + " self.sprites_shape = self.sprites.shape\n", + " self.slabel_shape = self.slabels.shape\n", + "\n", + " # Return the number of images in the dataset\n", + " def __len__(self):\n", + " return len(self.sprites)\n", + "\n", + " # Get the image and label at a given index\n", + " def __getitem__(self, idx):\n", + " # Return the image and label as a tuple\n", + " if self.transform:\n", + " image = self.transform(self.sprites[idx])\n", + " if self.null_context:\n", + " label = torch.tensor(0).to(torch.int64)\n", + " else:\n", + " label = torch.tensor(self.slabels[idx]).to(torch.int64)\n", + " return (image, label)\n", + "\n", + " def getshapes(self):\n", + " # return shapes of data and labels\n", + " return self.sprites_shape, self.slabel_shape\n", + "\n", + "transform = transforms.Compose([\n", + " transforms.ToTensor(), # from [0,255] to range [0.0,1.0]\n", + " transforms.Normalize((0.5,), (0.5,)) # range [-1,1]\n", + "\n", + "])" + ], + "metadata": { + "id": "GwAQE0g6LxY9" + }, + "execution_count": 12, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# load dataset and construct optimizer\n", + "dataset = CustomDataset(\"./sprites_1788_16x16.npy\", \"./sprite_labels_nc_1788_16x16.npy\", transform, null_context=False)\n", + "dataloader = DataLoader(dataset, batch_size=config.batch_size, shuffle=True, num_workers=1)\n", + "optim = torch.optim.Adam(nn_model.parameters(), lr=config.lrate)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "R7xWQnO7gTRu", + "outputId": "c7b5a888-7ebd-4879-8a7f-b033794c9546" + }, + "execution_count": 13, + "outputs": [ + { + "output_type": "stream", + "name": "stdout", + "text": [ + "sprite shape: (89400, 16, 16, 3)\n", + "labels shape: (89400, 5)\n" + ] + } + ] + }, + { + "cell_type": "code", + "source": [ + "# helper function: perturbs an image to a specified noise level\n", + "perturb_input, sample_ddpm_context = setup_ddpm(config.beta1,\n", + " config.beta2,\n", + " config.timesteps,\n", + " device)" + ], + "metadata": { + "id": "OkQs7K29gY1r" + }, + "execution_count": 14, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "noises = torch.randn(config.num_samples, 3,\n", + " config.height, config.height).to(device)\n", + "\n", + "# A fixed context vector to sample from\n", + "ctx_vector = F.one_hot(torch.tensor([0,0,0,0,0,0, # hero\n", + " 1,1,1,1,1,1, # non-hero\n", + " 2,2,2,2,2,2, # food\n", + " 3,3,3,3,3,3, # spell\n", + " 4,4,4,4,4,4]), # side-facing\n", + " 5).to(device).float()" + ], + "metadata": { + "id": "7x7h_bJA6L39" + }, + "execution_count": 15, + "outputs": [] + }, + { + "cell_type": "markdown", + "source": [ + "# Training" + ], + "metadata": { + "id": "hPL2wdBNJy08" + } + }, + { + "cell_type": "code", + "source": [ + "import os\n", + "\n", + "# set into train mode\n", + "nn_model.train()\n", + "\n", + "#initialize wandb project\n", + "run = wandb.init(project=\"msml612_sprite\",\n", + " job_type=\"train\",\n", + " config=config)\n", + "\n", + "# pass config to wandb\n", + "config = wandb.config\n", + "\n", + "#train\n", + "for ep in range(config.n_epoch):\n", + " print(f'epoch {ep}')\n", + "\n", + " # LR Decay (linear)\n", + " optim.param_groups[0]['lr'] = config.lrate*(1-ep/config.n_epoch)\n", + "\n", + " #progres bar\n", + " pbar = tqdm(dataloader, mininterval=2 )\n", + " for x, c in pbar: # x: images c: context\n", + " optim.zero_grad()\n", + " x = x.to(device)\n", + " c = c.to(x)\n", + "\n", + " # randomly mask out c\n", + " context_mask = torch.bernoulli(torch.zeros(c.shape[0]) + 0.9).to(device)\n", + " c = c * context_mask.unsqueeze(-1)\n", + "\n", + " # perturb data\n", + " noise = torch.randn_like(x)\n", + " t = torch.randint(1, config.timesteps + 1, (x.shape[0],)).to(device)\n", + " x_pert = perturb_input(x, t, noise)\n", + "\n", + " # use network to recover noise\n", + " pred_noise = nn_model(x_pert, t / config.timesteps, c=c)\n", + "\n", + " # loss is mean squared error between the predicted and true noise\n", + " loss = F.mse_loss(pred_noise, noise)\n", + " loss.backward()\n", + "\n", + " optim.step()\n", + "\n", + " wandb.log({\"loss\": loss.item(),\n", + " \"lr\": optim.param_groups[0]['lr'],\n", + " \"epoch\": ep})\n", + "\n", + " # save model periodically\n", + " if ep%4==0 or ep == int(config.n_epoch-1):\n", + " nn_model.eval()\n", + " ckpt_file = SAVE_DIR/f\"context_model.pth\"\n", + " torch.save(nn_model.state_dict(), ckpt_file)\n", + "\n", + " #save to WandB\n", + " artifact_name = f\"{wandb.run.id}_context_model\"\n", + " at = wandb.Artifact(artifact_name, type=\"model\")\n", + " at.add_file(ckpt_file)\n", + " wandb.log_artifact(at, aliases=[f\"epoch_{ep}\"])\n", + "\n", + " samples, _ = sample_ddpm_context(nn_model,\n", + " noises,\n", + " ctx_vector[:config.num_samples])\n", + "\n", + " #Save samples\n", + " wandb.log({\n", + " \"train_samples\": [\n", + " wandb.Image(img) for img in samples.split(1)\n", + " ]})\n", + "wandb.finish()" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 1000 + }, + "id": "do0GE6yHgfPG", + "outputId": "4c31427f-07dc-4acc-b71f-b083401356b2" + }, + "execution_count": 16, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [ + "\u001b[34m\u001b[1mwandb\u001b[0m: Currently logged in as: \u001b[33madityapatkar\u001b[0m (\u001b[33mteamaditya\u001b[0m). Use \u001b[1m`wandb login --relogin`\u001b[0m to force relogin\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Tracking run with wandb version 0.15.8" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Run data is saved locally in /content/wandb/run-20230813_092510-01ihqkt2" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Syncing run absurd-paper-9 to Weights & Biases (docs)
" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View project at https://wandb.ai/teamaditya/msml612_sprite" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View run at https://wandb.ai/teamaditya/msml612_sprite/runs/01ihqkt2" + ] + }, + "metadata": {} + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 0\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:23<00:00, 37.98it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 1\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 56.82it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 2\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.92it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 3\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.48it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 4\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.29it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 5\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.09it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 6\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.59it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 7\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.16it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 8\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.86it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 9\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.08it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 10\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.33it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 11\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.97it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 12\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.74it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 13\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 59.73it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 14\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.73it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 15\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.81it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 16\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.96it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 17\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 59.90it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 18\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.68it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 19\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.28it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 20\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.87it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 21\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 59.33it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 22\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.64it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 23\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.56it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 24\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.50it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 25\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.06it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 26\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.13it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 27\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.16it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 28\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.58it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 29\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.97it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 30\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.40it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 31\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.98it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 32\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.75it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 33\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.37it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 34\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.19it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 35\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.53it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 36\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.14it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 37\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.53it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 38\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.51it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 39\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.52it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 40\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.47it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 41\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.91it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 42\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 63.92it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 43\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.48it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 44\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.59it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 45\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.33it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 46\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.61it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 47\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.65it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 48\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.60it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 49\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.83it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 50\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 58.65it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 51\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.79it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 52\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.02it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 53\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.84it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 54\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.10it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 55\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.16it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 56\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.95it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 57\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.03it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 58\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.94it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 59\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.67it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 60\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.65it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 61\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.30it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 62\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.48it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 63\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.65it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 64\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.05it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 65\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.25it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 66\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.74it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 67\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.84it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 68\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.28it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 69\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.70it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 70\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.60it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 71\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.50it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 72\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.32it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 73\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.57it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 74\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.49it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 75\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.27it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 76\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.41it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 77\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 57.84it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 78\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.12it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 79\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.94it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 80\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.48it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 81\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 58.05it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 82\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.53it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 83\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.59it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 84\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.76it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 85\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.89it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 86\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.28it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 87\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.33it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 88\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.68it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 89\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 58.72it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 90\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.80it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 91\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.45it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 92\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.37it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 93\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.36it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 94\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.85it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 95\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.24it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 96\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.93it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 97\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.35it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 98\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.17it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 99\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.49it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 100\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.10it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 101\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.31it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 102\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.73it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 103\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.29it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 104\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.27it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 105\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 58.13it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 106\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.19it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 107\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.11it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 108\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.69it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 109\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 57.03it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 110\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.12it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 111\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 63.96it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 112\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.96it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 113\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 61.29it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 114\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.19it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 115\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.02it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 116\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 67.18it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 117\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:15<00:00, 58.33it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 118\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.87it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 119\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 63.99it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 120\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 66.69it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 121\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 62.09it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 122\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 64.29it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 123\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.25it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 124\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 63.87it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 125\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 60.88it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 126\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:13<00:00, 65.77it/s]\n" + ] + }, + { + "output_type": "stream", + "name": "stdout", + "text": [ + "epoch 127\n" + ] + }, + { + "output_type": "stream", + "name": "stderr", + "text": [ + "100%|██████████| 894/894 [00:14<00:00, 63.05it/s]\n" + ] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Waiting for W&B process to finish... (success)." + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "\n", + "

Run history:


epoch▁▁▁▁▂▂▂▂▂▃▃▃▃▃▃▄▄▄▄▄▅▅▅▅▅▅▆▆▆▆▆▇▇▇▇▇▇███
loss█▅▄▄▄▄▃▂▄▃▃▃▃▃▂▂▂▄▃▃▃▂▄▂▂▁▃▂▂▃▂▃▁▂▁▁▃▁▁▂
lr███▇▇▇▇▇▇▆▆▆▆▆▆▅▅▅▅▅▄▄▄▄▄▄▃▃▃▃▃▂▂▂▂▂▂▁▁▁

Run summary:


epoch127
loss0.0658
lr1e-05

" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View run absurd-paper-9 at: https://wandb.ai/teamaditya/msml612_sprite/runs/01ihqkt2
Synced 5 W&B file(s), 990 media file(s), 33 artifact file(s) and 0 other file(s)" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Find logs at: ./wandb/run-20230813_092510-01ihqkt2/logs" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "markdown", + "source": [ + "# Inference" + ], + "metadata": { + "id": "d6vF6aPcJ2Bp" + } + }, + { + "cell_type": "code", + "source": [ + "#setup DDPM\n", + "_, sample_ddpm_context = setup_ddpm(config.beta1,\n", + " config.beta2,\n", + " config.timesteps,\n", + " device)" + ], + "metadata": { + "id": "0riboFnkiPXy" + }, + "execution_count": 27, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#create noise\n", + "noises = torch.randn(config.num_samples, 3,\n", + " config.height, config.height).to(device)\n", + "\n", + "# A fixed context vector to sample from\n", + "ctx_vector = F.one_hot(torch.tensor([0,0,0,0,0,0, # hero\n", + " 1,1,1,1,1,1, # non-hero\n", + " 2,2,2,2,2,2, # food\n", + " 3,3,3,3,3,3, # spell\n", + " 4,4,4,4,4,4]), # side-facing\n", + " 5).to(device).float()" + ], + "metadata": { + "id": "rXb9q47biWGR" + }, + "execution_count": 18, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "ddpm_samples, _ = sample_ddpm_context(nn_model, noises, ctx_vector)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/" + }, + "id": "AzEPoaJXAiPu", + "outputId": "0ac84abb-7a37-4679-ed06-c7180514f8ea" + }, + "execution_count": 19, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [] + } + ] + }, + { + "cell_type": "code", + "source": [], + "metadata": { + "id": "vgVhBlxyVxYP" + }, + "execution_count": null, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#create a wandb table\n", + "table = wandb.Table(columns=[\"input_noise\", \"ddpm\", \"class\"])" + ], + "metadata": { + "id": "cVSXoNPfAnSW" + }, + "execution_count": 20, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#get classes for each ctx vector\n", + "def to_classes(ctx_vector):\n", + " classes = \"hero,non-hero,food,spell,side-facing\".split(\",\")\n", + " return [classes[i] for i in ctx_vector.argmax(dim=1)]" + ], + "metadata": { + "id": "-Ynl0Ch8AzyC" + }, + "execution_count": 21, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#add images to WandB\n", + "for noise, ddpm_s, c in zip(noises,\n", + " ddpm_samples,\n", + " to_classes(ctx_vector)):\n", + "\n", + " # add data row by row to the Table\n", + " table.add_data(wandb.Image(noise),\n", + " wandb.Image(ddpm_s),\n", + " c)" + ], + "metadata": { + "id": "KG_sbuOyAogK" + }, + "execution_count": 29, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "#save samples to wandb\n", + "with wandb.init(project=\"msml612_sprite\",\n", + " job_type=\"samples\",\n", + " config=config):\n", + "\n", + " wandb.log({\"samplers_table\":table})" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 174 + }, + "id": "nCNI3tRkA8ON", + "outputId": "91906c1d-76a3-4078-a6be-5b95695d84fe" + }, + "execution_count": 23, + "outputs": [ + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Tracking run with wandb version 0.15.8" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Run data is saved locally in /content/wandb/run-20230813_095756-2j5113bc" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Syncing run glamorous-mountain-10 to Weights & Biases (docs)
" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View project at https://wandb.ai/teamaditya/msml612_sprite" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View run at https://wandb.ai/teamaditya/msml612_sprite/runs/2j5113bc" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Waiting for W&B process to finish... (success)." + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + " View run glamorous-mountain-10 at: https://wandb.ai/teamaditya/msml612_sprite/runs/2j5113bc
Synced 4 W&B file(s), 1 media file(s), 61 artifact file(s) and 0 other file(s)" + ] + }, + "metadata": {} + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "" + ], + "text/html": [ + "Find logs at: ./wandb/run-20230813_095756-2j5113bc/logs" + ] + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [ + "#Function to show images\n", + "def show_images(imgs, nrow=2):\n", + " _, axs = plt.subplots(nrow, imgs.shape[0] // nrow, figsize=(4,2 ))\n", + " axs = axs.flatten()\n", + " for img, ax in zip(imgs, axs):\n", + " img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2\n", + " ax.set_xticks([])\n", + " ax.set_yticks([])\n", + " ax.imshow(img)\n", + " plt.show()" + ], + "metadata": { + "id": "1CoL9AMJiYow" + }, + "execution_count": 24, + "outputs": [] + }, + { + "cell_type": "code", + "source": [ + "# user defined context\n", + "ctx = torch.tensor([\n", + " # hero, non-hero, food, spell, side-facing\n", + " [1,0,0,0,0],\n", + " [1,0,0,0,0],\n", + " [0,0,0,0,1],\n", + " [0,0,0,0,1],\n", + " [0,1,0,0,0],\n", + " [0,1,0,0,0],\n", + " [0,0,1,0,0],\n", + " [0,0,1,0,0],\n", + "]).float().to(device)\n", + "\n", + "noises = torch.randn(8, 3,\n", + " config.height, config.height).to(device)\n", + "\n", + "samples, _ = sample_ddpm_context(nn_model, noises, ctx)\n", + "show_images(samples)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 188 + }, + "id": "Z-XhQD-XiZGs", + "outputId": "2f48ecd8-b9b1-4fc1-f620-e681e1078840" + }, + "execution_count": 37, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUoAAACrCAYAAAAARtWWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAsbklEQVR4nO2debgdZZXuV817nylzSE5yMpCQmUAYIiGEEIMMgoAER6D7KhdRugWHy+XKVWmu0oq22GDTICriFRSVbkC42IgYCPMUEiAhJEBCTgYyJ/sk5+zaNfUfu84J71qV1N5n65V+XL88+znPyq5aVfXVV9+uer9VaxlJkiSkKIqiHBDzL70DiqIo73V0oFQURclBB0pFUZQcdKBUFEXJQQdKRVGUHHSgVBRFyUEHSkVRlBx0oFQURcnBrmWhOI5p06ZN1NraSoZh/Ln36T1NkiTU1dVF7e3tZJq1/85oG+5H27BxtA0bp642TGqgs7MzISL9vOvT2dlZS9NpG2obahu+xz+1tGFNd5Stra1ERLR2fSe1tbUREdGps86CZS796BVyRcMC07Jx1DZt3PyYqdOEC7+I64Qe+uxuxl/Fl1e9JHyYDi6TWAn6DHywH7zjB8LHI3f/ioiISqUSjRvb0dcmtdK7/Pp3tSEZESxjJPIXftw5HWAf/cnRbAlcZ/i4ovCRxDHYi296ky2A5ut3dgof0z97GNgrbl4DdoXdnbhBKPfDqZ7LUqlEYzrG9rsN31rfSa1pG7oxbicxLbFe+6fHgD3oFFzGDHD5CE8LERE1H9EG9obrdoH9zq/eRh+8UYmIQn494Ib90AHbteWOJOl/lUolGjtuXL/b8M21ndTaWj0mz8L+H5qe3C47nafOW4D7nlTAfuLpJcKHEbNzY7FzF+N48JFpRwsfZ3WMB3vWwEPAjmJssx+sWSp8/HjZM0SU9sOxtfXDmgbK3lv0tra2vovcNnHVJq9ZrmjWN1C2NGFnJCKymw4+UBoteIEWmuR+mA76SCwcOMIA98OyscMS0f7BrXe7dT62ZLVhLQOlaeP/OUU8/oQNlG6zHCiSGJfhPxz8mubHSkRkuuZBl6lnoOylv23Y+q42rGWgNNjxmkVmszZOIjnIWS1s35lP3h5/7oGybz/624at+9uwPwOlZbFjSdDO6kP1DpROxrkssmuzxXHBjtivXJaP/lzLNQ2UfQsnEdnpmUrYSXxt3eti+UmH4h0i738m6+SvrZB3g+NnHwN2wu5Sl6a/Dn3f21kHzfSHBO3tWzeAHYfsFoOIDCOBv/1l3mePJCsddPbuwV/hMMN3cxN2hLX3453MoBPw+56lcoBqPwI71+QPTcB1unvAHnbWUOFj2307wA7wt4ZcdlsW8cGYiJK03aOksTlENwrJjarHGbEf28EXDRbLFwewH2g22JgG/wGX+1cp4LlK9nSDHUdsHUueB4uNcgHhefFiHLB4PyUiMpLqOmYiB4B68Cgkj6r72MMGxkJZLm8UsI88+9JzYMfsR//sSfgkRER0z6qNYC8YhHf6A5rxzi7rCB9Y/xbY5l5so9jDfnfmqBnCx6LRk4mIKGBPWgdDZ70VRVFy0IFSURQlBx0oFUVRcqhLo6yYFlV6xVE22xxytZeIbDZJwvU9njLYd6RG98rLr4BdZnM1tosaXWxJbSyO0a/joTbUvWcP+kzkZI6Z7qzZYJ7jJOihJP19mn0Z03EC6bt1WAvYTKKllU+uA7vtULnvu7txpeJo/H0smNio44JRwsfw96MAvuWPO8GOI9xuwoVA2t/Z7MaakMgyqh8iMtikie3JfmiG7H6gwma9Wb80+IwJERlPYb+yhzFdk2l0lUheWiH7L4/NFPsma8OMSQY7jdgIrcYasezY5DrVHfJi1FsNT/YhijCaosdCTdplE1H3vi4jJz48bTLY35iG+qFtYZuu2Ya6OBGRySY8b37zCbD/+/TjwX69C/spEdGRIw4lIqJyFNIDW94S32ehd5SKoig56ECpKIqSgw6UiqIoOehAqSiKkkNdkzluFJGbRr5bLHrcDqXwvKrzebCnTz4W7HKAArgVyhBTHrjM35owWaBvFGcE6TKhfcXyF8D22ASQ7WQI8elvStjgb8uTP3m9782AUQtxIuqwMyeI5d0iRv/GBoroBRfbzH9HBsvbwzCw1mSTCKuX4DZaBhaEj7bpQ8Bumo2TOz3PlsCOst6O7X0DxPTld3VQSSyqpAHXLuG5NffKiYjiQDZJ8ja2kTecT0TJbQ5qwj6y3cc2jFh4tGtggDYRUcwmRCI2iWSxvmXHGZHf6XHbsZy0qodCUqZCUu1/QdIE3xkkt9ttYRs1s8kbMvCcntgxRfi44XB8eSSs4DG8shUD0rPemGEvmdHFo/CauW31SrBPHjtR+Fi+u/qCSRBnvKt6APSOUlEUJQcdKBVFUXLQgVJRFCWHujTKsmGRm0Y8P7bsj/DdmUeeKpY/cda5YK9ZvQLs9kmYMmnFsuuEj2IRg6EtlvTCZNlBiqd9TfhYvgZ1C4sFwq9+Co/lsd/9u/BhpwkEbKMxbahsELnpIWz4A2pBg2fIzEcjTsOAcx4MnwR4/D8d+QHh4+Ydu8EuBZvAntU6EOylm7cIH5bFMrswHTdiiSUcpoMSEcWRB3/7i21GZKcB7TFLDjF+rdRXL7ziHLCve+xOsLe9thvsSWdgYDQR0flTBoG9Y/ZUsGedg+m+Xro3ow2ZbulbqFl6FewPQSKPpTcBj29kBIXXgR8UyA+q/j0bt1vJ2G5zgOez7LBMR+HB+wcR0VIW3M3nExL+YoiVlQmLzVGwBCZHtOH18uAGGVD+zNY3iKiaZm3QoIHi+yz0jlJRFCUHHSgVRVFy0IFSURQlh7o0ykLkUyHqjZdCjeL+pb8Xy5901Ilgnzv7b8F+dPE1YN+++HHh4+pPnAG2bWDs4VU/vhvsKweilkRE1DrtbLCX/+stYC9+7n7cRixjEeP0eOO4riYTFMKACmli4IDFpm1bsU8sP+FoTKLbNh+P7ydjTgd79pU3Ch+//vonwH7WxqQXN3//PrDPnTdL+Fg2jsXWMa3IYrZPeJ6IiLw0ma2ZkdS2Hsyw+iEiipiMFZOMjVvyxxfB/tTM6WAHIa5jvy01unVM6vKKeHz2XpYlP0PLDmP067EE0RUbv8+KxaS4ul0vI3FHPXhmTJ5Zja/t5tUKejLOD6swUohwmY/Mnwv2IQMx7paIKJmG/S4M2HX28vqD7TIRyaQ0r5bRh7g6M5LYxGlVgJhqzw6vd5SKoig56ECpKIqSgw6UiqIoOdQluCWORYlTFYX8gCV/deT7u0uWPgr2i5//EtjLmEZRdGWpVcdEHSErqeq7eeSBB8X/zZ6yEOzHnkFNLuZ6Ukbi3t4wsYzX0evDor6qSewVdLJ53VQieuvF7WDPmYnvrt4/GWNEM2E6zfuYfnibje/68mJjRERd93WBvXMd2lTBg/Fc2R+CNH4yyEhqWw8V26KKnb7rzaruDWGJjomIStuwDZezJLzjB2O/axsgT/J3vvoU2Kd/CHXOpjKuk61/4f/FLFGtm+D5D2J5PRhW9Xgb7Yfdtkl2GoPYxIpslYvy/in28Zy5LvapzZuxP8zokBr1ixuxAOEtj2Fc9a+/fA7Yax+Wsag7SqjjG1wrZ309o7BpXyx0PTHRekepKIqSgw6UiqIoOehAqSiKkoMOlIqiKDnUpar7ZJOfrlJgVdj8jEQHLqvEd+y/XA/2kvFjwT55zpHCxwvrt4JtsvjRRSe9D+zlq98UPm686CNg8+kgh1Xp684QypvCalIAN5TJHurBJ4f8NHjds9FXJZBtyHJN0HMv4/EtnDAc7H84B5MjExG9tpJFS9t42i9fgBMTi1+WFfTeXIvBwFaAEwABSyBsJbIRnbRSp5NRsbMe3CAkN6j6SGzczr4NMunswiOxnWOW/XVacijY9z2JCWSzaKlgGyYWBosnGbMINrseKj5eQ3HPOrALg2TS2d4Ieztq7B6nyQ+pya+2YdnDfS1UMnw7e9EOcAJwyZpXwb5wLvYpIqIWVt2RJ7S591nsp+MyknP4bALGMFk1THbvlzX1W0mTSleykksfAL2jVBRFyUEHSkVRlBx0oFQURcmhvqQYlZAKaUGg0MUx1gul7sSDbo+Y1AH2py69GOy7fvxj4cNgepLBfHKd59OXfEX4OHYCaqFLV78NNoulpaaMok6R7aV/G0uY6oUxeWFV3yvbGJRbyAwixuMbeXo72C3H43l48hk8NiIi3+TnBoVervNsmSy10lFHDAN707IdYDuEOmAQyoDjONVGQ6vBgHPHpkpfYljUSh9/RiZWOGrySLAjC/dt0ZfxeH//klS2Pn0yFstatRmP/4lluF1D5mKge5++A+yOsZh8pLxlGdhJ8pLwMffoj1W/42J9ncR2QrFd9WGyYSAh+cJBxcRA/oBJsK6Fbfbzp14TPs45EguBXXLiEWD7ZfRRNjPmA5iuyboysdh5CjP0cDdNwt37txb0jlJRFCUHHSgVRVFy0IFSURQlh/riKF2bfDddhckHliN1naMPxwJMk8ejRnH89OPBHnQRxmYREcVMhIi5JsGKpC9YcLTwseJpTH47deIksFeuwZf1KUuH7BVluDhTJ75jku9Uf58KLHFrbMntTpyB+37ieTPBbh5XAvulQMaRJiy7rcGLOjGd79CJMrFE6GJM25QpqJWuWr0BbEfootSXE6LRAm1uGJEbZhcXI1smFnnxtXfA/tD8aWD/r1uXgx2H0ofLTvuSpXi8ERPHHrjnR8JH64w2sP09eK6G2vj9+BkfEj564zOz4jTrwYxMMtNYTNfCi7k7kck4mphu6bFTGBu4jklS5793OY8BHgd2ewGv/6dK2LeJiN5ns8TcLI7WYvG85YwRLjBs+FsLekepKIqSgw6UiqIoOehAqSiKkkNdGqVHFfJ6xUkLNZK5x0ltcMGco8BOmCZgOjhOT52FcVXVPWTvcjJNkridEV52/GSM3zxpIu7HuXNRS73r6dXCh+dUj9tyGnvX24mrHyKimBV4P/78cWL5j1+7AOzy7m6wn34G9dXBowcKH7xIPNd5OVEk4+gWjB0B9plTUZO64CR8L/kXi98QPsJUEwsbTNzrWxb5adt5LM40znh/d+ZxqKd6AeqtQ4dhjKhTkPcPJjtXg0di8axdLK7yhSVPCB/T41PAPvs8TCjtUivYNs/sTERhUt23egpjZVE2LXLTmESHZQEucAGSiKgb9fNyE7ZzIcE+E/NqZER0yqQPgP0v/w37yPbtmMg3Syveswu14XseuhrsQ0bjfvFcCURETpr826mjQJveUSqKouSgA6WiKEoOOlAqiqLkoAOloihKDvUFnIcu+WmyA8dmb59nqKbTZswAe8WqNWBfcTlWZbz+hzcKH2bCZx5YlbUAhefrvvmPwsfJIzGQ9eTjUbz/7RPow8t4Gb8SufC3v5hhRGYaLB2wYFnTlAL9IfPQfuMhXKb7TTyFQ6dJId40cJ8D1qYRq2S45fXxwodxGJ7vOZMw8PfxYC1uI+M32EmTPduWFOnrwTMC8oyqjzDGSYYzLp8jlv/wZaeBbQe7wH7tDdyfyhpWYZKIyMYJIJNNBHzqoyeA/YUL5H7cdN+/gv2x885Hn2zyxo9klhQvDajPqthZD5Zd/RARWRUWHB7L7e5rwnZurrBkyA4mFjmWXftEREcMmgv2tp04eZOwa3vAFpmEOogwaP3jc/4H2L9egVVYCw5OfhIRBWkZ1N6/taB3lIqiKDnoQKkoipKDDpSKoig51BdwbpfJS5PNJixo9/Gnl4nljzsCkw+QhePysHEYxHz99f8kfMQsUWcSMQ0u4nqbHPvvfgH1s988j/YfXsLA1yiQPlwnTfZpNZbQwXct8t3eYGn09fjPZUKLuRdiIHecoFY40B4H9r43pe5iMe3HZUk/TBaBfogjtdInXlgH9jNLMdHEvz+JCYONxBc+KE610rjBgPPYIT/VJj2m1QXlfWJ5awD21VFHoUb9u0dQOx/D9XeSQfhNTRhQHVZwP+74PUu0QkS3/vQpsHlyEp/J8Z4t9bU4KqZ/G2tDpxKRU6lqomWX6+6y/zd34/nsbkJNsoldFi0ZSXHtIu6zzZJy8+mIf9x9jvBx18xvg/3Q68+DPYYlC3nrkDHCh5Pq271/a0HvKBVFUXLQgVJRFCUHHSgVRVFyqC+OMiqQH1X1noTlmC1EUpN6evlKsCd0YHKKUeOx8HwcSm2otHP7QfeppQ0TCVgZsYjL38GEBe90bgK7zHQNIyNvrxWmAkrYWFEnO65+iEjoq4Elf7eevgP1s47Ro8GeeSTubMAT2RLRru14/AkTg0a3o2YXcx2YiJZvwXOz9g0sVl+OsA2LvIMQkZ8mUvF5IpM68UyfPLPa34IQtbI//EgW5Jp36WywCw/jvjnNg3GFUZhEg4goMfFSGfbWRrB/dt8zYJuJPJcxqzhmJ6iR2fxqrBSIY7ppQTCrsX7ou9UPEVGB3y/1yKS7QRPqmE0BxlGWHfx+8St47RMRzToWE+d078Z4VZMlkK4UZT+8+A/fBJsneAmiQ8D+4zN3Ch+VNOa7knWhHwC9o1QURclBB0pFUZQcdKBUFEXJob44yigmL0p1BBa/5mcUxvJi1Bj29ewFe/NG1Aq9gtRkOtdjok5eU6mjA/WkMM5Idhqyd2hZ0tOCwfTVjBDAwKtqYUFW4bE6sOKYrLQQlc/iFQuBjOuKWbJj38ed27J9D/polglTOzdtBttgx99cxGOyLdkt9u1h7z8zLbQYY8xfxZaF4nqTPnu8Ml2dRKFHUapNOqyY2JAFo8Xys849HPejjPu+/HZ853jr1J3Cx5DjTwLb/Ay+pz36xClgb35ilfARVnC7touaXMDiSy1XanSV0E7/NnaPk5BBSW8/8HE7SVFehzwstuKhJllgOReGtkud97TTTwXbYJrkurWo+27cuF74mHgoznPYbE5i2Y6HwbZMGZttpfGrrpwSOSB6R6koipKDDpSKoig56ECpKIqSgw6UiqIoOdQ1mRO6BoVuVTw1WKCvR1IZDZnQunU7Bo+3s8DeKCOY+52tW8A22WzOmPEo7lIsg5m3vrMVffDEvD5OZoSe3A8nnRBywtort2URm9UPEZHDm9+Q2+VBylu3YPD4SDaZFcfyPGzeyI8flxkyDJPwWoYMWt+2Fc9dYvLqhziJ5JKs5Lg/oUNjyY8Ta/9cUlJD8lXXxz5hssB+m1VY7FqDfY6IKByLlTkHTsEJIouwD8WxvAcpODjhEbCg9AKb/AwyEui6dgX+9pdCkFAhqPatyMPtRBm5IlyPXzNohh6bEBSJNuTkTfc+DGxfvx6T1RQ8OXGaROiDv6Nisnu/KKMqp2lW2zkxa09wo3eUiqIoOehAqSiKkkNNj9697waXSqW+/zNCfJyxMh+98bHRYk/F/DEx5i9usm0SyUdv8aiZ8egtfIhHb2yGrEdvO0Rf/H3pPLLakEz8nTIzpIfQZm0YH/z4sx695fHjMglrd/64n+UjMVGCMFh+RNOU7wvH6TPdn7INLYMdS0a+gKCMj1j8qTiOeP2nDAnEx+fRqAefPWMmyfD2qu4rLsPrCjlsP7JquvTGjZZK1bjWfrdh1/79ixz2uBrIa8h16nv0zuqHFZazM2Bxw3ydiJ8XIgpDPJc8bQD3kXUeeh+96+qHSQ10dnYmVK3qpZ/009nZWUvTaRtqG2obvsc/tbShkST5w2kcx7Rp0yZqbW0lo8HML//VSZKEurq6qL29nUyzduVC23A/2oaNo23YOPW0YU0DpaIoyl8zOpmjKIqSgw6UiqIoOehAqSiKkoMOlIqiKDnoQKkoipKDDpSKoig56ECpKIqSgw6UiqIoOehAqSiKkoMOlIqiKDnoQKkoipKDDpSKoig56ECpKIqSQ02JezU10340vVXjaBs2jrZh49TVhprs88+X7FPbUNtQ2/C9/6mlDWu6o2xtbSUiorVr11FrWxsRERkJplw3DOnKYv+1YdMmsEcOGwV2YsjybwlP9R4luA1CO66gTUT0hdPOB/vMyy4F+4RTZoPtNWNFQSIi267+4pRKJRo3Zmxfm9RK7/Kd6zqpLW3D+bOPhWU6OzaI9b7/uR+A/cGTzgDbMLGRLUuWDzBYNcwoYpUkWekDw5SVJr993S1gf/HvPgO2aWLVPYNkhTvLre5rqVSicePH978N167ta8OKjXcCbkZJkoiV+oiwgCg5IatrYLAFiCiw0K/L2jA2sRRCksj94GUryuyaKbAmSyxZyTJKq12WSiUaP7aj3224tnNt37VsB7gfZUdeQ4UK9iHDxdIQ5RDPf8GSpUBCo4A+KtiGpovb9QM5pni8JAUrWxE6eB6cQO5HxanuR1epRId21NaGNQ2UvbforW1tfR20PwNla1cX2L2+evlzDZQu25GmpuaD7sfBBspe6n1s6V2+7V1tyAc105Y+m5qaDrqv/78GSq+AnZzvRz0DZd8yf4I2/K88ULq5A6Us1xqxssD9bcN3X8t8oHT7MVC6YqCU5Wr/MgOl3I/egbJvP2pow7rqepORfojIZLWfo1gOct/79tdwB6/9DthbXLzYvvjsMuFjzPixYNtMS+AFy6676rvCR5E2g33H33wS7B233wD2+R/9mPBhpongzQYTwkdRQlE62A8/42j4bkD7ALH8QxtuA/vez/w/sH96K37vnjBX+BiHzUw3ffVasOfNnwf22f/nbuFjD00Fe/fitWCvXfIK2JeeOlL4WDh/DhERxRmDSD2UbZtcu9p1bXbD4Bek1uSwOzM3xIupbOPAWAjlnZwR4ACVsMJ5U/7+OPQxSu7H8q88ictEvOgZK/JlyB9sJy2uZUe116TOwk5istPzELFddYOMO1kHf7ADdhkUWZ3xJMwYWlibmTYrthazwdbZJ1zEFbzJCV027rDfvMCTA6WbFjlzKxkFzA+AznoriqLkoAOloihKDjpQKoqi5FCXRmmk/4iIeNhRaEjdydiyDeypHcPBPnM3Tu5csvAU4WPk+vVglwn1pFIzaiNRkYm9RHR3yyCwt7ah9nPj7XeAfcHHzhM+KmHv38a0IbLi6oeIelbhdp3Rx4nFe7qmgf3RBbjMnEs+Avain3xd+Nhj42m+7J7HwP71GNQTH9u2RvhomnQybnfwGLC7914P9qkn/0T4iIOqJmWSnHCqh0JYpkI6eZC4KFLbgfztr1jYZ+wQlykE3bi8I7VBN0QxNCDUvmw2YWSV5QTBidd+AGyTaXZUwHZ57PIHhI/AttK/jcVABolLQVI9BjfCY/NdefxmgNd30cI2jNmk6t2bvyV8rHsb9cUr5vwd2G6E127o4qQLEZGdoK5oEJu8sZmu6UsfPV51nR5XTpYdCL2jVBRFyUEHSkVRlBx0oFQURclBB0pFUZQc6prMScKYkrAq6oY2iru+3y2WN8dOBvvVNT8FexITzec2vyN8XDV8BG6nghNAHgs4/dbAQ4SPb5RQaC6wAOxpE9vBTvjrQERkpW+3WHUkIMgiNqsfIqJXnsFXOA8dvVYsf+kUDEpvvv9RsLefiAH5EwuDhY9CeQfYx336/WB/5pU7wf7c+xYJHz98fh3Yj/n/AfYvLv8q2GGQ8WZO+taQYTXWhhUqUIWqIj2fQ7QsGUQcsTePXAeX2RdiMHWzfDGJYrYhm03ExOx6MHwZ6LxyKevfbE5r+nzs61EifZgHeOGjXmyzQrZZnTyJ2aSGU5YNYHp4fKGB53D1K1eDfXjGXNPMgThG/PKhK8H++GnfAzvOuI/rYcddrMTse5wwynjJiopBNSo9CHz55QHQO0pFUZQcdKBUFEXJQQdKRVGUHOoLOLcMMtIsFFGMIoRjNYnlP/cpDChdtHoX2Kf/281gP9uBGg0R0dad+IK+76J+csOoDrCvXosaJhHRcYTJJj5/CeopF1z2QVzBkL8fVqqrZWXnqYfITyjyq/rWL++/B7675hfni+W/bL8Ktm9NBLt5CO7rG2ufFz6GDsSA8sJ2FG66x08H+/kxq4SP7541GuxffemzYE/88nKwrUS2YZJm9kkyMvzUg0sVcqmqr3WzLDVOLM+PHbNMPwlqcs0GBilHlhS2ogr2b8dAn1GIwdLbNsv0XoUB6NdkWnjrANTsrYwsTpTOEVhh7Qkdsggil4Ko2nYeS7CRFKTAWGFZeFwH1zF2skDvijz/bksL2MfPmInbZVl8XD+jn7i43X0sA1Nzguch7pFDXFCoHktgSw34QOgdpaIoSg46UCqKouSgA6WiKEoOdWmUcVL9EMmEuZtLMo7yrhPmgz0zQs1ht48p2KPybuHjLAeDHtu2Y2KNnWNQo+wcIeP3HtmB+3blrd8He85pGIt4zBFYGoKIKKQA/vYX2zHIdtLEItFe+G7TzhaxfGElaj+DTsfjvaYbNdlbSjuFj/JwjBN1WCyox/S15gSTiBAR/eqHvwE7JEyK0VNBTa7NxXg2IqIw7TwN5j6msuWQm2b/bmL77ttSo/RYl9jLen0Lq4iS+PKycGx0MuuqE8BuPgTX6VmXka2fZeM+ZA620Y5SCezj/hkTkRARPX75YiIiCozaEzpk4cXVDxFRwBKLWGGGvmiz42FzFMvu7AR74cWHCx8xy7RvbMa4UmM0arIhL5FARCE77mYW89lTQN2xmFFOI059GEaGBnwA9I5SURQlBx0oFUVRctCBUlEUJYc6i4sZ1Q8RxTGOsSuWvS4W387ew9y6CzWKX+7ZCPZ/TMHEpkREz72OyX8tE7WSweuxcNgNG2Ts1ZYCimKd7D3dy8+9AOzHV78hfJhpHKVZW4XfAxInCcWpSNdcxKSipRYZ13XMMahJPv/m22B/fSK+r+pZMlHpTv9xsLe2oL42uII6zuDDpI8lU1HHHbwG4ztv+P1NYH9x4eXCR8HrbbvGRMpC4lMh6T1ujE2MMgVQPGfFBPXUKMHjta2MolYmLjP+SPT57AOoNyeB7IfNRdTX4s14DQWsXcyMImyOEcLffmMF1U/VG3y1+v6vicVNFuPojMdSyy57F/7lu98SPo5ZhDHA+8rYd3klVzMjjrTAdN6yh8sUWcLksCBjYu00PtOu/VVvvaNUFEXJQwdKRVGUHHSgVBRFyUEHSkVRlBzqmpkwjf2JQ20XA3vbR7aL5ZdP/zjYex78N9y4jcHk21bJoPXDQpzMcRMMON+zB5MVfHKwfKHfSLaDvdDCCZGOq3AiomsfCvNERAPaqsHxJjVW/c61bXLTqojfuwED3wNngFj+ZAuDch/cholFhp95Ith7t+0RPrYuxcmao3bhdpde+An0+dLLwsfmXfeB3TpsGNiPrHkW7M+eKJOTFJ0hRERkxI39PpfDArlhdXKF595oMmQfCggDu50IJ00CBycizFBOZk2+ai6uswt9WCzBhZURLL3qdkxYErOdt0zcDysjJ0SY3tuEdc7DcvzYIT+uTuJ4hBMi9925Wixv9OD1/tGLcOKxfRJOOlquPMevLcXrcMIEvHZ3vvgdsJeG8gWMDxxzCe6XgecqJp6sJCM5i1f9v8CrvQ31jlJRFCUHHSgVRVFy0IFSURQlh/oS9xoGGWngacgSXBgZOTBnTUZt7IcmBske9hTqPN/au0L4OM9H7eOE6ahbnLEe9cbpgdSX/snARBH3FvCwg1u+CfY/r5HB8//zmmuIiKjHbywpRtgdUJgmGNi9EzWbk+cPFcv/oBXbedJ5GAz9xm+xYJvXvVX4mDQFE1hEs7ANh+x8FOxdpkysMW0B6nx7ItQCH+taD/amHVLnHdycarANBksXzG4qmNVzGNoYUBxWMpJFsGTPcYDn3w7x+yNvxIB8IiKbBVwTSyyy5v++hN9nCIyGj+u4LFi6XEEd0HIzinzF1bYzqbE2tK2QbKvqI0lwu9MGSW1w1llYtM9kweLjZ2ChPMuW58FiL3o4Dp6HuIzHdPJcTA5NRBSwFwq8kGnSrGChbcmocitN6OHEFfHdgdA7SkVRlBx0oFQURclBB0pFUZQc6gvGSpK+rKsmK8A1fuRwsfiomCUjeAFjIttPQt1n2hhW5IuI7tz4Jti/KaOecm6ACUJnTZfaWPPOIWCPXrcb7AED14HdXZBJZ8vdVS2k3CPj9OrBtKsfIqKxTVj0654Vz4nlJx+P8YrFoaj7DrkQNZmuveOED4MlWwgLGJs5wFwrd5LRsx2T+e4bPRjsEc9iEuYZ41EXJSKKg6rmFjdWW4wSy6UkLSpm+EzXc6TzOMTjn3rRh8EOBqLubPgyVtaxsL+vvOlJsEO2WTdD/4oc1M8tFvMncvEGcj8ip7ofkdlYkTsrMsiKqv57mP76wR/9SCz/8NV/C7Zv4PbnzZsEdhDI4moG2+eAJf0YOP+LYGcVoXNYwuCIFZczmIYZhzIpRu/tYURaXExRFOVPhg6UiqIoOehAqSiKkkNdGmVC+1Ouci3g5hs+LZb/2arXwHbPRa1w+rH4bvPpBuptRETz3nc02CYrHlVkCVWb3YHCx8p7fg/2yKOwENKDG/Gd0zkZ7yIPHlyNLbPtxgS2ShJSJanGi114+WXw3eP/cLFYft8OtIvDUSPl4XpeS0Y2UqavlT32jnEPnpe2Lhkr+nYZ12lqQ6204wXcUT+UPry0IJhpNaavhYZNYZrk1TRRC3zkjW+I5Te8/QLYN1+FcYKf+fvdYMddcv9efQqTHwdM1yuG2B5lV/blgo9xgn4BfXghbte3ZRyll7677ISN5RzwLZP8tF8UY+wzUSy1uza2b7ubUfsb5OJ1+L9vxXf/iYi+tAhjLWN2DDt/9RWwjbKMFd1dwjmIdbtRC50YjQB7zDXXCR9GGs9tmbVn7tU7SkVRlBx0oFQURclBB0pFUZQcdKBUFEXJoc7sn/unc7jMvLlLvkjvjsClnM4tYDdNnIDLy5h1GrkXk0Ds8HACqNvCSQRzhwz03eVhoPuqZ44E+4rbMHHviLHjhI8wPe6wwQqCpmmSmSZUKGC+D9qz6BWxvDsQE8Y2VVC83kc4aRIZMilI1IaRzCYL/O12UDS322S3GLsFj/vNVTgB9PDdT6EPnlGX+gp4Es8vUS9OFJOTJmV55dFvwXdjIpn8eMBebMNyBSsEHncxvmCw5mWZWMRgCR2KrHKl7+HkTSGQLyaEBVwm4UmgYzyXiYhAJ6I0oUrf336SGAYl6YkoJzgx45jS97HfxOQrFrsOtn/9b9B/RsLcb/8M2/3KM3HipdKNfbu7W/qIWAKPQz3sh7aNE7MJyYm5JKr6jSMNOFcURfmToQOloihKDjpQKoqi5NDvxL0Gk+q+/Y2b5QoWapQDW1CTuHHXKrBvtTYIF8URqH2u+t1KsBcNxyDWs6ZgMlwioqce3w32zI9j0bOJk2aAXYlloG9vQSwjbjAZgWmRlWqE09tnwndz7j5bLP/wGYvBfmMsal/jZqAml7io6RIReSX8PTRi1OSGsvwFbz+2W/hoY/rZ9ZMw2bHN2sw0pb7Wq4klDYqUcRJTnCZUMHrw2FqGSX2ttYDnbOU23P7ZJmqHZ33jt8KHEXFdF7XgpgSDl3sS1MqIiIosCN9mOl+Ph21WjGXAdRBVlwmCDP2yDgpBSIWg6p8XVwszkkW4Pn+zgV0jLAHOF6bIgPskwWKCLz6H8wkR07XnnSQTa+x7DhMI+yVso9AqgW2Q7A+VVPvlLw0cDL2jVBRFyUEHSkVRlBx0oFQURcmhLo0yimOK0qyrFkvcyxObEhEFrAj8heePBbu5AxPI/uRWqVEuHIPJNl7d/BDYTz+C8Xsrr/iI8GEQxryddcVJYPN377OUizgtTh8nDWadNZI+gddiGWyv+uLlYvFPbD8L7Iu+exHYO23UynY1SX0pGYT6WiHEdfbGA/H7Re8XPq6dgIXnR49ArYhrt6YtW7GxCNR3+bFNSuxqf3u4jIXQPjSiQyz/6oo1YB814VCwFz+0DuxFGR2gwna+KUB9reKyJNWszxERBawQWsKcFkPsiD0ZyUMKaUGw3sJg/aXs2OSmxb14sg5yZSxyxWMNEOGxDGVxlo9egIl+iaQ2LboIu7aee1j2ZTtGDdJliXmPuetGsONYnkzPqmqfniE10AOhd5SKoig56ECpKIqSgw6UiqIoOdSlUcZRTHH6jq3p4rN/kFGQyWaJMTtRTqL5Hag/jhpzj/Bx003X438YqElsZgXQ7/Ww6BUR0eFTcd++fDUWVr/r5yvArvCYMaL9Pylmgy8qk9nnzDbxWKafMlssPdPC95Qvug01ys8vGQp28p2ThY8rvnQt2J885jSwx/4WE9s6nfg9EdGCwzGBcuJiHJ3n4LHEcYYi2atBRY3pvEZkkpG+r7t3OyZy/eVLm8TyE1tQ6/vRc9gRr7nyFrArlQxt0EU9q2ygfuaV8XgTT8YRWj6LG3TRR4V1rWJZxvNSKi/zd8/r5d1xlN0ODgOGkZEwmMWRGiFe24GH53/+z28TPuKEbYfprEaE30c8KzUROSypNn+l3AzQZ1aS6CBx4W8t6B2loihKDjpQKoqi5KADpaIoSg46UCqKouRgJEmSqwqXSiUaMGAAPfzLR6i5qZqk4vAzMGi3mAwU6wUsKNdiQemVAEXjnp49wkdrG07OJGwiYNdOFPPDHTLpasVDQXfYIExYUBg8GmwzK+Q8ncQplUo0ZNBA2rNnD7W1tcnlDkBvG27fsbVvPcNk4nXGqbAt3JclDz0A9s233w72bTf8RvhwWcC5341Bu6UytmGBTcwQETW1tOJ+mXguTbaf3YRJmomI1vx2IxER7e3eS/M/ubDfbbhr584DrhdnJDowc7p4zPbd5gHYRBSxhBVWmSVbYIk3yhnbLES4nQphX3bZ3EXoyokIo1Jt91KpRIOHDeh/P9yzq289x2eTNxkvC0Rs8ijmL5ywgoZRxr5TgkH4RowTXobNnFTkZEvk8KB1VkGSJeNNLBk8b6dB6qVSiQYMqa0N9Y5SURQlBx0oFUVRcqgpjrL36Xxf9/5ch6VSFywTZNRIqffRu1xGn0RECdtF/ujd1YX5F8O9+BhZ3Q4+BnissHzFxkfRvEdvov1tUiu9y3e9q9368+i9bx/GjQYBPgKWuvBYiIhc6+CP3l0+tmFgy8eVkMVF5j96y3O5t7t6bnr7UX/bsPccZPEXe/Su5D96V/6Uj95djfbD/W34l3v0xjb8Szx6E9XYhkkNdHZ29lYV00/66ezsrKXptA21DbUN3+OfWtqwpsmcOI5p06ZN1Nra2pfh/K+VJEmoq6uL2tvb+6op1oK24X60DRtH27Bx6mnDmgZKRVGUv2Z0MkdRFCUHHSgVRVFy0IFSURQlBx0oFUVRctCBUlEUJQcdKBVFUXLQgVJRFCWH/wQqmobSYuWSvQAAAABJRU5ErkJggg==\n" + }, + "metadata": {} + } + ] + }, + { + "cell_type": "code", + "source": [ + "# mix of defined context\n", + "ctx = torch.tensor([\n", + " # hero, non-hero, food, spell, side-facing\n", + " [1,0,0,0,0], #human\n", + " [1,0,0.6,0,0],\n", + " [0,0,0.6,0.4,0],\n", + " [1,0,0,0,1],\n", + " [1,1,0,0,0],\n", + " [1,0,0,1,0],\n", + " [1,0,0,0,0], #human\n", + " [0,1,0,0,0],\n", + "]).float().to(device)\n", + "samples, _ = sample_ddpm_context(nn_model, noises, ctx)\n", + "show_images(samples)" + ], + "metadata": { + "colab": { + "base_uri": "https://localhost:8080/", + "height": 188 + }, + "id": "YsJoy_Ahichx", + "outputId": "c04e7ad9-9ba9-48e9-c578-6c17b60435e0" + }, + "execution_count": 42, + "outputs": [ + { + "output_type": "stream", + "name": "stderr", + "text": [] + }, + { + "output_type": "display_data", + "data": { + "text/plain": [ + "
" + ], + "image/png": "iVBORw0KGgoAAAANSUhEUgAAAUoAAACrCAYAAAAARtWWAAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjcuMSwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy/bCgiHAAAACXBIWXMAAA9hAAAPYQGoP6dpAAAsUUlEQVR4nO2dd5Qc9ZXvb+Xu0cxIQigNikhgQDJKYLAIEiatyRjb4OOwj2W9YGxM2HUA89acY7/1Gvvts2HXgNc2GHsxxgaTTDAgosiwZJTTKGf1aKa7K74/qmbE994S1a2Wd3nP93POnJk7XX2r6le//nXV93d/9xpJkiSkKIqi7Bbzv/sAFEVRPujoQKkoilKADpSKoigF6ECpKIpSgA6UiqIoBehAqSiKUoAOlIqiKAXoQKkoilKA3chGcRzT2rVrqaOjgwzD+HMf0weaJEmop6eHurq6yDQb/57RNtyFtmHraBu2TlNtmDRAd3d3QkT6856f7u7uRppO21DbUNvwA/7TSBs2dEfZ0dFBRETd3Suos7Mz/WeC30Z+zreTG8f4DwNHbfYqnXTJBcJHHCdgJ4aFLi30+diPb5A+CN9jGgH6jBx83YqEjyhJfVQqFZowbuxAmzRK//arupdTZ2f6txHgcdVt+a3mGSEeK3udX8Kpf3ecPPYSb0O8VkGE7bH0uielD9bOVsiuno0+w0Cei+2k7VqpVGj82Al73IbdK97TD7EJKUjkfh3WhyKD2TE6cW1f+Eh89Bu52O5WVGf78KQPE6+lE+F+q+wzVc45jjBxiShtw4kt9MPu7iUD/TCp4rklZTksBHXcxvPwfJMQz/ebV14hfCQG9hmLPdCarA/9r+98T/ioW/gep14DO/RKYLtxRfggSvtOpVKhseMba8OGBsr+W/TOzs4/60Bpuw5xosKBEu2B44P97L2BcmC/TT627GrDjoFj/HMMlKbDRg4iStz3HyjNCK9EXhvuzYGynz1vw87/LwdKp4mBsp+90Q8TZ+8PlJ6Lx0lElJhsoBR9F88lrx/KgRL3IwdK4YL6B8p+GmnDhgbKgYOILAqzi5tY2NncnNwaARtcLNZQcy48H+yDPn+O8HHAR48GO2H7Wbt0BdjHf+VvhI95190MdswGV5PwAxzwTx8ROXG6jRXLQbQZEjIoofTCREwX8agutg9D7AiWhW046sIj0YeHH0YiInswdh6LzeFFlSrYY//ucOFjzU9fAjuwWRuyz7TlyuMwQzP73Zo2VrcsqmfX0MTvPHIs+cm49ci5YLMxjy54bh7YSSQ/5DHzaxLbcYIDhWXI8zci/LgFbMAuW9iIkS9vHGw37SO2IftKM4S1EoVu2i+CMr5Wrss29Fzs99/+2pVgxyGey2fPkE8206Z9CI8hwn6wdsNGsC++CMcHIqKf/OTnYAcetpG7k3XEdnm32Nv/u4mpbJ31VhRFKUAHSkVRlAJ0oFQURSmgKY2SLBoQzx2m1UU5E+iOidtMOftUsI/5ysVgT559vPARxH1gm2yGduzIiWDXPys1ysmnfxzsJfc/ADbT1MkJpb4UZSJyZEr9shkS36bET33FbJIlCuQEgM0mlkZdcCweK/uuGzZGCuD86zBh02i1PtR5jJwJkX5ddWC/MWsj1iz1QPYH006vXZIzWdYMXhSSF2X7ZzLe7X/zWbnfCM/n/GcfATthktyWfzhJ7hRlXHJDPOGwws4plJp9dSu22Zh5j4EdsDaOTanlWnF2wrHUL5vBLtXILqVarF3Hc7nkb78gtp99KOqL5x8/A2yvrQ3s9mH7CB/+jh78R2kQmMP3HQn2NVd+Xfg4+eP4WX7gkT/hBmXUl3fmdLX2MP1nVG+8H+odpaIoSgE6UCqKohSgA6WiKEoBOlAqiqIU0NRkjh3HZGerbSIuNEvtmiae9jGwp56AExFltppj7XwM/CUiGj0LA6qJBWCvXfAq2G2mnBCZcdqZYB9w4mlgL37kj7gLW07Y2P0B50lrExGWE5LlhJkvvrpFXo7h5x4BtjkMo4NZDD+te5sFQhPRuOloV9lEwEeOHg72i0+ukz6+NBVsN0Qfi372GtieLQOioyi9NnHU2kSEb9jkG2lbOWyxwII3V4nt5xyGfeLXp50I9tGjcZJhzBicZCAi2vxh7ODGEDZpZuC1s3OCxUs/WwN27ewD8fV7FuIbInkt63bqt95EIow8qlQih9KAc491O8eQAedHHYqTpnxFXBzgddi6abPwMXLsePTBVu7tqOFyRMpZqXb5FZeC/ZlzPwX2b35zF9jtVi9xapRe31rO53x36B2loihKATpQKoqiFKADpaIoSgHNBZwnxkDWoJiNsQ7PFEREJtMxuIzZ/fSzYH9nnAxSvcVfD3bEMhBdW98G9jnvymDxUnkI2EHItB8mt9o5AedxppfkBQE3Q2SkP0RECdNbeXA5EZHj4f/CGI/dZ1khTC5aEtGqN/E6DJmIPuu9eBzO5pykIJNQt0tW7QDbEmmNpA/DTOD3nuJGCblR6iNike4XzNohtn/HHQz28A/h+bsJBjr37I/JGYiIykwTdHzUNW0LddAgR7QP/3o0vmfzQWBHLNA/cWUbelmWIi9qLSlGuR4PJL84+URMgnLTlReL7WOWYSoJMPlEEKO+2FOXx7d+y3Z8D8sEtr2KfbknkTrvqC5sQ5cFmJ/3mbPB/u1//EH4KDnpIhbf6hOv7Q69o1QURSlAB0pFUZQCdKBUFEUpoCmNMrAMCqxUYOOa5OhzZUKLiWfjwvm+EHWM+6dibJb38cuFj+sf/Uewz/SXg+2ejok8q/OuEz7Ofx1jDw8/HROCHjjnE2AveErqGmYmLPIEC81iBTZZ/QkjHNRkDrjkBLF9aQjqa30+akPnH4vX4c6nZVKM2adgOvwXnkM97cWbl4J9xlX7Ch/PPoParD0Gt5nwpY+CveSG54UP20iPnWeYb5bYCSnOYlFNpuvZltS1xpmoJ/bZeP7uocvANnkmWyIiJlsHIctobuEGnSOGCRc7t2JMXzD9LbBff3wW2Ece97LwEWRaaGDJeOFmqHkmuV6mu7P7pedef0dsP2oExtqOaW8H27QwOfSdd7JkFUR0zicwfnmfESPA/vnPfwv2iSfPFT4WLsFrNWP6TLA3bNkCduxI3T8M0gQefiDnInaH3lEqiqIUoAOloihKATpQKoqiFNCURmnSe0ZWNsTyGDAiIl5vzGcxfqsWLQY7rxraX/uo9Th2F9sCNcqNyxaTIET9zGbJd21W/c5IcmIl+zUoq3FdIw/fSX+IiBxWbCqnCclksWbEKgrGFTzW/YZJTWbZ4iFgD+5CH+s3oI/H79skfLQNxlhDHguZ+Lz4liTOYivjnBjLZjADh8wgbUQuGSe90nf7fqgnDutgHXM7vifK0bUcC9uIx426rCTw+iVSXxw6C/tywqo/lmxWcTOnG/ZvYrfWDakURFTK1me3lTDpbjVHQl61cTvYyzdgvGofK1D36ALUvYmInvrBjWC3daDOuoXFWc6cNU342H9/nNf4yNGYP2LTBuy7dpiTmyG7vq4jq1zuDr2jVBRFKUAHSkVRlAJ0oFQURSlAB0pFUZQCmprMsaKQrKz6XWjiW+OcQF8KcBwOXFSgv/XLJ8Be+AO5GH/9uygKJw4K4MtuvBLsb9x8t/BRPREnc0yWFMOx8DgNnlCXiMKsqcKcapPN4EYhuQNtyCr59UqFfuhoDOQdN3Iy2D9/8BmwD5gtj8+us/PZiec/6nQ8jurDMtmpdwpe37iOx9pX3Qm2acoZgTBLGNzqZE7oxBQ66TnZPPfxotFie3M89iGDTcTEITqxDHn/EEU4KWBYfMd4TqMPHSV9sEUaIbP7lrHZm6PkcRhmCL/3lJplkZslremvMNpPPaiJ7Xm+FsfDZBS/f3Q+2KOHywQ3cYzn57PkMx3tOKm0ZaOcVDztbFwcMnwoJmuZ//SjYB94MFaPJCIKs0m0WuyK13aH3lEqiqIUoAOloihKATpQKoqiFNCU4BYaNoVZESWbJTbgegMRkenhwvnh13eD3Xb4J8G+/NbfCB8G00ZiFsVuOXgKB05GzZKI6PU/PQK2e/zHcYMa045ykvPaWVYE22hNG6pbJtUzTdSKmGZTk8lOS+EQsN946XGwE2KB3pFMGLtmPUvgUGdJV1kvKB0nEzps2ozJbGs1bAczZklnY9m1zEzoMnMSFDeD7Ztk+1lCBwfP3w9kELHHFjKYrN1NVijNyIv8Z5qkGeP5R0x/W/saJpwmItr3EJZshAXt+5uY7mnKdgqyj2xgtqaVl4w6lYysXzBNNsgJ0t60A4uFRSwafvI4XAiyat0G4cNwcD8R67uOjZrh2ImThA+Xtfu99z8AtsVXuURSDy8laR/xYw04VxRF2WvoQKkoilKADpSKoigFNCV0GEb6Q0QUx6jrjN9P6lrud58Ce9KUj4E9yjwa7NKHMI6KiChhSSBCrp8wSWLkUCyyTkR0ZB0L3s+7+w7cL9OKzFjqfFHWVFGO9tYMXhiT1x+3xzQoqyb3u9lcBfbwKUznqrLC8yHGMxIRuewtFhMlQ5acwzBlfFmENdzIDriO1YDOG6b/M8LWCrSRG6Y/JBNL1Hpl/GZ5C/bVnStRO++YhO8JI6lDGyxOssrO33wVz3/YqUOFj7D+/kXtDP6PHH3NodSHE7WW/LgaeOQEaVKKPzzwMLz2hTNPFdvvPxI/3yGLPeWxqOVyTvJjphUfMWcO2LNnHwn22P3GCBcL3ngd7DZW9CxkHyE7kdcytB343Qh6R6koilKADpSKoigF6ECpKIpSQHNrvZOArCTVRmK2TnnYi7Ig1VlTLwPbZLJFwoqmHzL8MOHDrDKNgemHpo26h8MX/xJRaRAWMfr02C+C/dCa29BnTvH6ONMTjZzYtmbwLYv8bI0tTxBsuVK7G1zGYzHY+lQT5TYKcrK9BgErXm+hvmWZ+H3Zt1yusQ1qqB9bIXYdLqc5idTQfDPVhAKzxe9n305/iMhi1zsePERsbi3DNhxUwoM1VmDHtHksHhElbK23zXTMgH2U1v1CXoelfZjstusEPPZtz6CumXxdHkdoZG2Yl1uhCcpuL5XdrB0CXC+9br2MAe0z8bN6xCxMmHvgGCxgF8Y5cdUsXnPEgQeDXd2OReDs8XkaLWKw3A+//c29YJ94zqeEj6RmwO9G0DtKRVGUAnSgVBRFKUAHSkVRlAJ0oFQURSmgqcmc2DQpzoT4mCVB4IkGiIgOH42L2v9z6yKw73jgWrA/d9Y3hA+HB4zypLpsv/fMu0X4OP6wC8A+atpJYD+67ndghzmJZe0kS+gg9fWmcMOY3Cw4N7FZUoQ+mTDXcXDyxmCTCLUAfQQ5x97G5p8qRYlq18tuYYkKgey4WAB2kpPg2DXTJASu0XgygjxqFlH/PASfF5r03B/E9k8fMhfsQ/fnSTBYZctYTtglvF3Z7JXNkmIkUU7yZx/bbNFDmKzEaWfJOmL5mTKzNnSS1tqwWhtEjptO4ngl9JXwpMREdNM1M8AeNRwrKN7yr3zSVfqwWX/f9BQmq+nZhosldjyFgfBERJOOxUmkg/Y/COxfse0DMf1DVHLTipGOWxWv7Q69o1QURSlAB0pFUZQCdKBUFEUpoCmNMoosijJtxmLJVx996Tqx/QmTLwI7JtR1Lpx+Db5hpdRkqlYnHgMrBBazSlEnzfya8PHQoh+C/diim8G+42VUNizKWUhvWNnv1r5bfMckP0tg6jJdb/Xd74rt9/sk6rylwVhsbPQILOJU38KyVxBRj8GCw1nC0loftmG9Paeolc80OqZzrr0fjz0nnwNRkGp0Ydh4Uac8SmadSmam74WolcVWTrB4G57/C2tQm0oSHtQv9TWhW7ImYusvyMwJqjdYUoh9XLx2Rz04D4/LkImc4yQ93yhprQ3LpZ1ULmXHWMVVC/Pmvy62/8ynMDj8gKOwyN09f1wJ9uXnniF83PU8zlHYq7DPOGzxyLAtMrHI2yyp8o2/xgBzi40HpVC2YZStfImo8cQiekepKIpSgA6UiqIoBehAqSiKUkBTGqVj1cmxsmd+phWEsYxXemTxjWB/Yso5uAEv4sRjJIloefgKbsJeHzcUE2kkPHMnEfVsQc3xkeW3gx2y7wsjJymGlSWbsHKSTjSDGxrkZolreYGoPM/rfrcE7CFTDsQNDtgKZp1k8Xom21DMEjzwpBiOLQXGyMc27H54Mdgh05+57kdE5NqpJmTbrSWdrRse1Y1Uq/N4cgpbduk5Lz0K9iMz54LNdW7eXkREDmsjk10t/hYnx4eZ4Gdk9jOoSZomtksSo/763m34ts1SpXZyKNUm+Ud3UE6c6zeuvBXsy664HOzTT8Sk3Ms2Y5IMIqIhhNqwO2IU2E7Akk7nFCx8a/6L+J4RWNTs9nvvA7tuy3tBk/oT+6hGqSiKstfQgVJRFKUAHSgVRVEKaEqjDMijgFLdxGFaoOFIfdFiSXZXb8GEoKuGPoQHU8IEokREz69HfclgWljgscy1OfraujrGeHGZ0WHxjHVH+vCytd4GtZa4l8wg/SESwXe2n5Mw2MXvsn0mYBylE6BtbkCbiGjBKy/hfiL0mbAF7FIZIrJYbCHXdd0aaxcvZ61z5MDvPcXzQ/IyzbTmYRcu8QJeRBR4uL+TXnoc7IjF71k5xc8SmxUTY20YWCyBdE6BOl5wzY5QC4wijI1MLBkDaGexqEbQWpG7cq9PZSvbP2ufo+ecKLY/6TjUIMd3oTa4aeMWsE1DXv8NmzEhtMEK0nkGfh4cK+c+LsJ2veM+1CRrTJMsVXPWxGdFxZwm2lDvKBVFUQrQgVJRFKUAHSgVRVEK0IFSURSlgOYCzqOInCxYOXSYMJ0zAxAxYfXlDfPBnr0vCxbPSSTw7IaXxP/ey9T9jkYfOWHb76x7DuyYCfHk4MF7YU5ShGySKDbysj00Tpw4FGeBx06CYn1iycthMPF62R8xYcGYKWPAXvl2t/Bhhni+PMeyaeD5u4E8x4BVO3TZnEnI5pDMQE7Y2Fk723mdpQlC16Iwy9xbCrANY08mi3B8Nmni4jY2C6aPXHn+VoCNxvu/4+NkT5RTUZOfdsiut2mwhQB12R/6i3DGrc3lEA0y0x8iCtikCk9kTEQ0dNgQsI+cjZ+766//N3xDzuKRBe8uADtmzXzowVPBNnOO442Fb4Edss+ySyz5tYvVQ4mI6lnilHreqoDdoHeUiqIoBehAqSiKUkBDN/D963YrlV3rN0OrkUdvVgOE3UqHCauzkhN79t595hGxQut5j97cB3/0tiz+TCS/P+JMRuj3lbeW+f3Ia0PTZI/euc9TeD48HWbMarPktZfJkkOKR28Ln6PN3EdvvFYOa6OQPWnn+TCzR++92YZ2gI/VsSMfvU0fz48/elsNPXqzvuvg+dsNPHrz+MyQ5c402SJzU4ZRDsTV7s02DFgfC3PWWFermEPAYO+JIh5fXPxZ5o/e3EfeuXEfUcTkCwsfvc1Inkv/o3dPM22YNEB3d3dC6ZnrT/bT3d3dSNNpG2obaht+wH8aaUMjSYqH0ziOae3atdTR0SGyNP+lkSQJ9fT0UFdXV24W692hbbgLbcPW0TZsnWbasKGBUlEU5S8ZncxRFEUpQAdKRVGUAnSgVBRFKUAHSkVRlAJ0oFQURSlAB0pFUZQCdKBUFEUpQAdKRVGUAnSgVBRFKUAHSkVRlAJ0oFQURSlAB0pFUZQCdKBUFEUpoKHEvZqaaRea3qp1tA1bR9uwdZpqQ032+edL9qltqG2obfjB/2mkDRu6o+zo6CAiouUruqmjs5OIiGwTU8PX45J4X8nCNP1hwKrfOZii389RAiwf88VbLu43THhFPVn9LfRYKQQfv0k9l1XpYxUGU8fpfiqVCk0YN3agTRploA27Vw20oRNi6nvflpfDZdtENmsPVgrioDMmCx8RJfgPloLUZDcWC/64TPjg1e4cVj7DZyn5HUs2Yn/1yZbbcPmufuiwGiSJL++SQheP/WuXXgr20uUrwL73nrvksbOHLyvG6xKbeF0+dsxxwseU6VPAvv76H4HNr0MtlP3BzqphVioVmjh2/J634eLl1NmRtmHMPnd33Hu3fF95ENinnHIC2HGM/dCycj7LCTtBZicm66c5N7xRgPs57dZfgH3LzNlgj5p6kPBhZJ+hSqVCYyeMa6gNGxoo+2/ROzo7qXNgoMQB6r9uoEQf/9UDZT/NPrbkteGfY6A0bdmGSZMDZf/xvZe9OVD2s1facA8GSpeXq2Xtnnf+zQ6UliXr7vD98v3w6+C+z0DZz562YWfHrjbkA2W5LEu8trXh//ix/3cNlHa5DHZHe/v7HifRroFywG6gDZur623UyDHSix0YODCWoqrY3jfwJBKHnWSAu3etnHraLitaxApI//OMY8H+1mtPyAMn9OGx2tx8bDUMWaDKTtLzMxN5ns1gBwbZWY3o0MHjcAN5/sQKUFlskBt1HNb1nnDMCOGiY3/sLLaD7d793Dr0efJI4WPDQ7hNGPFrhwNjEMm63g6lg5oVt1bX24kjcrKBqs4+5KYtO73D2tViXxyPPfIw2Gecfqrw0TUM2zVin+I1G9aD/cxTDwofl3/9arCNEK9/ndUKL9msRjURJUF6Z2cFrc3DmmSTmX38z/zyN+C1h//H98X2+xs3gX37eZeA7Y7ENv3p//6x8OGybwL+BdWI1rpx+xawn3myC+zDF2Lt8H3+hMdJRDTvmfTa9PQ0/lnWWW9FUZQCdKBUFEUpQAdKRVGUAprSKH2jRH6mTbq9rKj8IKnruT6r4G6jbnXBxZ8De9LIPuGjHHhgX/pPqJVMTLaCfdlh04WPH7/yGth1BzUyz8djj2ypXURJOfudN9PTOIYdkpFNQBgR00od+b3FZF3qmoua5P6zUTvrnCSFeIvPEhjodBLz6Y1BQZyIaMjcsWBve2IN2EkD37lRNsER5Ux0NEPiJpS4qSbm1Fkb5vRog8mlpov/eP7F+WBfedU3hY+ZM6Yzp7ijF+Y/Dfarr7wqfNRr2K98pkl6IU4q1kycaSYiKjnpZ8pw6uK1Znh72RZqb0/398pBF8Nr5dE1sf3iXx8B9tybjgT70WlXgH3UxnuEj50W9qtt7fjZdj08p0nVRcLHghdQb5/ScQfYvSffDPaKj3xa+BjxnXnpH3WpAe8OvaNUFEUpQAdKRVGUAnSgVBRFKUAHSkVRlAKamsxxqU4upYKr76AgnifmW8z7t8/CZV0//t41YLePPzBnp8wJW6ly7hP3gv3Zd1BUJyK67gvngv3lW24Hu4/NLbTlfX0E6eSBHSc5LzZOnWyqZ83u8dUMpgw4H3bUBLA7hg/GDdgyoh0r5ITYsPE4wWPWsQ03rsRJBteQky1jDh+NxzEDJ3d6/nM12E6MExNERJSt1rDYAoBmqZFNbtaGZQMn5kpmTjB76LF/4P5tCydVZhz+UeEiYdfdYIH/xx43B+wnnn5B+LBcPA4vwOMI2EKAUl1O2IRWOvEYhnLytBnmv/0clbLVNhu698H9XrVAbH9g175gP3s19rPlF/8ObGdozqKNYDvY1jacEKxNx/Fhfvhx4WPIgovAXnQiTrwNWfMu2DsHj5I+JkwjIqKk1kM7xKv56B2loihKATpQKoqiFKADpaIoSgFNaZRxYlCcZfwweEaWqtRTgjLqFHGIWtj2DUvArm5BzYKIKApRcyoPwkQb7Qbqb8FIXCRPRBQEqOOFLHFCW51rZlJf6cuCg/uc5rK1cLw4HtAmA5a1xMmRP9tcTAKSEOqYG1fg+Y+oDhU+jHgj+mCaXakbg3hX2zLguGMcalQ2YXICnpEocHKSYmTJQChoqtsJylFA5Uyb7SPcj2lI3zyRTcKy1hgJtulrr74ofBw85cO4H/b6vPnPge0YMpsW74cBu+AOW+cQlnPOJXNhtSaV00VnnzaQWec7154Gr7k/kmnmhl76EtgvXzceN+hCnbPcI/tQNcKFDdXlw3GDGgsAHyMD7nvXsH42B7XhyqZt+Iab/qfw0TUo1TGjIFCNUlEUZW+hA6WiKEoBOlAqiqIU0FxSDNMlP8tsXmKaZL0sdT0eJ/bKpgrYJQP1hvF/9UXhY92fbgDbCFDHsI85D+y+t/4kfJx5FSZM9VhW8MDD74u4Jr8/2rLSF2EitZemsML0h4gcljW8c+pYsXnboahBWh7qi9dMPwfsr/yTTJg6ZwTGXlplvC4v3rUY7L8/F9P8ExHd5qOOOfrYk8EeegzGHm5/6lnhI7b6f7cWR1m1THIy4bEtYmJdnJOMlWXf7+nFbdo9/BgcNO1o4eLRJx5iPnG/J51wCtjP55z/ju078R+8EovHKgLkxEraTtr/Wk2K0Rv2kpUlDn7ybiyncN6594ntS9aHwG67EdW98477Fdjt0Vrh41H7q2Bvm4Za8U4bk2bs/JRMoGyzWNRgJI4pbd2bwP7+Yd3Cx6Kpabv71YQWPixezkXvKBVFUQrQgVJRFKUAHSgVRVEKaEqjLNVDKtVTfa1exjHWq8v4Qp/FWlosca/BK9lFcq1z3zrUHAbvi1qZwcb67W9jPBsR0aXX3wn2XU+/ArZXQ72pJkPgiOqZXuS3tsbWj13y49SHyarOmTkFoyym48Ys+etdv0M9ya7JxMKfGHUw84nX6nET41dXs0JZRET2IFZwjK2pjky8toGVV+QrPT8zai1xrxVaA/oaEbZHZMqLl/BmZWvZ33kDNVrHljGgVsIq93GfrOsuW/i28GGwpM8Ooc7Iq5SWPLlevh6Ust85a+mbwKRdd0k/uur38Frtps+L7Z9kVSZn/hLFvf1KuD589WrMr0BEtOn6r6H9MOrpEwK8Dkd8EucfiIjeegSTLK+JcTzYNgx14O/Ok3r44UFa0jnIWUu/O/SOUlEUpQAdKBVFUQrQgVJRFKUAHSgVRVEKaC47gR2nP0SUMOGZHBmI7cRsHDZQWL3wm9eDveCnX5f7DFFwrWzdDvbq338X7KtvlNXf7n3yZbCtKk40+GxiqpQjlNc8N/vd2neLGyXkZkHSicOE5lAKz7UePH+XJZ3t3oqTPZ8/YpLwsXRBD9j8snzhVEyY+sJrzwsfGw1M3Ouw5BNJiBMgdk7+XN/B33uKa9TINdLrUXexH5r1nEqWLIPEzT/7Gdin/dVJYF933bXSB0v6EbGA89tuwaDtX9/+G+HjvgdZ0LqPx+qz/MJ2n5wQ89rSvuk5rU3mGKZJhpnu/+VeTJqyaZvc7yF9eNHWz8V++KtbsQ/RDpyYISJav5a162+/BOYgfzbYQ//l28KH97lP4T94IcXNaFaTnGTgo/82/V3rIaIfytdz0DtKRVGUAnSgVBRFKUAHSkVRlAKa0ihrlkOulWoVJZaEN3JyArFZ7PPKTRvA9jzUPa6+DQNfiYiiACN5DRY9HLHErauH8UJSRONmYcLQNS+vZlugJlPLOZdSPdWE/Hpr2hDZQfpDRMSSHti+9D2oiglzd25Eu5ud7lsbtwofwRbUOR0L32TbWChq6UFnCR+1pcvAbh+Jx8FrrsWm/A52s4BrN5FB8c3QZzlkZ/2wjeXACMo514cFuJ93FiZb4P0wCmUg8tEfO0n8773MvxP1x7yg9bNOwcQZ9977INhtxHT+nOD5emRkv1tb+FDprVKSLRKYMXQEvLayt11sf/wstDv/GfvMlK+ygPw+ef0/+jouBlnpLQf7wQfeBHvIFz4mfIybdRTYP7RwfIj37wC7d64sFBdNTjtr3Nd49mO9o1QURSlAB0pFUZQCdKBUFEUpoLmkGEFCpSB9rq87qJ+4oXzeP3g6Jowt26yo1QT0scqV2hVPlDFkGCsmZaEmNfNUGYs44gU8jglH4EL6FS+yhKq+jAlNSumxJ15rQYB1w6V6FgPI80bE8RCx/bbeE8EeMhXjwswRU8BeuPQB4cNgCYK5fFj2UJPqGCJjz9p2YDvv3LkKbMvEa2uaOcl5o6zt4tbasK1mUluWcCUoY/+weXwvEZ123qfBPvVo1K3OZK9fctmVwscVl14O9s4+7DO/+OMdYF995TeFjzfffB2P46yzwL7//j+AXbdlPGPJSrVBz+oTrzXD4A6HOjvT67DQXQev7ZtTP++Tvajblq9EzdaoozbuBtJJ6KCOOZb+EewLLsL3xAdj3yYiem0hFrWLt7LOHGJgZftEnJ8gIlrcnhYojI3GtXK9o1QURSlAB0pFUZQCdKBUFEUpoLm13kmc/hCRY6KecMhBUgvoHF4Gu70DtUHTRR/msJx9sjXFVUIttK0DdU5/u9QdQgcLqU85GQslHXwKJrZ96yFMQkpEFIbpsfphjoDTBF5SJy/J9B52bo61j9i+cx9s1yjC90QbMb6xbd/D5E5NfI/lsOBLFptqB3wBLVEyCNvd8ybg6ztxvbAhczAPLG32W/169oL0h4gSFk/4pUsvEpt/52osahWzoM/NW1Bf+/a3rhA+HIslnWZrvxPms57TTw49dDrY06YdCvbFX8Zjv+G6fxU+/Kgt+52zmL4J3lyxkQa1pzpnEODn4dS7F4nth16C59/XgxfRMLAIXs2QxxcZeK16+sbhBj5qmEM3ygTP/z58Mv5jMWrFlo+f/52r2OJvItr5iyzZddR4oUC9o1QURSlAB0pFUZQCdKBUFEUpQAdKRVGUAppLiuHa5LrpW0q8opwpXU2cuB/YOyoonq5cuQ3srrIUwGOW9GLCgZiMYclqrBi44iE59k8YhRM+Hkt64ZiYWcGKZPC8kWWidfIy0jZBPfSoHqaTKbaFvmJPitdGB07m2AZOtNT7MPA48WRCA17tUky0WLjf3gevEj7M6ceAPXQwzrytXY8B6CHJc3HjAH7vKTVyyaX0GpZ45hWenYOIDjngALAXL10K9nX/50dgf/USnPwhIopZUmWLlWG02ITYDbf8Uvi45Ivng33ApIlg88qOAcngedeqwu89ZelLL1NbOZ2ASRauhNduWPY7sb01dQ7YZx1yLNg9bGJ2kIvJKYiINtr4v4qPk71GuB3sp+6UixaW/eQJ/MencXFEfAhWbT159ZPChzPmRSIiCoKAHn5DvJyL3lEqiqIUoAOloihKATpQKoqiFGAkSVKYvbJSqdDgwYNpx/aN1NmZBo3XQgxadh2p3e1/ABak4juaOeMIsNtKMlGpywN9Pdzv+q0vgP3RmTOFj/seeA2Pg8WZPv/cW2DXebYKImrz06OvVCo0ePgQ2rFjx0BbNMJAG+7YNPC+gCVwsGwZpb3PwagNWayIV2nWOWAHhkw44baxf5io69WYRmm0Y4A+EVHvYzfhP1i36Xn9KXacslvFdvq9XKlUaOjgoXvehpu3DLyv7uB3fV66jQu/+EWw6z4meFi7FpNC5OQcJtfFvhn5LKE0k2RjkvraIZPHg22xdv/eD27A1w2ZQDiK0+OoVCo0bN/Be9yGS5YuoI6OVDP89Fevg202vD1bvG+Ljef/4fHYd7fvwGPo8DAZMBGRH2K/6uhZAvbSKn62rQ6ZhHnxYSPBHjYKXw9WYrD8tybL6/APf58m9KhUKjRkcGNtqHeUiqIoBehAqSiKUoAOlIqiKAU0FUfpGxb5mRhTYgl1azmulizGRAnTD8G4McfhmmROotIybtPpoeYwasKRYNe3yeMwXNQ+nn/yFbYBanZtOWvl66VUk6m7rSXFCCKXgiyRg2PjjnySGu22t58Ae8iHsbjSoC6MmwxNeXz1pW+DbZqoL5mDMbFxsqVH+CjV8Tt187uPgx0HqLf5OWKhk0lORotZMUInptBJ+59XRV81V/r+6U0/BXvUiC6wD5uOSSF4ggsiopUr1oAdMzl5zGgUyyxb9sO773kY7EUrMfbUYOJ5vx6Jfmvwe08ZMnSfAV3uQ4TF9hZueVpsP2nKXLA7VuHncMsmljBXSpRkDUN9cUOMGmWvh5/DbaukVp5MrYC9OcYkzOMf/TnYF177D8JHv36cpyPvDr2jVBRFKUAHSkVRlAJ0oFQURSmgKY0yqtkUZWu9qYQxXo6IkiSy6qgNRjHqWGXC2MuXn3pW+Jg2C+MiV+7Atc3TpuE63u4NuNaTiOiZR7DwusGSqsY2ahV+CdegEhF5NR9+7ylOkpCTxSDWDWwfz5droAOXxZHWUaOpVnbgG2KpUW57FYtWOQZeq8FzPodvyImsXbv2XfxHhNfS5nG0vtR/gixhcNhabTGK6i5F9VRnjcp4sKW6LLoVeqj1HXc86rxHH3882HFOcbnvf/+HuE2IIuVlF1/IDlLeg9yeYOE3m63BpxCPM3FlAuUoi0WMotb6IcXJwLr4H/8KNdx/CeU6/ZmfvAXsDZ2YdNiyMX6xdzvqr0REz7+CMc81Qm18H8JEvp+5UPblW0eh+Jnc8X2wF71xGdgGT1JNREY2F2LkzInsDr2jVBRFKUAHSkVRlAJ0oFQURSlAB0pFUZQCmprMKZdqVM4Cr2M+UZMTiG25OOHz9jsYYDppwliwpxw0Sfi49t9uA/vqb2GlurXrUPCefOgM4cMM8Pug7qKIXgpw8saKcio5ltKmCv3mCldyEjugxE79eyxIO8wJlnYCnDTYugSDx70DsOqiZcnkJH2rsKqkwSYaOsfgdUgSeS2DHZgg2WDJjXmL2aYU0Z0s2bOTtJa41/Oq5HnZjFAVg+ejsgzS5mUf/+M2TEx78glzwY5zRP7ulRiUbbB7jLPOwOQkliP7yWMPYcB5xDIoBy7aJV9OKlpZ37HM1tpw4XOLqH1QuljhwDmYYNuyBsvt78fJKuOwS3GDACfETjoe+wsR0eMLbwE7ZBN+tot96gf/fp/w8dgsXCww95prwE7YSoC8jD/9VTh5Nc73Q+8oFUVRCtCBUlEUpYCGniP7U1ZWKrvWAPNH7zDn0dslfPSOE3xMitltchDKx8aeHowb9FmsYcIeq2s1GQNXqaCPgD3i+uwRmPIevbPi7P2+GkjjiceZ04YGf/TOiS+0Ax7ziXbCYvGSnPWr/Pz5o7d4XMl59C7yEVjsMSrMWXOfPS72t8HeaEPx6B3IY0/Yo7fBjjVk/S7v0VucP7vHCALsM3HOQx/3IR692fZ+TqhkksX8ttqGvX27JCtoz/TAxPt8VmeKIpYrM8L4VT8nFpWff9Gjd60qY2J7d+Kxcp8NPXrTrtyyRA22YdIA3d3dSbZP/cl+uru7G2k6bUNtQ23DD/hPI23YUIbzOI5p7dq11NHRQYbRWvac/9dJkoR6enqoq6uLzLxU2LtB23AX2oato23YOs20YUMDpaIoyl8yOpmjKIpSgA6UiqIoBehAqSiKUoAOlIqiKAXoQKkoilKADpSKoigF6ECpKIpSwP8FDMTL19RQF8cAAAAASUVORK5CYII=\n" + }, + "metadata": {} + } + ] + } + ] +} \ No newline at end of file