Prgckwb commited on
Commit
23c37e5
1 Parent(s): 3ca046d

:tada: init

Browse files
Files changed (3) hide show
  1. .gitignore +8 -0
  2. app.py +317 -4
  3. requirements.txt +9 -0
.gitignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ ### Example user template template
2
+ ### Example user template
3
+
4
+ # IntelliJ project files
5
+ .idea
6
+ *.iml
7
+ out
8
+ gen
app.py CHANGED
@@ -1,8 +1,321 @@
 
 
 
 
 
1
  import gradio as gr
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
2
 
3
- def greet(name):
4
- return "Hello " + name + "!!"
5
 
6
- demo = gr.Interface(fn=greet, inputs="text", outputs="text")
7
- demo.launch()
8
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import dataclasses
2
+ import warnings
3
+
4
+ warnings.filterwarnings("ignore")
5
+
6
  import gradio as gr
7
+ import matplotlib.pyplot as plt
8
+ import numpy as np
9
+ import torch
10
+ import torch.nn.functional as F
11
+ from PIL import Image
12
+ from pathlib import Path
13
+ from diffusers import AutoencoderKL, UNet2DConditionModel
14
+ from diffusers.models.attention_processor import AttnProcessor, Attention
15
+ from rich import traceback
16
+ from torchvision.transforms.functional import to_tensor
17
+ from transformers import CLIPTokenizer, CLIPTextModel
18
+ from tqdm import tqdm
19
+
20
+ MODEL_ID = "CompVis/stable-diffusion-v1-4"
21
+ SEED = 1117
22
+ UNET_TIMESTEP = 1
23
+
24
+ traceback.install()
25
+
26
+
27
+ @dataclasses.dataclass
28
+ class AttentionStore:
29
+ index: int
30
+ query: torch.Tensor
31
+ key: torch.Tensor
32
+ value: torch.Tensor
33
+ attention_probs: torch.Tensor
34
+
35
+
36
+ class NewAttnProcessor(AttnProcessor):
37
+ def __init__(
38
+ self,
39
+ save_uncond_attention: bool = True,
40
+ save_cond_attention: bool = True,
41
+ max_cross_attention_maps: int = 64,
42
+ max_self_attention_maps: int = 64,
43
+ ):
44
+ super().__init__()
45
+ self.save_uncond_attn = save_uncond_attention
46
+ self.save_cond_attn = save_cond_attention
47
+ self.max_cross_size = max_cross_attention_maps
48
+ self.max_self_size = max_self_attention_maps
49
+
50
+ self.cross_attention_stores = []
51
+ self.self_attention_stores = []
52
+
53
+ def _save_attention_store(
54
+ self,
55
+ is_cross: bool,
56
+ q: torch.Tensor,
57
+ k: torch.Tensor,
58
+ v: torch.Tensor,
59
+ attn_probs: torch.Tensor
60
+ ) -> None:
61
+ # Function to split tensors based on conditional probability
62
+ def split_tensors(tensor):
63
+ half_size = tensor.shape[0] // 2
64
+ return tensor[:half_size], tensor[half_size:]
65
+
66
+ # Split attention probabilities and q, k, v tensors
67
+ uncond_attn_probs, cond_attn_probs = split_tensors(attn_probs)
68
+ uncond_q, cond_q = split_tensors(q)
69
+ uncond_k, cond_k = split_tensors(k)
70
+ uncond_v, cond_v = split_tensors(v)
71
+
72
+ # Select tensors based on flags
73
+ if self.save_cond_attn and self.save_uncond_attn:
74
+ selected_probs, selected_q, selected_k, selected_v = attn_probs, q, k, v
75
+ elif self.save_cond_attn:
76
+ selected_probs, selected_q, selected_k, selected_v = cond_attn_probs, cond_q, cond_k, cond_v
77
+ elif self.save_uncond_attn:
78
+ selected_probs, selected_q, selected_k, selected_v = uncond_attn_probs, uncond_q, uncond_k, uncond_v
79
+ else:
80
+ return
81
+
82
+ # Determine max size based on attention type (cross or self)
83
+ max_size = self.max_cross_size if is_cross else self.max_self_size
84
+
85
+ # Filter out large attention maps
86
+ if selected_probs.shape[1] > max_size ** 2:
87
+ return
88
+
89
+ # Create and append attention store object
90
+ store = AttentionStore(
91
+ index=len(self.cross_attention_stores) if is_cross else len(self.self_attention_stores),
92
+ query=selected_q,
93
+ key=selected_k,
94
+ value=selected_v,
95
+ attention_probs=selected_probs
96
+ )
97
+
98
+ target_store = self.cross_attention_stores if is_cross else self.self_attention_stores
99
+ target_store.append(store)
100
+ return
101
+
102
+ def __call__(
103
+ self,
104
+ attn: Attention,
105
+ hidden_states: torch.FloatTensor,
106
+ encoder_hidden_states: torch.FloatTensor = None,
107
+ attention_mask: torch.FloatTensor = None,
108
+ temb: torch.FloatTensor = None,
109
+ *args,
110
+ **kwargs,
111
+ ) -> torch.Tensor:
112
+ residual = hidden_states
113
+
114
+ if attn.spatial_norm is not None:
115
+ hidden_states = attn.spatial_norm(hidden_states, temb)
116
+
117
+ input_ndim = hidden_states.ndim
118
+
119
+ if input_ndim == 4:
120
+ batch_size, channel, height, width = hidden_states.shape
121
+ hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2)
122
+
123
+ batch_size, sequence_length, _ = (
124
+ hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape
125
+ )
126
+ attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size)
127
+
128
+ if attn.group_norm is not None:
129
+ hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2)
130
+
131
+ query = attn.to_q(hidden_states)
132
+
133
+ is_cross_attention = encoder_hidden_states is not None
134
+
135
+ if encoder_hidden_states is None:
136
+ encoder_hidden_states = hidden_states
137
+ elif attn.norm_cross:
138
+ encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states)
139
+
140
+ key = attn.to_k(encoder_hidden_states)
141
+ value = attn.to_v(encoder_hidden_states)
142
+
143
+ query = attn.head_to_batch_dim(query)
144
+ key = attn.head_to_batch_dim(key)
145
+ value = attn.head_to_batch_dim(value)
146
+
147
+ attention_probs = attn.get_attention_scores(query, key, attention_mask)
148
+
149
+ # Save attention maps
150
+ self._save_attention_store(is_cross=is_cross_attention, q=query, k=key, v=value, attn_probs=attention_probs)
151
+
152
+ hidden_states = torch.bmm(attention_probs, value)
153
+ hidden_states = attn.batch_to_head_dim(hidden_states)
154
+
155
+ # linear proj
156
+ hidden_states = attn.to_out[0](hidden_states)
157
+ # dropout
158
+ hidden_states = attn.to_out[1](hidden_states)
159
+
160
+ if input_ndim == 4:
161
+ hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width)
162
+
163
+ if attn.residual_connection:
164
+ hidden_states = hidden_states + residual
165
+
166
+ hidden_states = hidden_states / attn.rescale_output_factor
167
+
168
+ return hidden_states
169
+
170
+ def reset_attention_stores(self) -> None:
171
+ self.cross_attention_stores = []
172
+ self.self_attention_stores = []
173
+ return
174
+
175
+
176
+ device = "cuda" if torch.cuda.is_available() else "cpu"
177
+ tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(MODEL_ID, subfolder="tokenizer")
178
+ text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(MODEL_ID, subfolder="text_encoder").to(device)
179
+ unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(MODEL_ID, subfolder="unet").to(device)
180
+ vae: AutoencoderKL = AutoencoderKL.from_pretrained(MODEL_ID, subfolder="vae").to(device)
181
+
182
+ unet.set_attn_processor(
183
+ NewAttnProcessor(
184
+ save_uncond_attention=False,
185
+ save_cond_attention=True,
186
+ )
187
+ )
188
+
189
+
190
+ @torch.inference_mode()
191
+ def inference(image: Image.Image, prompt: str, progress=gr.Progress(track_tqdm=True)):
192
+ progress(0, "Initializing...")
193
+ image = image.convert("RGB").resize((512, 512))
194
+ image = to_tensor(image).unsqueeze(0).to(device)
195
+
196
+ progress(0.1, "Generating text embeddings...")
197
+ input_ids = tokenizer(
198
+ prompt,
199
+ return_tensors="pt",
200
+ padding="max_length",
201
+ truncation=True,
202
+ max_length=tokenizer.model_max_length,
203
+ ).input_ids.to(device)
204
+
205
+ n_cond_tokens = len(
206
+ tokenizer(
207
+ prompt,
208
+ return_tensors="pt",
209
+ truncation=True,
210
+ ).input_ids[0]
211
+ )
212
+
213
+ cond_text_embeddings = text_encoder(input_ids).last_hidden_state[0].to(device)
214
+
215
+ uncond_input_ids = tokenizer(
216
+ "",
217
+ return_tensors="pt",
218
+ padding="max_length",
219
+ truncation=True,
220
+ max_length=tokenizer.model_max_length,
221
+ ).input_ids.to(device)
222
+ uncond_text_embeddings = text_encoder(uncond_input_ids).last_hidden_state[0].to(device)
223
+
224
+ text_embeddings = torch.stack([uncond_text_embeddings, cond_text_embeddings], dim=0)
225
+
226
+ progress(0.2, "Encoding the input image...")
227
+ init_image = image.to(device)
228
+ init_latent_dist = vae.encode(init_image).latent_dist
229
+
230
+ # Fix the random seed for reproducibility
231
+ progress(0.3, "Generating the latents...")
232
+ generator = torch.Generator(device=device).manual_seed(SEED)
233
+ latent = init_latent_dist.sample(generator=generator)
234
+ latent = latent * vae.config['scaling_factor'] # scaling_factor = 0.18215
235
+ latents = latent.expand(len(image), unet.config['in_channels'], 512 // 8, 512 // 8)
236
+ latents_input = torch.cat([latents] * 2).to(device)
237
+
238
+ progress(0.5, "Forwarding the UNet model...")
239
+ _ = unet(latents_input, UNET_TIMESTEP, encoder_hidden_states=text_embeddings)
240
+
241
+ attn_processor = next(iter(unet.attn_processors.values()))
242
+ cross_attention_stores = attn_processor.cross_attention_stores
243
+
244
+ progress(0.7, "Processing the cross attention maps...")
245
+ cross_attention_probs_list = []
246
+ # 事前に保存しておいた、全ての Cross-Attention 層の出力を取得
247
+ for i, cross_attn_store in enumerate(cross_attention_stores):
248
+ cross_attn_probs = cross_attn_store.attention_probs # (8, 8x8~64x64, 77)
249
+ n_heads, scale_pow, n_tokens = cross_attn_probs.shape
250
+
251
+ # scale: 8, 16, 32, 64
252
+ scale = int(np.sqrt(scale_pow))
253
+
254
+ # Multi-head Attentionの平均を取って、1つのAttention Mapにする
255
+ mean_cross_attn_probs = (
256
+ cross_attn_probs
257
+ .permute(0, 2, 1) # (8, 77, 8x8~64x64)
258
+ .reshape(n_heads, n_tokens, scale, scale) # (8, 77, 8~64, 8~64)
259
+ .mean(dim=0) # (77, 8~64, 8~64)
260
+ )
261
+
262
+ # scale を 全て 512x512 に合わせる
263
+ mean_cross_attn_probs = F.interpolate(
264
+ mean_cross_attn_probs.unsqueeze(0),
265
+ size=(512, 512),
266
+ mode='bilinear',
267
+ align_corners=True
268
+ ).squeeze(0) # (77, 512, 512)
269
+
270
+ # <bos> と <eos> トークンの間に挿入されたトークンのみを取得
271
+ mean_cross_attn_probs = mean_cross_attn_probs[:n_cond_tokens, ...] # (n_tokens, 512, 512)
272
+ cross_attention_probs_list.append(mean_cross_attn_probs)
273
+
274
+ # list -> torch.Tensor
275
+ cross_attention_probs = torch.stack(cross_attention_probs_list) # (16, n_classes, 512, 512)
276
+ n_layers, n_cond_tokens, _, _ = cross_attention_probs.shape
277
+
278
+ progress(0.9, "Post-processing the attention maps...")
279
+
280
+ image_list = []
281
+ # 各行ごとに画像を作成し保存
282
+ for i in tqdm(range(cross_attention_probs.shape[0]), desc="Saving images..."):
283
+ fig, ax = plt.subplots(1, n_cond_tokens, figsize=(16, 4)) # 行ごとに画像を作成
284
+
285
+ for j in range(cross_attention_probs.shape[1]):
286
+ # 各クラスのアテンションマップを Min-Max 正規化 (0~1)
287
+ min_val = cross_attention_probs[i, j].min()
288
+ max_val = cross_attention_probs[i, j].max()
289
+ cross_attention_probs[i, j] = (cross_attention_probs[i, j] - min_val) / (max_val - min_val)
290
+
291
+ attn_probs = cross_attention_probs[i, j].cpu().detach().numpy()
292
+ ax[j].imshow(attn_probs, alpha=0.9)
293
+ ax[j].axis('off')
294
+ ax[j].set_title(tokenizer.decode(input_ids[0, j].item()))
295
+
296
+ # 各行ごとの画像を保存
297
+ out_dir = Path("output")
298
+ out_dir.mkdir(exist_ok=True)
299
+ filepath = out_dir / f"output_row_{i}.png"
300
+ plt.savefig(filepath, bbox_inches='tight', pad_inches=0)
301
+ plt.close(fig)
302
+
303
+ # 保存した画像をPILで読み込んでリストに追加
304
+ image_list.append(Image.open(filepath))
305
+ return image_list
306
 
 
 
307
 
308
+ if __name__ == '__main__':
309
+ ca_output = [gr.Image(type="pil", label="Attention Map") for _ in range(16)]
310
 
311
+ iface = gr.Interface(
312
+ title="Stable Diffusion Attention Visualizer",
313
+ description="",
314
+ fn=inference,
315
+ inputs=[
316
+ gr.Image(type="pil", label="Input Image", width=512, height=512),
317
+ gr.Textbox(label="Prompt", placeholder="Enter a prompt here..."),
318
+ ],
319
+ outputs=ca_output,
320
+ )
321
+ iface.launch()
requirements.txt ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ gradio
2
+ torch
3
+ torchvision
4
+ diffusers
5
+ accelerate
6
+ safetensors
7
+ transformers
8
+ matplotlib
9
+ rich