Aditya Patkar commited on
Commit
cb0d40a
1 Parent(s): a338de2

added feature2sprite

Browse files
Files changed (5) hide show
  1. app.py +28 -1
  2. constants.py +2 -0
  3. feature_to_sprite.py +85 -0
  4. requirements.txt +6 -0
  5. utilities.py +286 -0
app.py CHANGED
@@ -5,7 +5,7 @@
5
  # imports
6
  import streamlit as st
7
  from text_to_image import generate_image
8
-
9
 
10
  def setup():
11
  """
@@ -81,6 +81,33 @@ def main():
81
  This mode generates 16*16 images of sprites based on a combination of features. It uses a custom model trained on a dataset of sprites.
82
  """
83
  )
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
84
 
85
 
86
  if __name__ == "__main__":
 
5
  # imports
6
  import streamlit as st
7
  from text_to_image import generate_image
8
+ from feature_to_sprite import generate_sprites
9
 
10
  def setup():
11
  """
 
81
  This mode generates 16*16 images of sprites based on a combination of features. It uses a custom model trained on a dataset of sprites.
82
  """
83
  )
84
+
85
+ form = st.form(key="my_form")
86
+
87
+ #add sliders
88
+ hero = form.slider("Hero", min_value=0.0, max_value=1.0, value=1.0, step=0.01)
89
+ non_hero = form.slider("Non Hero", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
90
+ food = form.slider("Food", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
91
+ spell = form.slider("Spell", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
92
+ side_facing = form.slider("Side Facing", min_value=0.0, max_value=1.0, value=0.0, step=0.01)
93
+ #add submit button
94
+ submit_button = form.form_submit_button(label="Generate")
95
+
96
+ #create feature vector
97
+ if submit_button:
98
+ feature_vector = [hero, non_hero, food, spell, side_facing]
99
+ #show loader
100
+ with st.spinner("Generating sprite..."):
101
+ #horizontal line and line break
102
+ st.markdown("<hr>", unsafe_allow_html=True)
103
+ st.markdown("<br>", unsafe_allow_html=True)
104
+
105
+ st.subheader("Your Sprite")
106
+ st.markdown("<br>", unsafe_allow_html=True)
107
+
108
+ generate_sprites(feature_vector)
109
+
110
+
111
 
112
 
113
  if __name__ == "__main__":
constants.py CHANGED
@@ -1,5 +1,7 @@
1
  """
2
  This file contains all the constants used in the project.
3
  """
 
4
 
5
  MODEL_ID = "stabilityai/stable-diffusion-2-1"
 
 
1
  """
2
  This file contains all the constants used in the project.
3
  """
4
+ import os
5
 
6
  MODEL_ID = "stabilityai/stable-diffusion-2-1"
7
+ WANDB_API_KEY = os.environ.get("WANDB_API_KEY")
feature_to_sprite.py ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from pathlib import Path
2
+ from types import SimpleNamespace
3
+ import torch
4
+ import torch.nn.functional as F
5
+ import wandb
6
+ import streamlit as st
7
+
8
+ from utilities import ContextUnet, setup_ddpm
9
+ from constants import WANDB_API_KEY
10
+
11
+ def load_model():
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ #login to wandb
15
+ #wandb.login(key=WANDB_API_KEY)
16
+
17
+
18
+ "Load the model from wandb artifacts"
19
+ api = wandb.Api(api_key=WANDB_API_KEY)
20
+ artifact = api.artifact("teamaditya/model-registry/Feature2Sprite:v0", type="model")
21
+ model_path = Path(artifact.download())
22
+
23
+ # recover model info from the registry
24
+ producer_run = artifact.logged_by()
25
+
26
+ # load the weights dictionary
27
+ model_weights = torch.load(model_path/"context_model.pth",
28
+ map_location="cpu")
29
+
30
+ # create the model
31
+ model = ContextUnet(in_channels=3,
32
+ n_feat=producer_run.config["n_feat"],
33
+ n_cfeat=producer_run.config["n_cfeat"],
34
+ height=producer_run.config["height"])
35
+
36
+ # load the weights into the model
37
+ model.load_state_dict(model_weights)
38
+
39
+ # set the model to eval mode
40
+ model.eval()
41
+ return model.to(DEVICE)
42
+
43
+ def show_image(img):
44
+ img = (img.permute(1, 2, 0).clip(-1, 1).detach().cpu().numpy() + 1) / 2
45
+ st.image(img, clamp=True)
46
+
47
+ def generate_sprites(feature_vector):
48
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
49
+ config = SimpleNamespace(
50
+ # hyperparameters
51
+ num_samples = 30,
52
+
53
+ # ddpm sampler hyperparameters
54
+ timesteps = 500,
55
+ beta1 = 1e-4,
56
+ beta2 = 0.02,
57
+
58
+ # network hyperparameters
59
+ height = 16,
60
+ )
61
+ nn_model = load_model()
62
+
63
+ _, sample_ddpm_context = setup_ddpm(config.beta1,
64
+ config.beta2,
65
+ config.timesteps,
66
+ DEVICE)
67
+
68
+ noises = torch.randn(config.num_samples, 3,
69
+ config.height, config.height).to(DEVICE)
70
+
71
+ feature_vector = torch.tensor([feature_vector]).to(DEVICE).float()
72
+ ddpm_samples, _ = sample_ddpm_context(nn_model, noises, feature_vector)
73
+
74
+ #upscale the 16*16 images to 256*256
75
+ ddpm_samples = F.interpolate(ddpm_samples, size=(256, 256), mode="bilinear")
76
+ # show the images
77
+ show_image(ddpm_samples[0])
78
+
79
+
80
+
81
+
82
+
83
+
84
+
85
+
requirements.txt CHANGED
@@ -1,5 +1,11 @@
1
  diffusers==0.19.3
 
 
 
2
  streamlit==1.25.0
3
  torch==2.0.1
 
 
 
4
  transformers==4.31.0
5
  accelerate==0.21.0
 
1
  diffusers==0.19.3
2
+ matplotlib==3.7.2
3
+ numpy==1.25.2
4
+ Pillow==9.5.0
5
  streamlit==1.25.0
6
  torch==2.0.1
7
+ torchvision==0.15.2
8
+ tqdm==4.65.0
9
+ wandb==0.15.8
10
  transformers==4.31.0
11
  accelerate==0.21.0
utilities.py ADDED
@@ -0,0 +1,286 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import numpy as np
2
+ import torch
3
+ import torch.nn as nn
4
+ from tqdm.auto import tqdm
5
+
6
+ class ContextUnet(nn.Module):
7
+ def __init__(self, in_channels, n_feat=256, n_cfeat=10, height=28): # cfeat - context features
8
+ super(ContextUnet, self).__init__()
9
+
10
+ # number of input channels, number of intermediate feature maps and number of classes
11
+ self.in_channels = in_channels
12
+ self.n_feat = n_feat
13
+ self.n_cfeat = n_cfeat
14
+ self.h = height #assume h == w. must be divisible by 4, so 28,24,20,16...
15
+
16
+ # Initialize the initial convolutional layer
17
+ self.init_conv = ResidualConvBlock(in_channels, n_feat, is_res=True)
18
+
19
+ # Initialize the down-sampling path of the U-Net with two levels
20
+ self.down1 = UnetDown(n_feat, n_feat) # down1 #[10, 256, 8, 8]
21
+ self.down2 = UnetDown(n_feat, 2 * n_feat) # down2 #[10, 256, 4, 4]
22
+
23
+ # original: self.to_vec = nn.Sequential(nn.AvgPool2d(7), nn.GELU())
24
+ self.to_vec = nn.Sequential(nn.AvgPool2d((4)), nn.GELU())
25
+
26
+ # Embed the timestep and context labels with a one-layer fully connected neural network
27
+ self.timeembed1 = EmbedFC(1, 2*n_feat)
28
+ self.timeembed2 = EmbedFC(1, 1*n_feat)
29
+ self.contextembed1 = EmbedFC(n_cfeat, 2*n_feat)
30
+ self.contextembed2 = EmbedFC(n_cfeat, 1*n_feat)
31
+
32
+ # Initialize the up-sampling path of the U-Net with three levels
33
+ self.up0 = nn.Sequential(
34
+ nn.ConvTranspose2d(2 * n_feat, 2 * n_feat, self.h//4, self.h//4), # up-sample
35
+ nn.GroupNorm(8, 2 * n_feat), # normalize
36
+ nn.ReLU(),
37
+ )
38
+ self.up1 = UnetUp(4 * n_feat, n_feat)
39
+ self.up2 = UnetUp(2 * n_feat, n_feat)
40
+
41
+ # Initialize the final convolutional layers to map to the same number of channels as the input image
42
+ self.out = nn.Sequential(
43
+ nn.Conv2d(2 * n_feat, n_feat, 3, 1, 1), # reduce number of feature maps #in_channels, out_channels, kernel_size, stride=1, padding=0
44
+ nn.GroupNorm(8, n_feat), # normalize
45
+ nn.ReLU(),
46
+ nn.Conv2d(n_feat, self.in_channels, 3, 1, 1), # map to same number of channels as input
47
+ )
48
+
49
+ def forward(self, x, t, c=None):
50
+ """
51
+ x : (batch, n_feat, h, w) : input image
52
+ t : (batch, n_cfeat) : time step
53
+ c : (batch, n_classes) : context label
54
+ """
55
+ # x is the input image, c is the context label, t is the timestep, context_mask says which samples to block the context on
56
+
57
+ # pass the input image through the initial convolutional layer
58
+ x = self.init_conv(x)
59
+ # pass the result through the down-sampling path
60
+ down1 = self.down1(x) #[10, 256, 8, 8]
61
+ down2 = self.down2(down1) #[10, 256, 4, 4]
62
+
63
+ # convert the feature maps to a vector and apply an activation
64
+ hiddenvec = self.to_vec(down2)
65
+
66
+ # mask out context if context_mask == 1
67
+ if c is None:
68
+ c = torch.zeros(x.shape[0], self.n_cfeat).to(x)
69
+
70
+ # embed context and timestep
71
+ cemb1 = self.contextembed1(c).view(-1, self.n_feat * 2, 1, 1) # (batch, 2*n_feat, 1,1)
72
+ temb1 = self.timeembed1(t).view(-1, self.n_feat * 2, 1, 1)
73
+ cemb2 = self.contextembed2(c).view(-1, self.n_feat, 1, 1)
74
+ temb2 = self.timeembed2(t).view(-1, self.n_feat, 1, 1)
75
+ #print(f"uunet forward: cemb1 {cemb1.shape}. temb1 {temb1.shape}, cemb2 {cemb2.shape}. temb2 {temb2.shape}")
76
+
77
+
78
+ up1 = self.up0(hiddenvec)
79
+ up2 = self.up1(cemb1*up1 + temb1, down2) # add and multiply embeddings
80
+ up3 = self.up2(cemb2*up2 + temb2, down1)
81
+ out = self.out(torch.cat((up3, x), 1))
82
+ return out
83
+
84
+ class ResidualConvBlock(nn.Module):
85
+ def __init__(
86
+ self, in_channels: int, out_channels: int, is_res: bool = False
87
+ ) -> None:
88
+ super().__init__()
89
+
90
+ # Check if input and output channels are the same for the residual connection
91
+ self.same_channels = in_channels == out_channels
92
+
93
+ # Flag for whether or not to use residual connection
94
+ self.is_res = is_res
95
+
96
+ # First convolutional layer
97
+ self.conv1 = nn.Sequential(
98
+ nn.Conv2d(in_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1
99
+ nn.BatchNorm2d(out_channels), # Batch normalization
100
+ nn.GELU(), # GELU activation function
101
+ )
102
+
103
+ # Second convolutional layer
104
+ self.conv2 = nn.Sequential(
105
+ nn.Conv2d(out_channels, out_channels, 3, 1, 1), # 3x3 kernel with stride 1 and padding 1
106
+ nn.BatchNorm2d(out_channels), # Batch normalization
107
+ nn.GELU(), # GELU activation function
108
+ )
109
+
110
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
111
+
112
+ # If using residual connection
113
+ if self.is_res:
114
+ # Apply first convolutional layer
115
+ x1 = self.conv1(x)
116
+
117
+ # Apply second convolutional layer
118
+ x2 = self.conv2(x1)
119
+
120
+ # If input and output channels are the same, add residual connection directly
121
+ if self.same_channels:
122
+ out = x + x2
123
+ else:
124
+ # If not, apply a 1x1 convolutional layer to match dimensions before adding residual connection
125
+ shortcut = nn.Conv2d(x.shape[1], x2.shape[1], kernel_size=1, stride=1, padding=0).to(x.device)
126
+ out = shortcut(x) + x2
127
+ #print(f"resconv forward: x {x.shape}, x1 {x1.shape}, x2 {x2.shape}, out {out.shape}")
128
+
129
+ # Normalize output tensor
130
+ return out / 1.414
131
+
132
+ # If not using residual connection, return output of second convolutional layer
133
+ else:
134
+ x1 = self.conv1(x)
135
+ x2 = self.conv2(x1)
136
+ return x2
137
+
138
+ # Method to get the number of output channels for this block
139
+ def get_out_channels(self):
140
+ return self.conv2[0].out_channels
141
+
142
+ # Method to set the number of output channels for this block
143
+ def set_out_channels(self, out_channels):
144
+ self.conv1[0].out_channels = out_channels
145
+ self.conv2[0].in_channels = out_channels
146
+ self.conv2[0].out_channels = out_channels
147
+
148
+
149
+
150
+ class UnetUp(nn.Module):
151
+ def __init__(self, in_channels, out_channels):
152
+ super(UnetUp, self).__init__()
153
+
154
+ # Create a list of layers for the upsampling block
155
+ # The block consists of a ConvTranspose2d layer for upsampling, followed by two ResidualConvBlock layers
156
+ layers = [
157
+ nn.ConvTranspose2d(in_channels, out_channels, 2, 2),
158
+ ResidualConvBlock(out_channels, out_channels),
159
+ ResidualConvBlock(out_channels, out_channels),
160
+ ]
161
+
162
+ # Use the layers to create a sequential model
163
+ self.model = nn.Sequential(*layers)
164
+
165
+ def forward(self, x, skip):
166
+ # Concatenate the input tensor x with the skip connection tensor along the channel dimension
167
+ x = torch.cat((x, skip), 1)
168
+
169
+ # Pass the concatenated tensor through the sequential model and return the output
170
+ x = self.model(x)
171
+ return x
172
+
173
+
174
+ class UnetDown(nn.Module):
175
+ def __init__(self, in_channels, out_channels):
176
+ super(UnetDown, self).__init__()
177
+
178
+ # Create a list of layers for the downsampling block
179
+ # Each block consists of two ResidualConvBlock layers, followed by a MaxPool2d layer for downsampling
180
+ layers = [ResidualConvBlock(in_channels, out_channels), ResidualConvBlock(out_channels, out_channels), nn.MaxPool2d(2)]
181
+
182
+ # Use the layers to create a sequential model
183
+ self.model = nn.Sequential(*layers)
184
+
185
+ def forward(self, x):
186
+ # Pass the input through the sequential model and return the output
187
+ return self.model(x)
188
+
189
+ class EmbedFC(nn.Module):
190
+ def __init__(self, input_dim, emb_dim):
191
+ super(EmbedFC, self).__init__()
192
+ '''
193
+ This class defines a generic one layer feed-forward neural network for embedding input data of
194
+ dimensionality input_dim to an embedding space of dimensionality emb_dim.
195
+ '''
196
+ self.input_dim = input_dim
197
+
198
+ # define the layers for the network
199
+ layers = [
200
+ nn.Linear(input_dim, emb_dim),
201
+ nn.GELU(),
202
+ nn.Linear(emb_dim, emb_dim),
203
+ ]
204
+
205
+ # create a PyTorch sequential model consisting of the defined layers
206
+ self.model = nn.Sequential(*layers)
207
+
208
+ def forward(self, x):
209
+ # flatten the input tensor
210
+ x = x.view(-1, self.input_dim)
211
+ # apply the model layers to the flattened tensor
212
+ return self.model(x)
213
+
214
+ def unorm(x):
215
+ # unity norm. results in range of [0,1]
216
+ # assume x (h,w,3)
217
+ xmax = x.max((0,1))
218
+ xmin = x.min((0,1))
219
+ return(x - xmin)/(xmax - xmin)
220
+
221
+ def norm_all(store, n_t, n_s):
222
+ # runs unity norm on all timesteps of all samples
223
+ nstore = np.zeros_like(store)
224
+ for t in range(n_t):
225
+ for s in range(n_s):
226
+ nstore[t,s] = unorm(store[t,s])
227
+ return nstore
228
+
229
+ def norm_torch(x_all):
230
+ # runs unity norm on all timesteps of all samples
231
+ # input is (n_samples, 3,h,w), the torch image format
232
+ x = x_all.cpu().numpy()
233
+ xmax = x.max((2,3))
234
+ xmin = x.min((2,3))
235
+ xmax = np.expand_dims(xmax,(2,3))
236
+ xmin = np.expand_dims(xmin,(2,3))
237
+ nstore = (x - xmin)/(xmax - xmin)
238
+ return torch.from_numpy(nstore)
239
+
240
+
241
+ ## diffusion functions
242
+
243
+ def setup_ddpm(beta1, beta2, timesteps, device):
244
+ # construct DDPM noise schedule and sampling functions
245
+ b_t = (beta2 - beta1) * torch.linspace(0, 1, timesteps + 1, device=device) + beta1
246
+ a_t = 1 - b_t
247
+ ab_t = torch.cumsum(a_t.log(), dim=0).exp()
248
+ ab_t[0] = 1
249
+
250
+ # helper function: perturbs an image to a specified noise level
251
+ def perturb_input(x, t, noise):
252
+ return ab_t.sqrt()[t, None, None, None] * x + (1 - ab_t[t, None, None, None]) * noise
253
+
254
+ # helper function; removes the predicted noise (but adds some noise back in to avoid collapse)
255
+ def _denoise_add_noise(x, t, pred_noise, z=None):
256
+ if z is None:
257
+ z = torch.randn_like(x)
258
+ noise = b_t.sqrt()[t] * z
259
+ mean = (x - pred_noise * ((1 - a_t[t]) / (1 - ab_t[t]).sqrt())) / a_t[t].sqrt()
260
+ return mean + noise
261
+
262
+ # sample with context using standard algorithm
263
+ # we make a change to the original algorithm to allow for context explicitely (the noises)
264
+ @torch.no_grad()
265
+ def sample_ddpm_context(nn_model, noises, context, save_rate=20):
266
+ # array to keep track of generated steps for plotting
267
+ intermediate = []
268
+ pbar = tqdm(range(timesteps, 0, -1), leave=False)
269
+ for i in pbar:
270
+ pbar.set_description(f'sampling timestep {i:3d}')
271
+
272
+ # reshape time tensor
273
+ t = torch.tensor([i / timesteps])[:, None, None, None].to(noises.device)
274
+
275
+ # sample some random noise to inject back in. For i = 1, don't add back in noise
276
+ z = torch.randn_like(noises) if i > 1 else 0
277
+
278
+ eps = nn_model(noises, t, c=context) # predict noise e_(x_t,t, ctx)
279
+ noises = _denoise_add_noise(noises, i, eps, z)
280
+ if i % save_rate==0 or i==timesteps or i<8:
281
+ intermediate.append(noises.detach().cpu().numpy())
282
+
283
+ intermediate = np.stack(intermediate)
284
+ return noises.clip(-1, 1), intermediate
285
+
286
+ return perturb_input, sample_ddpm_context