# An official reimplemented version of Marigold training script. # Last modified: 2024-04-29 # # Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved. # # Licensed under the Apache License, Version 2.0 (the "License"); # you may not use this file except in compliance with the License. # You may obtain a copy of the License at # # http://www.apache.org/licenses/LICENSE-2.0 # # Unless required by applicable law or agreed to in writing, software # distributed under the License is distributed on an "AS IS" BASIS, # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. # See the License for the specific language governing permissions and # limitations under the License. # -------------------------------------------------------------------------- # If you find this code useful, we kindly ask you to cite our paper in your work. # Please find bibtex at: https://github.com/prs-eth/Marigold#-citation # If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold. # More information about the method can be found at https://marigoldmonodepth.github.io # -------------------------------------------------------------------------- from diffusers import StableDiffusionInpaintPipeline import logging import os import pdb import cv2 import shutil import json from pycocotools import mask as coco_mask from datetime import datetime from typing import List, Union import random import safetensors import numpy as np import torch from diffusers import DDPMScheduler from omegaconf import OmegaConf from torch.nn import Conv2d from torch.nn.parameter import Parameter from torch.optim import Adam from torch.optim.lr_scheduler import LambdaLR from torch.utils.data import DataLoader, Dataset from tqdm import tqdm from PIL import Image # import torch.optim.lr_scheduler from diffusers.schedulers import PNDMScheduler from torchvision.transforms.functional import pil_to_tensor from marigold.marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput from src.util import metric from src.util.data_loader import skip_first_batches from src.util.logging_util import tb_logger, eval_dic_to_text from src.util.loss import get_loss from src.util.lr_scheduler import IterExponential from src.util.metric import MetricTracker from src.util.multi_res_noise import multi_res_noise_like from src.util.alignment import align_depth_least_square, depth2disparity, disparity2depth from src.util.seeding import generate_seed_sequence from accelerate import Accelerator import os from torchvision.transforms import InterpolationMode, Resize, CenterCrop import torchvision.transforms as transforms # os.environ['CUDA_LAUNCH_BLOCKING'] = '1' class MarigoldInpaintTrainer: def __init__( self, cfg: OmegaConf, model: MarigoldPipeline, train_dataloader: DataLoader, device, base_ckpt_dir, out_dir_ckpt, out_dir_eval, out_dir_vis, accumulation_steps: int, depth_model = None, separate_list: List = None, val_dataloaders: List[DataLoader] = None, vis_dataloaders: List[DataLoader] = None, train_dataset: Dataset = None, timestep_method: str = 'unidiffuser', connection: bool = False ): self.cfg: OmegaConf = cfg self.model: MarigoldPipeline = model self.depth_model = depth_model self.device = device self.seed: Union[int, None] = ( self.cfg.trainer.init_seed ) # used to generate seed sequence, set to `None` to train w/o seeding self.out_dir_ckpt = out_dir_ckpt self.out_dir_eval = out_dir_eval self.out_dir_vis = out_dir_vis self.train_loader: DataLoader = train_dataloader self.val_loaders: List[DataLoader] = val_dataloaders self.vis_loaders: List[DataLoader] = vis_dataloaders self.accumulation_steps: int = accumulation_steps self.separate_list = separate_list self.timestep_method = timestep_method self.train_dataset = train_dataset self.connection = connection # Adapt input layers # if 8 != self.model.unet.config["in_channels"]: # self._replace_unet_conv_in() # if 8 != self.model.unet.config["out_channels"]: # self._replace_unet_conv_out() self.train_metrics = MetricTracker(*["loss", 'rgb_loss', 'depth_loss']) # self.generator = torch.Generator('cuda:0').manual_seed(1024) # Encode empty text prompt self.model.encode_empty_text() self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) self.model.unet.enable_xformers_memory_efficient_attention() # Trainability self.model.text_encoder.requires_grad_(False) # self.model.unet.requires_grad_(True) grad_part = filter(lambda p: p.requires_grad, self.model.unet.parameters()) # Optimizer !should be defined after input layer is adapted lr = self.cfg.lr self.optimizer = Adam(grad_part, lr=lr) total_params = sum(p.numel() for p in self.model.unet.parameters()) total_params_m = total_params / 1_000_000 print(f"Total parameters: {total_params_m:.2f}M") trainable_params = sum(p.numel() for p in self.model.unet.parameters() if p.requires_grad) trainable_params_m = trainable_params / 1_000_000 print(f"Trainable parameters: {trainable_params_m:.2f}M") # LR scheduler lr_func = IterExponential( total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, ) self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) # Loss self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) # Training noise scheduler # self.rgb_training_noise_scheduler: PNDMScheduler = PNDMScheduler.from_pretrained( # os.path.join( # cfg.trainer.rgb_training_noise_scheduler.pretrained_path, # "scheduler", # ) # ) self.rgb_training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( cfg.trainer.depth_training_noise_scheduler.pretrained_path, subfolder="scheduler") self.depth_training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( cfg.trainer.depth_training_noise_scheduler.pretrained_path, subfolder="scheduler") self.rgb_prediction_type = self.rgb_training_noise_scheduler.config.prediction_type # assert ( # self.rgb_prediction_type == self.model.rgb_scheduler.config.prediction_type # ), "Different prediction types" self.depth_prediction_type = self.depth_training_noise_scheduler.config.prediction_type assert ( self.depth_prediction_type == self.model.depth_scheduler.config.prediction_type ), "Different prediction types" self.scheduler_timesteps = ( self.rgb_training_noise_scheduler.config.num_train_timesteps ) # Settings self.max_epoch = self.cfg.max_epoch self.max_iter = self.cfg.max_iter self.gradient_accumulation_steps = accumulation_steps self.gt_depth_type = self.cfg.gt_depth_type self.gt_mask_type = self.cfg.gt_mask_type self.save_period = self.cfg.trainer.save_period self.backup_period = self.cfg.trainer.backup_period self.val_period = self.cfg.trainer.validation_period self.vis_period = self.cfg.trainer.visualization_period # Multi-resolution noise self.apply_multi_res_noise = self.cfg.multi_res_noise is not None if self.apply_multi_res_noise: self.mr_noise_strength = self.cfg.multi_res_noise.strength self.annealed_mr_noise = self.cfg.multi_res_noise.annealed self.mr_noise_downscale_strategy = ( self.cfg.multi_res_noise.downscale_strategy ) # Internal variables self.epoch = 0 self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training self.effective_iter = 0 # how many times optimizer.step() is called self.in_evaluation = False self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming def _replace_unet_conv_in(self): # replace the first layer to accept 8 in_channels _weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3] _bias = self.model.unet.conv_in.bias.clone() # [320] zero_weight = torch.zeros(_weight.shape).to(_weight.device) _weight = torch.cat([_weight, zero_weight], dim=1) # _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s) # half the activation magnitude # _weight *= 0.5 # new conv_in channel _n_convin_out_channel = self.model.unet.conv_in.out_channels _new_conv_in = Conv2d( 8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) ) _new_conv_in.weight = Parameter(_weight) _new_conv_in.bias = Parameter(_bias) self.model.unet.conv_in = _new_conv_in logging.info("Unet conv_in layer is replaced") # replace config self.model.unet.config["in_channels"] = 8 logging.info("Unet config is updated") return def parallel_train(self, t_end=None, accelerator=None): logging.info("Start training") self.model, self.optimizer, self.train_loader, self.lr_scheduler = accelerator.prepare( self.model, self.optimizer, self.train_loader, self.lr_scheduler ) self.depth_model = accelerator.prepare(self.depth_model) self.accelerator = accelerator if os.path.exists(os.path.join(self.out_dir_ckpt, 'latest')): accelerator.load_state(os.path.join(self.out_dir_ckpt, 'latest')) self.load_miscs(os.path.join(self.out_dir_ckpt, 'latest')) # if accelerator.is_main_process: # self._inpaint_rgbd() self.train_metrics.reset() accumulated_step = 0 for epoch in range(self.epoch, self.max_epoch + 1): self.epoch = epoch logging.debug(f"epoch: {self.epoch}") # Skip previous batches when resume for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): self.model.unet.train() # globally consistent random generators if self.seed is not None: local_seed = self._get_next_seed() rand_num_generator = torch.Generator(device=self.model.device) rand_num_generator.manual_seed(local_seed) else: rand_num_generator = None # >>> With gradient accumulation >>> # Get data rgb = batch["rgb_norm"].to(self.model.device) with torch.no_grad(): disparities = self.depth_model(batch["rgb_int"].numpy().astype(np.uint8), 518, device=self.model.device) if len(disparities.shape) == 2: disparities = disparities.unsqueeze(0) depth_gt_for_latent = [] for disparity_map in disparities: depth_map = ((disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min())) * 2 - 1 depth_gt_for_latent.append(depth_map) depth_gt_for_latent = torch.stack(depth_gt_for_latent, dim=0) batch_size = rgb.shape[0] mask = self.model.mask_processor.preprocess(batch['mask'] * 255).to(self.model.device) rgb_timesteps = torch.randint( 0, self.scheduler_timesteps, (batch_size,), device=self.model.device, generator=rand_num_generator, ).long() # [B] depth_timesteps = rgb_timesteps rgb_flag = 1 depth_flag = 1 if self.timestep_method == 'joint': rgb_mask = mask depth_mask = mask elif self.timestep_method == 'partition': rand_num = random.random() if rand_num < 0.5: # joint prediction rgb_mask = mask depth_mask = mask elif rand_num < 0.75: # full rgb; depth prediction rgb_flag = 0 rgb_mask = torch.zeros_like(mask) depth_mask = mask else: depth_flag = 0 rgb_mask = mask if random.random() < 0.5: depth_mask = torch.zeros_like(mask) # full depth; rgb prediction else: depth_mask = mask # partial depth; rgb prediction masked_rgb = rgb * (rgb_mask < 0.5) masked_depth = depth_gt_for_latent * (depth_mask.squeeze() < 0.5) with torch.no_grad(): # Encode image rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w] mask_rgb_latent = self.model.encode_rgb(masked_rgb) if depth_timesteps.sum() == 0: gt_depth_latent = self.encode_depth(masked_depth) else: gt_depth_latent = self.encode_depth(depth_gt_for_latent) mask_depth_latent = self.encode_depth(masked_depth) rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:]) depth_mask = torch.nn.functional.interpolate(depth_mask, size=gt_depth_latent.shape[-2:]) # Sample noise rgb_noise = torch.randn( rgb_latent.shape, device=self.model.device, generator=rand_num_generator, ) # [B, 4, h, w] depth_noise = torch.randn( gt_depth_latent.shape, device=self.model.device, generator=rand_num_generator, ) # [B, 4, h, w] if rgb_timesteps.sum() == 0: noisy_rgb_latents = rgb_latent else: noisy_rgb_latents = self.rgb_training_noise_scheduler.add_noise( rgb_latent, rgb_noise, rgb_timesteps ) # [B, 4, h, w] if depth_timesteps.sum() == 0: noisy_depth_latents = gt_depth_latent else: noisy_depth_latents = self.depth_training_noise_scheduler.add_noise( gt_depth_latent, depth_noise, depth_timesteps ) # [B, 4, h, w] noisy_latents = torch.cat( [noisy_rgb_latents, rgb_mask, mask_rgb_latent, mask_depth_latent, noisy_depth_latents, depth_mask, mask_rgb_latent, mask_depth_latent], dim=1 ).float() # [B, 9*2, h, w] # Text embedding input_ids = self.model.tokenizer( batch['text'], padding="max_length", max_length=self.model.tokenizer.model_max_length, truncation=True, return_tensors="pt", ) input_ids = {k: v.to(self.model.device) for k, v in input_ids.items()} text_embed = self.model.text_encoder(**input_ids)[0] model_pred = self.model.unet( noisy_latents, rgb_timesteps, depth_timesteps, text_embed, controlnet_connection=self.connection ).sample # [B, 8, h, w] if torch.isnan(model_pred).any(): logging.warning("model_pred contains NaN.") # Get the target for loss depending on the prediction type if "sample" == self.rgb_prediction_type: rgb_target = rgb_latent elif "epsilon" == self.rgb_prediction_type: rgb_target = rgb_latent elif "v_prediction" == self.rgb_prediction_type: rgb_target = self.rgb_training_noise_scheduler.get_velocity( rgb_latent, rgb_noise, rgb_timesteps ) # [B, 4, h, w] else: raise ValueError(f"Unknown rgb prediction type {self.prediction_type}") if "sample" == self.depth_prediction_type: depth_target = gt_depth_latent elif "epsilon" == self.depth_prediction_type: depth_target = gt_depth_latent elif "v_prediction" == self.depth_prediction_type: depth_target = self.depth_training_noise_scheduler.get_velocity( gt_depth_latent, depth_noise, depth_timesteps ) # [B, 4, h, w] else: raise ValueError(f"Unknown depth prediction type {self.prediction_type}") # Masked latent loss with accelerator.accumulate(self.model): rgb_loss = self.loss(model_pred[:, 0:4, :, :].float(), rgb_target.float()) depth_loss = self.loss(model_pred[:, 4:, :, :].float(), depth_target.float()) if rgb_flag == 0: loss = depth_loss elif depth_flag == 0: loss = rgb_loss else: loss = self.cfg.loss.depth_factor * depth_loss + (1 - self.cfg.loss.depth_factor) * rgb_loss self.train_metrics.update("loss", loss.item()) self.train_metrics.update("rgb_loss", rgb_loss.item()) self.train_metrics.update("depth_loss", depth_loss.item()) # loss = loss / self.gradient_accumulation_steps accelerator.backward(loss) self.optimizer.step() self.optimizer.zero_grad() # loss.backward() self.n_batch_in_epoch += 1 # print(accelerator.process_index, self.lr_scheduler.get_last_lr()) self.lr_scheduler.step(self.effective_iter) if accelerator.sync_gradients: accumulated_step += 1 if accumulated_step >= self.gradient_accumulation_steps: accumulated_step = 0 self.effective_iter += 1 if accelerator.is_main_process: # Log to tensorboard if self.effective_iter == 1: self._inpaint_rgbd() accumulated_loss = self.train_metrics.result()["loss"] rgb_loss = self.train_metrics.result()["rgb_loss"] depth_loss = self.train_metrics.result()["depth_loss"] tb_logger.log_dic( { f"train/{k}": v for k, v in self.train_metrics.result().items() }, global_step=self.effective_iter, ) tb_logger.writer.add_scalar( "lr", self.lr_scheduler.get_last_lr()[0], global_step=self.effective_iter, ) tb_logger.writer.add_scalar( "n_batch_in_epoch", self.n_batch_in_epoch, global_step=self.effective_iter, ) logging.info( f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}, rgb_loss={rgb_loss:.5f}, depth_loss={depth_loss:.5f}" ) accelerator.wait_for_everyone() if self.save_period > 0 and 0 == self.effective_iter % self.save_period: accelerator.save_state(output_dir=os.path.join(self.out_dir_ckpt, 'latest')) unwrapped_model = accelerator.unwrap_model(self.model) if accelerator.is_main_process: accelerator.save_model(unwrapped_model.unet, os.path.join(self.out_dir_ckpt, 'latest'), safe_serialization=False) self.save_miscs('latest') self._inpaint_rgbd() accelerator.wait_for_everyone() if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: unwrapped_model = accelerator.unwrap_model(self.model) if accelerator.is_main_process: accelerator.save_model(unwrapped_model.unet, os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name()), safe_serialization=False) accelerator.wait_for_everyone() # End of training if self.max_iter > 0 and self.effective_iter >= self.max_iter: unwrapped_model = accelerator.unwrap_model(self.model) if accelerator.is_main_process: unwrapped_model.unet.save_pretrained( os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name())) accelerator.wait_for_everyone() return torch.cuda.empty_cache() # <<< Effective batch end <<< # Epoch end self.n_batch_in_epoch = 0 def _inpaint_rgbd(self): image_path = ['/dataset/~sa-1b/data/sa_001000/sa_10000335.jpg', '/dataset/~sa-1b/data/sa_000357/sa_3572319.jpg', '/dataset/~sa-1b/data/sa_000045/sa_457934.jpg'] prompt = ['A white car is parked in front of the factory', 'church with cemetery next to it', 'A house with a red brick roof'] imgs = [pil_to_tensor(Image.open(p)) for p in image_path] depth_imgs = [self.depth_model(img.unsqueeze(0).cpu().numpy()) for img in imgs] masks = [] for rgb_path in image_path: anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations'] random.shuffle(anno) object_num = random.randint(5, 10) mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8) for single_anno in (anno[0:object_num] if len(anno)>object_num else anno): mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8) masks.append(torch.from_numpy(mask)) resize_transform = transforms.Compose([ Resize(size=512, interpolation=InterpolationMode.NEAREST_EXACT), CenterCrop(size=[512, 512])]) imgs = [resize_transform(img) for img in imgs] depth_imgs = [resize_transform(depth_img.unsqueeze(0)) for depth_img in depth_imgs] masks = [resize_transform(mask.unsqueeze(0)) for mask in masks] # pdb.set_trace() for i in range(len(imgs)): output_image = self.model._rgbd_inpaint(imgs[i], depth_imgs[i], masks[i], [prompt[i]], processing_res=512, mode='joint_inpaint') tb_logger.writer.add_image(f'{prompt[i]}', pil_to_tensor(output_image), self.effective_iter) def encode_depth(self, depth_in): # stack depth into 3-channel stacked = self.stack_depth_images(depth_in) # encode using VAE encoder depth_latent = self.model.encode_rgb(stacked) return depth_latent @staticmethod def stack_depth_images(depth_in): if 4 == len(depth_in.shape): stacked = depth_in.repeat(1, 3, 1, 1) elif 3 == len(depth_in.shape): stacked = depth_in.unsqueeze(1) stacked = stacked.repeat(1, 3, 1, 1) elif 2 == len(depth_in.shape): stacked = depth_in.unsqueeze(0).unsqueeze(0) stacked = stacked.repeat(1, 3, 1, 1) return stacked def visualize(self): for val_loader in self.vis_loaders: vis_dataset_name = val_loader.dataset.disp_name vis_out_dir = os.path.join( self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name ) os.makedirs(vis_out_dir, exist_ok=True) _ = self.validate_single_dataset( data_loader=val_loader, metric_tracker=self.val_metrics, save_to_dir=vis_out_dir, ) def _get_next_seed(self): if 0 == len(self.global_seed_sequence): self.global_seed_sequence = generate_seed_sequence( initial_seed=self.seed, length=self.max_iter * self.gradient_accumulation_steps, ) logging.info( f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" ) return self.global_seed_sequence.pop() def save_miscs(self, ckpt_name): ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) state = { "config": self.cfg, "effective_iter": self.effective_iter, "epoch": self.epoch, "n_batch_in_epoch": self.n_batch_in_epoch, "global_seed_sequence": self.global_seed_sequence, } train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") torch.save(state, train_state_path) logging.info(f"Misc state is saved to: {train_state_path}") def load_miscs(self, ckpt_path): checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) self.effective_iter = checkpoint["effective_iter"] self.epoch = checkpoint["epoch"] self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] self.global_seed_sequence = checkpoint["global_seed_sequence"] logging.info(f"Misc state is loaded from {ckpt_path}") def save_checkpoint(self, ckpt_name, save_train_state): ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) logging.info(f"Saving checkpoint to: {ckpt_dir}") # Backup previous checkpoint temp_ckpt_dir = None if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): temp_ckpt_dir = os.path.join( os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" ) if os.path.exists(temp_ckpt_dir): shutil.rmtree(temp_ckpt_dir, ignore_errors=True) os.rename(ckpt_dir, temp_ckpt_dir) logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") # Save UNet unet_path = os.path.join(ckpt_dir, "unet") self.model.unet.save_pretrained(unet_path, safe_serialization=False) logging.info(f"UNet is saved to: {unet_path}") if save_train_state: state = { "config": self.cfg, "effective_iter": self.effective_iter, "epoch": self.epoch, "n_batch_in_epoch": self.n_batch_in_epoch, "best_metric": self.best_metric, "in_evaluation": self.in_evaluation, "global_seed_sequence": self.global_seed_sequence, } train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") torch.save(state, train_state_path) # iteration indicator f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") f.close() logging.info(f"Trainer state is saved to: {train_state_path}") # Remove temp ckpt if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): shutil.rmtree(temp_ckpt_dir, ignore_errors=True) logging.debug("Old checkpoint backup is removed.") def load_checkpoint( self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True ): logging.info(f"Loading checkpoint from: {ckpt_path}") # Load UNet _model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") self.model.unet.load_state_dict( torch.load(_model_path, map_location=self.device) ) self.model.unet.to(self.device) logging.info(f"UNet parameters are loaded from {_model_path}") # Load training states if load_trainer_state: checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) self.effective_iter = checkpoint["effective_iter"] self.epoch = checkpoint["epoch"] self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] self.in_evaluation = checkpoint["in_evaluation"] self.global_seed_sequence = checkpoint["global_seed_sequence"] self.best_metric = checkpoint["best_metric"] self.optimizer.load_state_dict(checkpoint["optimizer"]) logging.info(f"optimizer state is loaded from {ckpt_path}") if resume_lr_scheduler: self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) logging.info(f"LR scheduler state is loaded from {ckpt_path}") logging.info( f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" ) return def _get_backup_ckpt_name(self): return f"iter_{self.effective_iter:06d}"