Navyabhat commited on
Commit
81376a7
1 Parent(s): c6b3844

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +89 -305
app.py CHANGED
@@ -1,307 +1,91 @@
1
  import gradio as gr
2
- from base64 import b64encode
3
- import numpy
4
  import torch
5
- from diffusers import AutoencoderKL, LMSDiscreteScheduler, UNet2DConditionModel
6
- from PIL import Image
7
- from torch import autocast
8
- from torchvision import transforms as tfms
9
- from tqdm.auto import tqdm
10
- from transformers import CLIPTextModel, CLIPTokenizer, logging
11
- import torchvision.transforms as T
12
-
13
- torch.manual_seed(1)
14
- logging.set_verbosity_error()
15
- torch_device = "cpu"
16
-
17
- vae = AutoencoderKL.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="vae")
18
-
19
- # Load the tokenizer and text encoder to tokenize and encode the text.
20
- tokenizer = CLIPTokenizer.from_pretrained("openai/clip-vit-large-patch14")
21
- text_encoder = CLIPTextModel.from_pretrained("openai/clip-vit-large-patch14")
22
-
23
- # The UNet model for generating the latents.
24
- unet = UNet2DConditionModel.from_pretrained("CompVis/stable-diffusion-v1-4", subfolder="unet")
25
-
26
- # The noise scheduler
27
- scheduler = LMSDiscreteScheduler(beta_start=0.00085, beta_end=0.012, beta_schedule="scaled_linear", num_train_timesteps=1000)
28
-
29
- vae = vae.to(torch_device)
30
- text_encoder = text_encoder.to(torch_device)
31
- unet = unet.to(torch_device);
32
-
33
- token_emb_layer = text_encoder.text_model.embeddings.token_embedding
34
- pos_emb_layer = text_encoder.text_model.embeddings.position_embedding
35
- position_ids = text_encoder.text_model.embeddings.position_ids[:, :77]
36
- position_embeddings = pos_emb_layer(position_ids)
37
-
38
- def pil_to_latent(input_im):
39
- # Single image -> single latent in a batch (so size 1, 4, 64, 64)
40
- with torch.no_grad():
41
- latent = vae.encode(tfms.ToTensor()(input_im).unsqueeze(0).to(torch_device)*2-1) # Note scaling
42
- return 0.18215 * latent.latent_dist.sample()
43
-
44
- def latents_to_pil(latents):
45
- # bath of latents -> list of images
46
- latents = (1 / 0.18215) * latents
47
- with torch.no_grad():
48
- image = vae.decode(latents).sample
49
- image = (image / 2 + 0.5).clamp(0, 1)
50
- image = image.detach().cpu().permute(0, 2, 3, 1).numpy()
51
- images = (image * 255).round().astype("uint8")
52
- pil_images = [Image.fromarray(image) for image in images]
53
- return pil_images
54
-
55
- def get_output_embeds(input_embeddings):
56
- # CLIP's text model uses causal mask, so we prepare it here:
57
- bsz, seq_len = input_embeddings.shape[:2]
58
- causal_attention_mask = text_encoder.text_model._build_causal_attention_mask(bsz, seq_len, dtype=input_embeddings.dtype)
59
-
60
- # Getting the output embeddings involves calling the model with passing output_hidden_states=True
61
- # so that it doesn't just return the pooled final predictions:
62
- encoder_outputs = text_encoder.text_model.encoder(
63
- inputs_embeds=input_embeddings,
64
- attention_mask=None, # We aren't using an attention mask so that can be None
65
- causal_attention_mask=causal_attention_mask.to(torch_device),
66
- output_attentions=None,
67
- output_hidden_states=True, # We want the output embs not the final output
68
- return_dict=None,
69
- )
70
-
71
- # We're interested in the output hidden state only
72
- output = encoder_outputs[0]
73
-
74
- # There is a final layer norm we need to pass these through
75
- output = text_encoder.text_model.final_layer_norm(output)
76
-
77
- # And now they're ready!
78
- return output
79
-
80
- def generate_with_embs(text_embeddings, seed, max_length):
81
- height = 512 # default height of Stable Diffusion
82
- width = 512 # default width of Stable Diffusion
83
- num_inference_steps = 10 # Number of denoising steps
84
- guidance_scale = 7.5 # Scale for classifier-free guidance
85
- generator = torch.manual_seed(seed) # Seed generator to create the inital latent noise
86
- batch_size = 1
87
-
88
- # max_length = text_input.input_ids.shape[-1]
89
- uncond_input = tokenizer(
90
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
91
- )
92
- with torch.no_grad():
93
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
94
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
95
-
96
- # Prep Scheduler
97
- set_timesteps(scheduler, num_inference_steps)
98
-
99
- # Prep latents
100
- latents = torch.randn(
101
- (batch_size, unet.in_channels, height // 8, width // 8),
102
- generator=generator,
103
- )
104
- latents = latents.to(torch_device)
105
- latents = latents * scheduler.init_noise_sigma
106
-
107
- # Loop
108
- for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
109
- # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
110
- latent_model_input = torch.cat([latents] * 2)
111
- sigma = scheduler.sigmas[i]
112
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
113
-
114
- # predict the noise residual
115
- with torch.no_grad():
116
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
117
-
118
- # perform guidance
119
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
120
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
121
-
122
- # compute the previous noisy sample x_t -> x_t-1
123
- latents = scheduler.step(noise_pred, t, latents).prev_sample
124
-
125
- return latents_to_pil(latents)[0]
126
-
127
- # Prep Scheduler
128
- def set_timesteps(scheduler, num_inference_steps):
129
- scheduler.set_timesteps(num_inference_steps)
130
- scheduler.timesteps = scheduler.timesteps.to(torch.float32) # minor fix to ensure MPS compatibility, fixed in diffusers PR 3925
131
-
132
- def embed_style(prompt, style_embed, style_seed):
133
- # Tokenize
134
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
135
- input_ids = text_input.input_ids.to(torch_device)
136
-
137
- # Get token embeddings
138
- token_embeddings = token_emb_layer(input_ids)
139
-
140
- replacement_token_embedding = style_embed.to(torch_device)
141
-
142
- # Insert this into the token embeddings
143
- token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
144
-
145
- # Combine with pos embs
146
- input_embeddings = token_embeddings + position_embeddings
147
-
148
- # Feed through to get final output embs
149
- modified_output_embeddings = get_output_embeds(input_embeddings)
150
-
151
- # And generate an image with this:
152
- max_length = text_input.input_ids.shape[-1]
153
- return generate_with_embs(modified_output_embeddings, style_seed, max_length)
154
-
155
- def loss_style(prompt, style_embed, style_seed):
156
- # Tokenize
157
- text_input = tokenizer(prompt, padding="max_length", max_length=tokenizer.model_max_length, truncation=True, return_tensors="pt")
158
- input_ids = text_input.input_ids.to(torch_device)
159
-
160
- # Get token embeddings
161
- token_embeddings = token_emb_layer(input_ids)
162
-
163
- # The new embedding - our special birb word
164
- replacement_token_embedding = style_embed.to(torch_device)
165
-
166
- # Insert this into the token embeddings
167
- token_embeddings[0, torch.where(input_ids[0]==6829)] = replacement_token_embedding.to(torch_device)
168
-
169
- # Combine with pos embs
170
- input_embeddings = token_embeddings + position_embeddings
171
-
172
- # Feed through to get final output embs
173
- modified_output_embeddings = get_output_embeds(input_embeddings)
174
-
175
- # And generate an image with this:
176
- max_length = text_input.input_ids.shape[-1]
177
- return generate_loss_based_image(modified_output_embeddings, style_seed,max_length)
178
-
179
-
180
- def color_loss(image):
181
- color_channel = image[:, 1]
182
- target_value = 0.7
183
- error = torch.abs(color_channel - target_value).mean()
184
- return error
185
-
186
- def generate_loss_based_image(text_embeddings, seed, max_length):
187
-
188
- height = 64
189
- width = 64
190
- num_inference_steps = 10
191
- guidance_scale = 8
192
- generator = torch.manual_seed(64)
193
- batch_size = 1
194
- loss_scale = 200
195
-
196
- uncond_input = tokenizer(
197
- [""] * batch_size, padding="max_length", max_length=max_length, return_tensors="pt"
198
- )
199
- with torch.no_grad():
200
- uncond_embeddings = text_encoder(uncond_input.input_ids.to(torch_device))[0]
201
- text_embeddings = torch.cat([uncond_embeddings, text_embeddings])
202
-
203
- # Prep Scheduler
204
- set_timesteps(scheduler, num_inference_steps+1)
205
-
206
- # Prep latents
207
- latents = torch.randn(
208
- (batch_size, unet.in_channels, height // 8, width // 8),
209
- generator=generator,
210
- )
211
- latents = latents.to(torch_device)
212
- latents = latents * scheduler.init_noise_sigma
213
-
214
- sched_out = None
215
-
216
- # Loop
217
- for i, t in tqdm(enumerate(scheduler.timesteps), total=len(scheduler.timesteps)):
218
- # expand the latents if we are doing classifier-free guidance to avoid doing two forward passes.
219
- latent_model_input = torch.cat([latents] * 2)
220
- sigma = scheduler.sigmas[i]
221
- latent_model_input = scheduler.scale_model_input(latent_model_input, t)
222
-
223
- # predict the noise residual
224
- with torch.no_grad():
225
- noise_pred = unet(latent_model_input, t, encoder_hidden_states=text_embeddings)["sample"]
226
-
227
- # perform CFG
228
- noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
229
- noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
230
-
231
- ### ADDITIONAL GUIDANCE ###
232
- if i%5 == 0 and i>0:
233
- # Requires grad on the latents
234
- latents = latents.detach().requires_grad_()
235
-
236
- # Get the predicted x0:
237
- scheduler._step_index -= 1
238
- latents_x0 = scheduler.step(noise_pred, t, latents).pred_original_sample
239
-
240
- # Decode to image space
241
- denoised_images = vae.decode((1 / 0.18215) * latents_x0).sample / 2 + 0.5 # range (0, 1)
242
-
243
-
244
- # Calculate loss
245
- loss = color_loss(denoised_images) * loss_scale
246
-
247
- # Occasionally print it out
248
- # if i%10==0:
249
- print(i, 'loss:', loss)
250
-
251
- # Get gradient
252
- cond_grad = torch.autograd.grad(loss, latents)[0]
253
-
254
- # Modify the latents based on this gradient
255
- latents = latents.detach() - cond_grad * sigma**2
256
- # To PIL Images
257
- im_t0 = latents_to_pil(latents_x0)[0]
258
- im_next = latents_to_pil(latents)[0]
259
-
260
- # Now step with scheduler
261
- latents = scheduler.step(noise_pred, t, latents).prev_sample
262
-
263
- return latents_to_pil(latents)[0]
264
-
265
-
266
- def generate_image_from_prompt(text_in, style_in):
267
- STYLE_LIST = ['coffeemachine.bin', 'collage_style.bin', 'cube.bin', 'jerrymouse2.bin', 'zero.bin']
268
- STYLE_SEEDS = [32, 64, 128, 16, 8]
269
-
270
- print(text_in)
271
- print(style_in)
272
- style_file = style_in + '.bin'
273
- idx = STYLE_LIST.index(style_file)
274
- print(style_file)
275
- print(idx)
276
-
277
- prompt = text_in + ' a puppy'
278
-
279
- style_seed = STYLE_SEEDS[idx]
280
- style_dict = torch.load(style_file)
281
- style_embed = [v for v in style_dict.values()]
282
-
283
- generated_image = embed_style(prompt, style_embed[0], style_seed)
284
-
285
- loss_generated_img = (loss_style(prompt, style_embed[0], style_seed))
286
-
287
- return [generated_image, loss_generated_img]
288
-
289
-
290
- # Define Interface
291
-
292
- title = 'ERA-SESSION20 Generative Art and Stable Diffusion'
293
-
294
- demo = gr.Interface(generate_image_from_prompt,
295
- inputs = [gr.Textbox(1, label='prompt'),
296
- gr.Dropdown(
297
- ['coffeemachine.bin', 'collage_style.bin', 'cube.bin', 'jerrymouse2.bin', 'zero.bin'],value="cube", label="Pretrained Styles"
298
- )
299
- ],
300
- outputs = [
301
-
302
- gr.Gallery(label="Generated images", show_label=False, elem_id="gallery", columns=[2], rows=[2], object_fit="contain", height="auto")
303
- ],
304
-
305
- title = title
306
- )
307
- demo.launch(debug=True)
 
1
  import gradio as gr
2
+ import random
 
3
  import torch
4
+ import pathlib
5
+
6
+ from src.utils import concept_styles, loss_fn
7
+ from src.stable_diffusion import StableDiffusion
8
+
9
+ PROJECT_PATH = "."
10
+ CONCEPT_LIBS_PATH = f"{PROJECT_PATH}/concept_libs"
11
+
12
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+
15
+ def generate(prompt, styles, gen_steps, loss_scale):
16
+ lossless_images, lossy_images = [], []
17
+ for style in styles:
18
+ concept_lib_path = f"{CONCEPT_LIBS_PATH}/{concept_styles[style]}"
19
+ concept_lib = pathlib.Path(concept_lib_path)
20
+ concept_embed = torch.load(concept_lib)
21
+
22
+ manual_seed = random.randint(0, 100)
23
+ diffusion = StableDiffusion(
24
+ device=DEVICE,
25
+ num_inference_steps=gen_steps,
26
+ manual_seed=manual_seed,
27
+ )
28
+ generated_image_lossless = diffusion.generate_image(
29
+ prompt=prompt,
30
+ loss_fn=loss_fn,
31
+ loss_scale=0,
32
+ concept_embed=concept_embed,
33
+ )
34
+ generated_image_lossy = diffusion.generate_image(
35
+ prompt=prompt,
36
+ loss_fn=loss_fn,
37
+ loss_scale=loss_scale,
38
+ concept_embed=concept_embed,
39
+ )
40
+ lossless_images.append((generated_image_lossless, style))
41
+ lossy_images.append((generated_image_lossy, style))
42
+ return {lossless_gallery: lossless_images, lossy_gallery: lossy_images}
43
+
44
+
45
+ with gr.Blocks() as app:
46
+ gr.Markdown("## ERA Session20 - Stable Diffusion: Generative Art with Guidance")
47
+ with gr.Row():
48
+ with gr.Column():
49
+ prompt_box = gr.Textbox(label="Prompt", interactive=True)
50
+ style_selector = gr.Dropdown(
51
+ choices=list(concept_styles.keys()),
52
+ value=list(concept_styles.keys())[0],
53
+ multiselect=True,
54
+ label="Select a Concept Style",
55
+ interactive=True,
56
+ )
57
+ gen_steps = gr.Slider(
58
+ minimum=10,
59
+ maximum=50,
60
+ value=30,
61
+ step=10,
62
+ label="Select Number of Steps",
63
+ interactive=True,
64
+ )
65
+
66
+ loss_scale = gr.Slider(
67
+ minimum=0,
68
+ maximum=32,
69
+ value=8,
70
+ step=8,
71
+ label="Select Guidance Scale",
72
+ interactive=True,
73
+ )
74
+
75
+ submit_btn = gr.Button(value="Generate")
76
+
77
+ with gr.Column():
78
+ lossless_gallery = gr.Gallery(
79
+ label="Generated Images without Guidance", show_label=True
80
+ )
81
+ lossy_gallery = gr.Gallery(
82
+ label="Generated Images with Guidance", show_label=True
83
+ )
84
+
85
+ submit_btn.click(
86
+ generate,
87
+ inputs=[prompt_box, style_selector, gen_steps, loss_scale],
88
+ outputs=[lossless_gallery, lossy_gallery],
89
+ )
90
+
91
+ app.launch()