Spaces:
Sleeping
Sleeping
File size: 11,709 Bytes
cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a c046d7f cb0d40a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 142 143 144 145 146 147 148 149 150 151 152 153 154 155 156 157 158 159 160 161 162 163 164 165 166 167 168 169 170 171 172 173 174 175 176 177 178 179 180 181 182 183 184 185 186 187 188 189 190 191 192 193 194 195 196 197 198 199 200 201 202 203 204 205 206 207 208 209 210 211 212 213 214 215 216 217 218 219 220 221 222 223 224 225 226 227 228 229 230 231 232 233 234 235 236 237 238 239 240 241 242 243 244 245 246 247 248 249 250 251 252 253 254 255 256 257 258 259 260 261 262 263 264 265 266 267 268 269 270 271 272 273 274 275 276 277 278 279 280 281 282 283 284 285 286 287 288 289 290 291 292 293 294 295 296 297 298 299 300 301 302 303 304 305 306 307 308 309 310 311 312 313 314 |
import numpy as np
import torch
import torch.nn as nn
from tqdm.auto import tqdm
class ContextUnet(nn.Module):
def __init__(
self, in_channels, n_feat=256, n_cfeat=10, height=28
): # cfeat - context features
super(ContextUnet, self).__init__()
# number of input channels, number of intermediate feature maps and number of classes
self.in_channels = in_channels
self.n_feat = n_feat
self.n_cfeat = n_cfeat
self.h = height # assume h == w. must be divisible by 4, so 28,24,20,16...
# Initialize the initial convolutional layer
self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
# Initialize the down-sampling path of the U-Net with two levels
self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8]
self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4]
# original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())
# Embed the timestep and context labels with a one-layer fully connected neural network
self.timeembed1 = EmbedFC(1, 2 * n_feat)
self.timeembed2 = EmbedFC(1, 1 * n_feat)
self.contextembed1 = EmbedFC(n_cfeat, 2 * n_feat)
self.contextembed2 = EmbedFC(n_cfeat, 1 * n_feat)
# Initialize the up-sampling path of the U-Net with three levels
self.up0 = nn.Sequential(
nn.ConvTranspose2d(
2 * n_feat, 2 * n_feat, self.h // 4, self.h // 4
), # up-sample
nn.GroupNorm(8, 2 * n_feat), # normalize
nn.ReLU(),
)
self.up1 = UnetUp(4 * n_feat, n_feat)
self.up2 = UnetUp(2 * n_feat, n_feat)
# Initialize the final convolutional layers to map to the same number of channels as the input image
self.out = nn.Sequential(
nn.Conv2d(
2 * n_feat, n_feat, 3, 1, 1
), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0
nn.GroupNorm(8, n_feat), # normalize
nn.ReLU(),
nn.Conv2d(
n_feat, self.in_channels, 3, 1, 1
), # map to same number of channels as input
)
def forward(self, x, t, c=None):
"""
x : (batch, n_feat, h, w) : input image
t : (batch, n_cfeat) : time step
c : (batch, n_classes) : context label
"""
# x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on
# pass the input image through the initial convolutional layer
x = self.init_conv(x)
# pass the result through the down-sampling path
down1 = self.down1(x) # [10, 256, 8, 8]
down2 = self.down2(down1) # [10, 256, 4, 4]
# convert the feature maps to a vector and apply an activation
hiddenvec = self.to_vec(down2)
# mask out context if context_mask == 1
if c is None:
c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
# embed context and timestep
cemb1 = self.contextembed1(c).view(
-1, self.n_feat * 2, 1, 1
) # (batch, 2*n_feat, 1,1)
temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
# print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")
up1 = self.up0(hiddenvec)
up2 = self.up1(cemb1 * up1 + temb1, down2) # add and multiply embeddings
up3 = self.up2(cemb2 * up2 + temb2, down1)
out = self.out(torch.cat((up3, x), 1))
return out
class ResidualConvBlock(nn.Module):
def __init__(
self, in_channels: int, out_channels: int, is_res: bool = False
) -> None:
super().__init__()
# Check if input and output channels are the same for the residual connection
self.same_channels = in_channels == out_channels
# Flag for whether or not to use residual connection
self.is_res = is_res
# First convolutional layer
self.conv1 = nn.Sequential(
nn.Conv2d(
in_channels, out_channels, 3, 1, 1
), # 3x3 kernel with stride 1 and padding 1
nn.BatchNorm2d(out_channels), # Batch normalization
nn.GELU(), # GELU activation function
)
# Second convolutional layer
self.conv2 = nn.Sequential(
nn.Conv2d(
out_channels, out_channels, 3, 1, 1
), # 3x3 kernel with stride 1 and padding 1
nn.BatchNorm2d(out_channels), # Batch normalization
nn.GELU(), # GELU activation function
)
def forward(self, x: torch.Tensor) -> torch.Tensor:
# If using residual connection
if self.is_res:
# Apply first convolutional layer
x1 = self.conv1(x)
# Apply second convolutional layer
x2 = self.conv2(x1)
# If input and output channels are the same, add residual connection directly
if self.same_channels:
out = x + x2
else:
# If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection
shortcut = nn.Conv2d(
x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0
).to(x.device)
out = shortcut(x) + x2
# print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}")
# Normalize output tensor
return out / 1.414
# If not using residual connection, return output of second convolutional layer
else:
x1 = self.conv1(x)
x2 = self.conv2(x1)
return x2
# Method to get the number of output channels for this block
def get_out_channels(self):
return self.conv2[0].out_channels
# Method to set the number of output channels for this block
def set_out_channels(self, out_channels):
self.conv1[0].out_channels = out_channels
self.conv2[0].in_channels = out_channels
self.conv2[0].out_channels = out_channels
class UnetUp(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetUp, self).__init__()
# Create a list of layers for the upsampling block
# The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers
layers = [
nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
ResidualConvBlock(out_channels, out_channels),
ResidualConvBlock(out_channels, out_channels),
]
# Use the layers to create a sequential model
self.model = nn.Sequential(*layers)
def forward(self, x, skip):
# Concatenate the input tensor x with the skip connection tensor along the channel dimension
x = torch.cat((x, skip), 1)
# Pass the concatenated tensor through the sequential model and return the output
x = self.model(x)
return x
class UnetDown(nn.Module):
def __init__(self, in_channels, out_channels):
super(UnetDown, self).__init__()
# Create a list of layers for the downsampling block
# Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling
layers = [
ResidualConvBlock(in_channels, out_channels),
ResidualConvBlock(out_channels, out_channels),
nn.MaxPool2d(2),
]
# Use the layers to create a sequential model
self.model = nn.Sequential(*layers)
def forward(self, x):
# Pass the input through the sequential model and return the output
return self.model(x)
class EmbedFC(nn.Module):
def __init__(self, input_dim, emb_dim):
super(EmbedFC, self).__init__()
"""
This class defines a generic one layer feed-forward neural network for embedding input data of
dimensionality input_dim to an embedding space of dimensionality emb_dim.
"""
self.input_dim = input_dim
# define the layers for the network
layers = [
nn.Linear(input_dim, emb_dim),
nn.GELU(),
nn.Linear(emb_dim, emb_dim),
]
# create a PyTorch sequential model consisting of the defined layers
self.model = nn.Sequential(*layers)
def forward(self, x):
# flatten the input tensor
x = x.view(-1, self.input_dim)
# apply the model layers to the flattened tensor
return self.model(x)
def unorm(x):
# unity norm. results in range of [0,1]
# assume x (h,w,3)
xmax = x.max((0, 1))
xmin = x.min((0, 1))
return (x - xmin) / (xmax - xmin)
def norm_all(store, n_t, n_s):
# runs unity norm on all timesteps of all samples
nstore = np.zeros_like(store)
for t in range(n_t):
for s in range(n_s):
nstore[t, s] = unorm(store[t, s])
return nstore
def norm_torch(x_all):
# runs unity norm on all timesteps of all samples
# input is (n_samples, 3,h,w), the torch image format
x = x_all.cpu().numpy()
xmax = x.max((2, 3))
xmin = x.min((2, 3))
xmax = np.expand_dims(xmax, (2, 3))
xmin = np.expand_dims(xmin, (2, 3))
nstore = (x - xmin) / (xmax - xmin)
return torch.from_numpy(nstore)
## diffusion functions
def setup_ddpm(beta1, beta2, timesteps, device):
# construct DDPM noise schedule and sampling functions
b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
a_t = 1 - b_t
ab_t = torch.cumsum(a_t.log(), dim=0).exp()
ab_t[0] = 1
# helper function: perturbs an image to a specified noise level
def perturb_input(x, t, noise):
return (
ab_t.sqrt()[t, None, None, None] * x
+ (1 - ab_t[t, None, None, None]) * noise
)
# helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
def _denoise_add_noise(x, t, pred_noise, z=None):
if z is None:
z = torch.randn_like(x)
noise = b_t.sqrt()[t] * z
mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
return mean + noise
# sample with context using standard algorithm
# we make a change to the original algorithm to allow for context explicitely (the noises)
@torch.no_grad()
def sample_ddpm_context(nn_model, noises, context, save_rate=20):
# array to keep track of generated steps for plotting
intermediate = []
pbar = tqdm(range(timesteps, 0, -1), leave=False)
for i in pbar:
pbar.set_description(f"sampling timestep {i:3d}")
# reshape time tensor
t = torch.tensor([i / timesteps])[:, None, None, None].to(noises.device)
# sample some random noise to inject back in. For i = 1, don't add back in noise
z = torch.randn_like(noises) if i > 1 else 0
eps = nn_model(noises, t, c=context) # predict noise e_(x_t,t, ctx)
noises = _denoise_add_noise(noises, i, eps, z)
if i % save_rate == 0 or i == timesteps or i < 8:
intermediate.append(noises.detach().cpu().numpy())
intermediate = np.stack(intermediate)
return noises.clip(-1, 1), intermediate
return perturb_input, sample_ddpm_context
|