import dataclasses import warnings warnings.filterwarnings("ignore") import gradio as gr import matplotlib.pyplot as plt import numpy as np import torch import uuid import torch.nn.functional as F from PIL import Image from pathlib import Path from diffusers import AutoencoderKL, UNet2DConditionModel from diffusers.models.attention_processor import AttnProcessor, Attention from rich import traceback from torchvision.transforms.functional import to_tensor from transformers import CLIPTokenizer, CLIPTextModel from tqdm import tqdm import spaces MODEL_ID = "CompVis/stable-diffusion-v1-4" SEED = 1117 UNET_TIMESTEP = 1 traceback.install() @dataclasses.dataclass class AttentionStore: index: int query: torch.Tensor key: torch.Tensor value: torch.Tensor attention_probs: torch.Tensor class NewAttnProcessor(AttnProcessor): def __init__( self, save_uncond_attention: bool = True, save_cond_attention: bool = True, max_cross_attention_maps: int = 64, max_self_attention_maps: int = 64, ): super().__init__() self.save_uncond_attn = save_uncond_attention self.save_cond_attn = save_cond_attention self.max_cross_size = max_cross_attention_maps self.max_self_size = max_self_attention_maps self.cross_attention_stores = [] self.self_attention_stores = [] def _save_attention_store( self, is_cross: bool, q: torch.Tensor, k: torch.Tensor, v: torch.Tensor, attn_probs: torch.Tensor ) -> None: # Function to split tensors based on conditional probability def split_tensors(tensor): half_size = tensor.shape[0] // 2 return tensor[:half_size], tensor[half_size:] # Split attention probabilities and q, k, v tensors uncond_attn_probs, cond_attn_probs = split_tensors(attn_probs) uncond_q, cond_q = split_tensors(q) uncond_k, cond_k = split_tensors(k) uncond_v, cond_v = split_tensors(v) # Select tensors based on flags if self.save_cond_attn and self.save_uncond_attn: selected_probs, selected_q, selected_k, selected_v = attn_probs, q, k, v elif self.save_cond_attn: selected_probs, selected_q, selected_k, selected_v = cond_attn_probs, cond_q, cond_k, cond_v elif self.save_uncond_attn: selected_probs, selected_q, selected_k, selected_v = uncond_attn_probs, uncond_q, uncond_k, uncond_v else: return # Determine max size based on attention type (cross or self) max_size = self.max_cross_size if is_cross else self.max_self_size # Filter out large attention maps if selected_probs.shape[1] > max_size ** 2: return # Create and append attention store object store = AttentionStore( index=len(self.cross_attention_stores) if is_cross else len(self.self_attention_stores), query=selected_q, key=selected_k, value=selected_v, attention_probs=selected_probs ) target_store = self.cross_attention_stores if is_cross else self.self_attention_stores target_store.append(store) return def __call__( self, attn: Attention, hidden_states: torch.FloatTensor, encoder_hidden_states: torch.FloatTensor = None, attention_mask: torch.FloatTensor = None, temb: torch.FloatTensor = None, *args, **kwargs, ) -> torch.Tensor: residual = hidden_states if attn.spatial_norm is not None: hidden_states = attn.spatial_norm(hidden_states, temb) input_ndim = hidden_states.ndim if input_ndim == 4: batch_size, channel, height, width = hidden_states.shape hidden_states = hidden_states.view(batch_size, channel, height * width).transpose(1, 2) batch_size, sequence_length, _ = ( hidden_states.shape if encoder_hidden_states is None else encoder_hidden_states.shape ) attention_mask = attn.prepare_attention_mask(attention_mask, sequence_length, batch_size) if attn.group_norm is not None: hidden_states = attn.group_norm(hidden_states.transpose(1, 2)).transpose(1, 2) query = attn.to_q(hidden_states) is_cross_attention = encoder_hidden_states is not None if encoder_hidden_states is None: encoder_hidden_states = hidden_states elif attn.norm_cross: encoder_hidden_states = attn.norm_encoder_hidden_states(encoder_hidden_states) key = attn.to_k(encoder_hidden_states) value = attn.to_v(encoder_hidden_states) query = attn.head_to_batch_dim(query) key = attn.head_to_batch_dim(key) value = attn.head_to_batch_dim(value) attention_probs = attn.get_attention_scores(query, key, attention_mask) # Save attention maps self._save_attention_store(is_cross=is_cross_attention, q=query, k=key, v=value, attn_probs=attention_probs) hidden_states = torch.bmm(attention_probs, value) hidden_states = attn.batch_to_head_dim(hidden_states) # linear proj hidden_states = attn.to_out[0](hidden_states) # dropout hidden_states = attn.to_out[1](hidden_states) if input_ndim == 4: hidden_states = hidden_states.transpose(-1, -2).reshape(batch_size, channel, height, width) if attn.residual_connection: hidden_states = hidden_states + residual hidden_states = hidden_states / attn.rescale_output_factor return hidden_states def reset_attention_stores(self) -> None: self.cross_attention_stores = [] self.self_attention_stores = [] return device = "cuda" if torch.cuda.is_available() else "cpu" tokenizer: CLIPTokenizer = CLIPTokenizer.from_pretrained(MODEL_ID, subfolder="tokenizer") text_encoder: CLIPTextModel = CLIPTextModel.from_pretrained(MODEL_ID, subfolder="text_encoder").to(device) unet: UNet2DConditionModel = UNet2DConditionModel.from_pretrained(MODEL_ID, subfolder="unet").to(device) vae: AutoencoderKL = AutoencoderKL.from_pretrained(MODEL_ID, subfolder="vae").to(device) unet.set_attn_processor( NewAttnProcessor( save_uncond_attention=False, save_cond_attention=True, ) ) @spaces.GPU() @torch.inference_mode() def inference( image_path: str, prompt: str, has_include_special_tokens: bool = False, progress=gr.Progress(track_tqdm=False)): progress(0, "Initializing...") image = Image.open(image_path) image = image.convert("RGB").resize((512, 512)) image = to_tensor(image).unsqueeze(0).to(device) progress(0.1, "Generating text embeddings...") input_ids = tokenizer( prompt, return_tensors="pt", padding="max_length", truncation=True, max_length=tokenizer.model_max_length, ).input_ids.to(device) n_cond_tokens = len( tokenizer( prompt, return_tensors="pt", truncation=True, ).input_ids[0] ) cond_text_embeddings = text_encoder(input_ids).last_hidden_state[0].to(device) uncond_input_ids = tokenizer( "", return_tensors="pt", padding="max_length", truncation=True, max_length=tokenizer.model_max_length, ).input_ids.to(device) uncond_text_embeddings = text_encoder(uncond_input_ids).last_hidden_state[0].to(device) text_embeddings = torch.stack([uncond_text_embeddings, cond_text_embeddings], dim=0) progress(0.2, "Encoding the input image...") init_image = image.to(device) init_latent_dist = vae.encode(init_image).latent_dist # Fix the random seed for reproducibility progress(0.3, "Generating the latents...") generator = torch.Generator(device=device).manual_seed(SEED) latent = init_latent_dist.sample(generator=generator) latent = latent * vae.config['scaling_factor'] # scaling_factor = 0.18215 latents = latent.expand(len(image), unet.config['in_channels'], 512 // 8, 512 // 8) latents_input = torch.cat([latents] * 2).to(device) progress(0.5, "Forwarding the UNet model...") _ = unet(latents_input, UNET_TIMESTEP, encoder_hidden_states=text_embeddings) attn_processor = next(iter(unet.attn_processors.values())) cross_attention_stores = attn_processor.cross_attention_stores progress(0.7, "Processing the cross attention maps...") cross_attention_probs_list = [] # 事前に保存しておいた、全ての Cross-Attention 層の出力を取得 for i, cross_attn_store in enumerate(cross_attention_stores): cross_attn_probs = cross_attn_store.attention_probs # (8, 8x8~64x64, 77) n_heads, scale_pow, n_tokens = cross_attn_probs.shape # scale: 8, 16, 32, 64 scale = int(np.sqrt(scale_pow)) # Multi-head Attentionの平均を取って、1つのAttention Mapにする mean_cross_attn_probs = ( cross_attn_probs .permute(0, 2, 1) # (8, 77, 8x8~64x64) .reshape(n_heads, n_tokens, scale, scale) # (8, 77, 8~64, 8~64) .mean(dim=0) # (77, 8~64, 8~64) ) # scale を 全て 512x512 に合わせる mean_cross_attn_probs = F.interpolate( mean_cross_attn_probs.unsqueeze(0), size=(512, 512), mode='bilinear', align_corners=True ).squeeze(0) # (77, 512, 512) # トークンの間に挿入されたトークンのみを取得 if has_include_special_tokens: mean_cross_attn_probs = mean_cross_attn_probs[:n_cond_tokens, ...] # (n_tokens, 512, 512) else: mean_cross_attn_probs = mean_cross_attn_probs[1:n_cond_tokens - 1, ...] # (n_tokens-2, 512, 512) cross_attention_probs_list.append(mean_cross_attn_probs) # list -> torch.Tensor cross_attention_probs = torch.stack(cross_attention_probs_list) # (16, n_classes, 512, 512) n_layers, n_cond_tokens, _, _ = cross_attention_probs.shape progress(0.9, "Post-processing the attention maps...") image_list = [] # 各行ごとに画像を作成し保存 for i in tqdm(range(cross_attention_probs.shape[0]), desc="Saving images..."): fig, ax = plt.subplots(1, n_cond_tokens, figsize=(16, 4)) for j in range(cross_attention_probs.shape[1]): # 各クラスのアテンションマップを Min-Max 正規化 (0~1) min_val = cross_attention_probs[i, j].min() max_val = cross_attention_probs[i, j].max() cross_attention_probs[i, j] = (cross_attention_probs[i, j] - min_val) / (max_val - min_val) attn_probs = cross_attention_probs[i, j].cpu().detach().numpy() ax[j].imshow(attn_probs, alpha=0.9) ax[j].axis('off') if has_include_special_tokens: ax[j].set_title(tokenizer.decode(input_ids[0, j].item())) else: ax[j].set_title(tokenizer.decode(input_ids[0, j + 1].item())) # 各行ごとの画像を保存 out_dir = Path("output") out_dir.mkdir(exist_ok=True) # 一意なランダムファイル名を生成 unique_filename = str(uuid.uuid4()) filepath = out_dir / f"{unique_filename}.png" plt.savefig(filepath, bbox_inches='tight', pad_inches=0) plt.close(fig) # 保存した画像をPILで読み込んでリストに追加 image_list.append(Image.open(filepath)) attn_processor.reset_attention_stores() return image_list if __name__ == '__main__': unet_mapping = [ "0: Down 64", "1: Down 64", "2: Down 32", "3: Down 32", "4: Down 16", "5: Down 16", "6: Mid 8", "7: Up 16", "8: Up 16", "9: Up 16", "10: Up 32", "11: Up 32", "12: Up 32", "13: Up 64", "14: Up 64", "15: Up 64", ] ca_output = [gr.Image(type="pil", label=unet_mapping[i]) for i in range(16)] iface = gr.Interface( title="Stable Diffusion Attention Visualizer", description="This is a visualizer for the attention maps of the Stable Diffusion model. ", fn=inference, inputs=[ gr.Image(type="filepath", label="Input", width=512, height=512), gr.Textbox(label="Prompt", placeholder="e.g.) A photo of dog..."), gr.Checkbox(label="Include Special Tokens", value=False), ], outputs=ca_output, cache_examples=True, examples=[ ["assets/aeroplane.png", "plane background", False], ["assets/dogcat.png", "a photo of dog", False], ] ) iface.launch()