Spaces:
Sleeping
Sleeping
import numpy as np | |
import torch | |
import torch.nn as nn | |
from tqdm.auto import tqdm | |
class ContextUnet(nn.Module): | |
def __init__( | |
self, in_channels, n_feat=256, n_cfeat=10, height=28 | |
): # cfeat - context features | |
super(ContextUnet, self).__init__() | |
# number of input channels, number of intermediate feature maps and number of classes | |
self.in_channels = in_channels | |
self.n_feat = n_feat | |
self.n_cfeat = n_cfeat | |
self.h = height # assume h == w. must be divisible by 4, so 28,24,20,16... | |
# Initialize the initial convolutional layer | |
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True) | |
# Initialize the down-sampling path of the U-Net with two levels | |
self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8] | |
self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4] | |
# original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU()) | |
self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU()) | |
# Embed the timestep and context labels with a one-layer fully connected neural network | |
self.timeembed1 = EmbedFC(1, 2 * n_feat) | |
self.timeembed2 = EmbedFC(1, 1 * n_feat) | |
self.contextembed1 = EmbedFC(n_cfeat, 2 * n_feat) | |
self.contextembed2 = EmbedFC(n_cfeat, 1 * n_feat) | |
# Initialize the up-sampling path of the U-Net with three levels | |
self.up0 = nn.Sequential( | |
nn.ConvTranspose2d( | |
2 * n_feat, 2 * n_feat, self.h // 4, self.h // 4 | |
), # up-sample | |
nn.GroupNorm(8, 2 * n_feat), # normalize | |
nn.ReLU(), | |
) | |
self.up1 = UnetUp(4 * n_feat, n_feat) | |
self.up2 = UnetUp(2 * n_feat, n_feat) | |
# Initialize the final convolutional layers to map to the same number of channels as the input image | |
self.out = nn.Sequential( | |
nn.Conv2d( | |
2 * n_feat, n_feat, 3, 1, 1 | |
), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0 | |
nn.GroupNorm(8, n_feat), # normalize | |
nn.ReLU(), | |
nn.Conv2d( | |
n_feat, self.in_channels, 3, 1, 1 | |
), # map to same number of channels as input | |
) | |
def forward(self, x, t, c=None): | |
""" | |
x : (batch, n_feat, h, w) : input image | |
t : (batch, n_cfeat) : time step | |
c : (batch, n_classes) : context label | |
""" | |
# x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on | |
# pass the input image through the initial convolutional layer | |
x = self.init_conv(x) | |
# pass the result through the down-sampling path | |
down1 = self.down1(x) # [10, 256, 8, 8] | |
down2 = self.down2(down1) # [10, 256, 4, 4] | |
# convert the feature maps to a vector and apply an activation | |
hiddenvec = self.to_vec(down2) | |
# mask out context if context_mask == 1 | |
if c is None: | |
c = torch.zeros(x.shape[0], self.n_cfeat).to(x) | |
# embed context and timestep | |
cemb1 = self.contextembed1(c).view( | |
-1, self.n_feat * 2, 1, 1 | |
) # (batch, 2*n_feat, 1,1) | |
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1) | |
cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1) | |
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1) | |
# print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}") | |
up1 = self.up0(hiddenvec) | |
up2 = self.up1(cemb1 * up1 + temb1, down2) # add and multiply embeddings | |
up3 = self.up2(cemb2 * up2 + temb2, down1) | |
out = self.out(torch.cat((up3, x), 1)) | |
return out | |
class ResidualConvBlock(nn.Module): | |
def __init__( | |
self, in_channels: int, out_channels: int, is_res: bool = False | |
) -> None: | |
super().__init__() | |
# Check if input and output channels are the same for the residual connection | |
self.same_channels = in_channels == out_channels | |
# Flag for whether or not to use residual connection | |
self.is_res = is_res | |
# First convolutional layer | |
self.conv1 = nn.Sequential( | |
nn.Conv2d( | |
in_channels, out_channels, 3, 1, 1 | |
), # 3x3 kernel with stride 1 and padding 1 | |
nn.BatchNorm2d(out_channels), # Batch normalization | |
nn.GELU(), # GELU activation function | |
) | |
# Second convolutional layer | |
self.conv2 = nn.Sequential( | |
nn.Conv2d( | |
out_channels, out_channels, 3, 1, 1 | |
), # 3x3 kernel with stride 1 and padding 1 | |
nn.BatchNorm2d(out_channels), # Batch normalization | |
nn.GELU(), # GELU activation function | |
) | |
def forward(self, x: torch.Tensor) -> torch.Tensor: | |
# If using residual connection | |
if self.is_res: | |
# Apply first convolutional layer | |
x1 = self.conv1(x) | |
# Apply second convolutional layer | |
x2 = self.conv2(x1) | |
# If input and output channels are the same, add residual connection directly | |
if self.same_channels: | |
out = x + x2 | |
else: | |
# If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection | |
shortcut = nn.Conv2d( | |
x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0 | |
).to(x.device) | |
out = shortcut(x) + x2 | |
# print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}") | |
# Normalize output tensor | |
return out / 1.414 | |
# If not using residual connection, return output of second convolutional layer | |
else: | |
x1 = self.conv1(x) | |
x2 = self.conv2(x1) | |
return x2 | |
# Method to get the number of output channels for this block | |
def get_out_channels(self): | |
return self.conv2[0].out_channels | |
# Method to set the number of output channels for this block | |
def set_out_channels(self, out_channels): | |
self.conv1[0].out_channels = out_channels | |
self.conv2[0].in_channels = out_channels | |
self.conv2[0].out_channels = out_channels | |
class UnetUp(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(UnetUp, self).__init__() | |
# Create a list of layers for the upsampling block | |
# The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers | |
layers = [ | |
nn.ConvTranspose2d(in_channels, out_channels, 2, 2), | |
ResidualConvBlock(out_channels, out_channels), | |
ResidualConvBlock(out_channels, out_channels), | |
] | |
# Use the layers to create a sequential model | |
self.model = nn.Sequential(*layers) | |
def forward(self, x, skip): | |
# Concatenate the input tensor x with the skip connection tensor along the channel dimension | |
x = torch.cat((x, skip), 1) | |
# Pass the concatenated tensor through the sequential model and return the output | |
x = self.model(x) | |
return x | |
class UnetDown(nn.Module): | |
def __init__(self, in_channels, out_channels): | |
super(UnetDown, self).__init__() | |
# Create a list of layers for the downsampling block | |
# Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling | |
layers = [ | |
ResidualConvBlock(in_channels, out_channels), | |
ResidualConvBlock(out_channels, out_channels), | |
nn.MaxPool2d(2), | |
] | |
# Use the layers to create a sequential model | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
# Pass the input through the sequential model and return the output | |
return self.model(x) | |
class EmbedFC(nn.Module): | |
def __init__(self, input_dim, emb_dim): | |
super(EmbedFC, self).__init__() | |
""" | |
This class defines a generic one layer feed-forward neural network for embedding input data of | |
dimensionality input_dim to an embedding space of dimensionality emb_dim. | |
""" | |
self.input_dim = input_dim | |
# define the layers for the network | |
layers = [ | |
nn.Linear(input_dim, emb_dim), | |
nn.GELU(), | |
nn.Linear(emb_dim, emb_dim), | |
] | |
# create a PyTorch sequential model consisting of the defined layers | |
self.model = nn.Sequential(*layers) | |
def forward(self, x): | |
# flatten the input tensor | |
x = x.view(-1, self.input_dim) | |
# apply the model layers to the flattened tensor | |
return self.model(x) | |
def unorm(x): | |
# unity norm. results in range of [0,1] | |
# assume x (h,w,3) | |
xmax = x.max((0, 1)) | |
xmin = x.min((0, 1)) | |
return (x - xmin) / (xmax - xmin) | |
def norm_all(store, n_t, n_s): | |
# runs unity norm on all timesteps of all samples | |
nstore = np.zeros_like(store) | |
for t in range(n_t): | |
for s in range(n_s): | |
nstore[t, s] = unorm(store[t, s]) | |
return nstore | |
def norm_torch(x_all): | |
# runs unity norm on all timesteps of all samples | |
# input is (n_samples, 3,h,w), the torch image format | |
x = x_all.cpu().numpy() | |
xmax = x.max((2, 3)) | |
xmin = x.min((2, 3)) | |
xmax = np.expand_dims(xmax, (2, 3)) | |
xmin = np.expand_dims(xmin, (2, 3)) | |
nstore = (x - xmin) / (xmax - xmin) | |
return torch.from_numpy(nstore) | |
## diffusion functions | |
def setup_ddpm(beta1, beta2, timesteps, device): | |
# construct DDPM noise schedule and sampling functions | |
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1 | |
a_t = 1 - b_t | |
ab_t = torch.cumsum(a_t.log(), dim=0).exp() | |
ab_t[0] = 1 | |
# helper function: perturbs an image to a specified noise level | |
def perturb_input(x, t, noise): | |
return ( | |
ab_t.sqrt()[t, None, None, None] * x | |
+ (1 - ab_t[t, None, None, None]) * noise | |
) | |
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse) | |
def _denoise_add_noise(x, t, pred_noise, z=None): | |
if z is None: | |
z = torch.randn_like(x) | |
noise = b_t.sqrt()[t] * z | |
mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt() | |
return mean + noise | |
# sample with context using standard algorithm | |
# we make a change to the original algorithm to allow for context explicitely (the noises) | |
def sample_ddpm_context(nn_model, noises, context, save_rate=20): | |
# array to keep track of generated steps for plotting | |
intermediate = [] | |
pbar = tqdm(range(timesteps, 0, -1), leave=False) | |
for i in pbar: | |
pbar.set_description(f"sampling timestep {i:3d}") | |
# reshape time tensor | |
t = torch.tensor([i / timesteps])[:, None, None, None].to(noises.device) | |
# sample some random noise to inject back in. For i = 1, don't add back in noise | |
z = torch.randn_like(noises) if i > 1 else 0 | |
eps = nn_model(noises, t, c=context) # predict noise e_(x_t,t, ctx) | |
noises = _denoise_add_noise(noises, i, eps, z) | |
if i % save_rate == 0 or i == timesteps or i < 8: | |
intermediate.append(noises.detach().cpu().numpy()) | |
intermediate = np.stack(intermediate) | |
return noises.clip(-1, 1), intermediate | |
return perturb_input, sample_ddpm_context | |