import os import json import torch from model.attn_processor import AttnProcessor2_0, SkipAttnProcessor def init_adapter(unet, cross_attn_cls=SkipAttnProcessor, self_attn_cls=None, cross_attn_dim=None, **kwargs): if cross_attn_dim is None: cross_attn_dim = unet.config.cross_attention_dim attn_procs = {} for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else cross_attn_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: if self_attn_cls is not None: attn_procs[name] = self_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) else: # retain the original attn processor attn_procs[name] = AttnProcessor2_0(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) else: attn_procs[name] = cross_attn_cls(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, **kwargs) unet.set_attn_processor(attn_procs) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) return adapter_modules def init_diffusion_model(diffusion_model_name_or_path, unet_class=None): from diffusers import AutoencoderKL from transformers import CLIPTextModel, CLIPTokenizer text_encoder = CLIPTextModel.from_pretrained(diffusion_model_name_or_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(diffusion_model_name_or_path, subfolder="vae") tokenizer = CLIPTokenizer.from_pretrained(diffusion_model_name_or_path, subfolder="tokenizer") try: unet_folder = os.path.join(diffusion_model_name_or_path, "unet") unet_configs = json.load(open(os.path.join(unet_folder, "config.json"), "r")) unet = unet_class(**unet_configs) unet.load_state_dict(torch.load(os.path.join(unet_folder, "diffusion_pytorch_model.bin"), map_location="cpu"), strict=True) except: unet = None return text_encoder, vae, tokenizer, unet def attn_of_unet(unet): attn_blocks = torch.nn.ModuleList() for name, param in unet.named_modules(): if "attn1" in name: attn_blocks.append(param) return attn_blocks def get_trainable_module(unet, trainable_module_name): if trainable_module_name == "unet": return unet elif trainable_module_name == "transformer": trainable_modules = torch.nn.ModuleList() for blocks in [unet.down_blocks, unet.mid_block, unet.up_blocks]: if hasattr(blocks, "attentions"): trainable_modules.append(blocks.attentions) else: for block in blocks: if hasattr(block, "attentions"): trainable_modules.append(block.attentions) return trainable_modules elif trainable_module_name == "attention": attn_blocks = torch.nn.ModuleList() for name, param in unet.named_modules(): if "attn1" in name: attn_blocks.append(param) return attn_blocks else: raise ValueError(f"Unknown trainable_module_name: {trainable_module_name}")