import os import random import argparse from pathlib import Path import json import itertools import time import torch import torch.nn.functional as F from torchvision import transforms from PIL import Image from transformers import CLIPImageProcessor from accelerate import Accelerator from accelerate.logging import get_logger from accelerate.utils import ProjectConfiguration from diffusers import AutoencoderKL, DDPMScheduler, UNet2DConditionModel from transformers import CLIPTextModel, CLIPTokenizer, CLIPVisionModelWithProjection from ip_adapter.resampler import Resampler from ip_adapter.utils import is_torch2_available if is_torch2_available(): from ip_adapter.attention_processor import IPAttnProcessor2_0 as IPAttnProcessor, AttnProcessor2_0 as AttnProcessor else: from ip_adapter.attention_processor import IPAttnProcessor, AttnProcessor # Dataset class MyDataset(torch.utils.data.Dataset): def __init__(self, json_file, tokenizer, size=512, t_drop_rate=0.05, i_drop_rate=0.05, ti_drop_rate=0.05, image_root_path=""): super().__init__() self.tokenizer = tokenizer self.size = size self.i_drop_rate = i_drop_rate self.t_drop_rate = t_drop_rate self.ti_drop_rate = ti_drop_rate self.image_root_path = image_root_path self.data = json.load(open(json_file)) # list of dict: [{"image_file": "1.png", "text": "A dog"}] self.transform = transforms.Compose([ transforms.Resize(self.size, interpolation=transforms.InterpolationMode.BILINEAR), transforms.CenterCrop(self.size), transforms.ToTensor(), transforms.Normalize([0.5], [0.5]), ]) self.clip_image_processor = CLIPImageProcessor() def __getitem__(self, idx): item = self.data[idx] text = item["text"] image_file = item["image_file"] # read image raw_image = Image.open(os.path.join(self.image_root_path, image_file)) image = self.transform(raw_image.convert("RGB")) clip_image = self.clip_image_processor(images=raw_image, return_tensors="pt").pixel_values # drop drop_image_embed = 0 rand_num = random.random() if rand_num < self.i_drop_rate: drop_image_embed = 1 elif rand_num < (self.i_drop_rate + self.t_drop_rate): text = "" elif rand_num < (self.i_drop_rate + self.t_drop_rate + self.ti_drop_rate): text = "" drop_image_embed = 1 # get text and tokenize text_input_ids = self.tokenizer( text, max_length=self.tokenizer.model_max_length, padding="max_length", truncation=True, return_tensors="pt" ).input_ids return { "image": image, "text_input_ids": text_input_ids, "clip_image": clip_image, "drop_image_embed": drop_image_embed } def __len__(self): return len(self.data) def collate_fn(data): images = torch.stack([example["image"] for example in data]) text_input_ids = torch.cat([example["text_input_ids"] for example in data], dim=0) clip_images = torch.cat([example["clip_image"] for example in data], dim=0) drop_image_embeds = [example["drop_image_embed"] for example in data] return { "images": images, "text_input_ids": text_input_ids, "clip_images": clip_images, "drop_image_embeds": drop_image_embeds } class IPAdapter(torch.nn.Module): """IP-Adapter""" def __init__(self, unet, image_proj_model, adapter_modules): super().__init__() self.unet = unet self.image_proj_model = image_proj_model self.adapter_modules = adapter_modules def forward(self, noisy_latents, timesteps, encoder_hidden_states, image_embeds): ip_tokens = self.image_proj_model(image_embeds) encoder_hidden_states = torch.cat([encoder_hidden_states, ip_tokens], dim=1) # Predict the noise residual and compute loss noise_pred = self.unet(noisy_latents, timesteps, encoder_hidden_states).sample return noise_pred def parse_args(): parser = argparse.ArgumentParser(description="Simple example of a training script.") parser.add_argument( "--pretrained_model_name_or_path", type=str, default=None, required=True, help="Path to pretrained model or model identifier from huggingface.co/models.", ) parser.add_argument( "--data_json_file", type=str, default=None, required=True, help="Training data", ) parser.add_argument( "--data_root_path", type=str, default="", required=True, help="Training data root path", ) parser.add_argument( "--image_encoder_path", type=str, default=None, required=True, help="Path to CLIP image encoder", ) parser.add_argument( "--output_dir", type=str, default="sd-ip_adapter", help="The output directory where the model predictions and checkpoints will be written.", ) parser.add_argument( "--logging_dir", type=str, default="logs", help=( "[TensorBoard](https://www.tensorflow.org/tensorboard) log directory. Will default to" " *output_dir/runs/**CURRENT_DATETIME_HOSTNAME***." ), ) parser.add_argument( "--resolution", type=int, default=512, help=( "The resolution for input images" ), ) parser.add_argument( "--learning_rate", type=float, default=1e-4, help="Learning rate to use.", ) parser.add_argument("--weight_decay", type=float, default=1e-2, help="Weight decay to use.") parser.add_argument("--num_train_epochs", type=int, default=100) parser.add_argument( "--train_batch_size", type=int, default=8, help="Batch size (per device) for the training dataloader." ) parser.add_argument( "--dataloader_num_workers", type=int, default=0, help=( "Number of subprocesses to use for data loading. 0 means that the data will be loaded in the main process." ), ) parser.add_argument( "--save_steps", type=int, default=2000, help=( "Save a checkpoint of the training state every X updates" ), ) parser.add_argument( "--mixed_precision", type=str, default=None, choices=["no", "fp16", "bf16"], help=( "Whether to use mixed precision. Choose between fp16 and bf16 (bfloat16). Bf16 requires PyTorch >=" " 1.10.and an Nvidia Ampere GPU. Default to the value of accelerate config of the current system or the" " flag passed with the `accelerate.launch` command. Use this argument to override the accelerate config." ), ) parser.add_argument( "--report_to", type=str, default="tensorboard", help=( 'The integration to report the results and logs to. Supported platforms are `"tensorboard"`' ' (default), `"wandb"` and `"comet_ml"`. Use `"all"` to report to all integrations.' ), ) parser.add_argument("--local_rank", type=int, default=-1, help="For distributed training: local_rank") args = parser.parse_args() env_local_rank = int(os.environ.get("LOCAL_RANK", -1)) if env_local_rank != -1 and env_local_rank != args.local_rank: args.local_rank = env_local_rank return args def main(): args = parse_args() logging_dir = Path(args.output_dir, args.logging_dir) accelerator_project_config = ProjectConfiguration(project_dir=args.output_dir, logging_dir=logging_dir) accelerator = Accelerator( mixed_precision=args.mixed_precision, log_with=args.report_to, project_config=accelerator_project_config, ) if accelerator.is_main_process: if args.output_dir is not None: os.makedirs(args.output_dir, exist_ok=True) # Load scheduler, tokenizer and models. noise_scheduler = DDPMScheduler.from_pretrained(args.pretrained_model_name_or_path, subfolder="scheduler") tokenizer = CLIPTokenizer.from_pretrained(args.pretrained_model_name_or_path, subfolder="tokenizer") text_encoder = CLIPTextModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="text_encoder") vae = AutoencoderKL.from_pretrained(args.pretrained_model_name_or_path, subfolder="vae") unet = UNet2DConditionModel.from_pretrained(args.pretrained_model_name_or_path, subfolder="unet") image_encoder = CLIPVisionModelWithProjection.from_pretrained(args.image_encoder_path) # freeze parameters of models to save more memory unet.requires_grad_(False) vae.requires_grad_(False) text_encoder.requires_grad_(False) image_encoder.requires_grad_(False) #ip-adapter-plus num_tokens = 16 image_proj_model = Resampler( dim=unet.config.cross_attention_dim, depth=4, dim_head=64, heads=12, num_queries=num_tokens, embedding_dim=image_encoder.config.hidden_size, output_dim=unet.config.cross_attention_dim, ff_mult=4 ) # init adapter modules attn_procs = {} unet_sd = unet.state_dict() for name in unet.attn_processors.keys(): cross_attention_dim = None if name.endswith("attn1.processor") else unet.config.cross_attention_dim if name.startswith("mid_block"): hidden_size = unet.config.block_out_channels[-1] elif name.startswith("up_blocks"): block_id = int(name[len("up_blocks.")]) hidden_size = list(reversed(unet.config.block_out_channels))[block_id] elif name.startswith("down_blocks"): block_id = int(name[len("down_blocks.")]) hidden_size = unet.config.block_out_channels[block_id] if cross_attention_dim is None: attn_procs[name] = AttnProcessor() else: layer_name = name.split(".processor")[0] weights = { "to_k_ip.weight": unet_sd[layer_name + ".to_k.weight"], "to_v_ip.weight": unet_sd[layer_name + ".to_v.weight"], } attn_procs[name] = IPAttnProcessor(hidden_size=hidden_size, cross_attention_dim=cross_attention_dim, num_tokens=num_tokens) attn_procs[name].load_state_dict(weights) unet.set_attn_processor(attn_procs) adapter_modules = torch.nn.ModuleList(unet.attn_processors.values()) ip_adapter = IPAdapter(unet, image_proj_model, adapter_modules) weight_dtype = torch.float32 if accelerator.mixed_precision == "fp16": weight_dtype = torch.float16 elif accelerator.mixed_precision == "bf16": weight_dtype = torch.bfloat16 #unet.to(accelerator.device, dtype=weight_dtype) vae.to(accelerator.device, dtype=weight_dtype) text_encoder.to(accelerator.device, dtype=weight_dtype) image_encoder.to(accelerator.device, dtype=weight_dtype) # optimizer params_to_opt = itertools.chain(ip_adapter.image_proj_model.parameters(), ip_adapter.adapter_modules.parameters()) optimizer = torch.optim.AdamW(params_to_opt, lr=args.learning_rate, weight_decay=args.weight_decay) # dataloader train_dataset = MyDataset(args.data_json_file, tokenizer=tokenizer, size=args.resolution, image_root_path=args.data_root_path) train_dataloader = torch.utils.data.DataLoader( train_dataset, shuffle=True, collate_fn=collate_fn, batch_size=args.train_batch_size, num_workers=args.dataloader_num_workers, ) # Prepare everything with our `accelerator`. ip_adapter, optimizer, train_dataloader = accelerator.prepare(ip_adapter, optimizer, train_dataloader) global_step = 0 for epoch in range(0, args.num_train_epochs): begin = time.perf_counter() for step, batch in enumerate(train_dataloader): load_data_time = time.perf_counter() - begin with accelerator.accumulate(ip_adapter): # Convert images to latent space with torch.no_grad(): latents = vae.encode(batch["images"].to(accelerator.device, dtype=weight_dtype)).latent_dist.sample() latents = latents * vae.config.scaling_factor # Sample noise that we'll add to the latents noise = torch.randn_like(latents) bsz = latents.shape[0] # Sample a random timestep for each image timesteps = torch.randint(0, noise_scheduler.num_train_timesteps, (bsz,), device=latents.device) timesteps = timesteps.long() # Add noise to the latents according to the noise magnitude at each timestep # (this is the forward diffusion process) noisy_latents = noise_scheduler.add_noise(latents, noise, timesteps) clip_images = [] for clip_image, drop_image_embed in zip(batch["clip_images"], batch["drop_image_embeds"]): if drop_image_embed == 1: clip_images.append(torch.zeros_like(clip_image)) else: clip_images.append(clip_image) clip_images = torch.stack(clip_images, dim=0) with torch.no_grad(): image_embeds = image_encoder(clip_images.to(accelerator.device, dtype=weight_dtype), output_hidden_states=True).hidden_states[-2] with torch.no_grad(): encoder_hidden_states = text_encoder(batch["text_input_ids"].to(accelerator.device))[0] noise_pred = ip_adapter(noisy_latents, timesteps, encoder_hidden_states, image_embeds) loss = F.mse_loss(noise_pred.float(), noise.float(), reduction="mean") # Gather the losses across all processes for logging (if we use distributed training). avg_loss = accelerator.gather(loss.repeat(args.train_batch_size)).mean().item() # Backpropagate accelerator.backward(loss) optimizer.step() optimizer.zero_grad() if accelerator.is_main_process: print("Epoch {}, step {}, data_time: {}, time: {}, step_loss: {}".format( epoch, step, load_data_time, time.perf_counter() - begin, avg_loss)) global_step += 1 if global_step % args.save_steps == 0: save_path = os.path.join(args.output_dir, f"checkpoint-{global_step}") accelerator.save_state(save_path) begin = time.perf_counter() if __name__ == "__main__": main()