|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import logging |
|
import os |
|
import pdb |
|
import shutil |
|
from datetime import datetime |
|
from typing import List, Union |
|
import random |
|
import safetensors |
|
import numpy as np |
|
import torch |
|
from diffusers import DDPMScheduler |
|
from omegaconf import OmegaConf |
|
from torch.nn import Conv2d |
|
from torch.nn.parameter import Parameter |
|
from torch.optim import Adam |
|
from torch.optim.lr_scheduler import LambdaLR |
|
from torch.utils.data import DataLoader |
|
from tqdm import tqdm |
|
from PIL import Image |
|
|
|
|
|
from marigold.marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput |
|
from src.util import metric |
|
from src.util.data_loader import skip_first_batches |
|
from src.util.logging_util import tb_logger, eval_dic_to_text |
|
from src.util.loss import get_loss |
|
from src.util.lr_scheduler import IterExponential |
|
from src.util.metric import MetricTracker |
|
from src.util.multi_res_noise import multi_res_noise_like |
|
from src.util.alignment import align_depth_least_square, depth2disparity, disparity2depth |
|
from src.util.seeding import generate_seed_sequence |
|
from accelerate import Accelerator |
|
import os |
|
|
|
|
|
class MarigoldTrainer: |
|
def __init__( |
|
self, |
|
cfg: OmegaConf, |
|
model: MarigoldPipeline, |
|
train_dataloader: DataLoader, |
|
device, |
|
base_ckpt_dir, |
|
out_dir_ckpt, |
|
out_dir_eval, |
|
out_dir_vis, |
|
accumulation_steps: int, |
|
depth_model = None, |
|
separate_list: List = None, |
|
val_dataloaders: List[DataLoader] = None, |
|
vis_dataloaders: List[DataLoader] = None, |
|
timestep_method: str = 'unidiffuser' |
|
): |
|
self.cfg: OmegaConf = cfg |
|
self.model: MarigoldPipeline = model |
|
self.depth_model = depth_model |
|
self.device = device |
|
self.seed: Union[int, None] = ( |
|
self.cfg.trainer.init_seed |
|
) |
|
self.out_dir_ckpt = out_dir_ckpt |
|
self.out_dir_eval = out_dir_eval |
|
self.out_dir_vis = out_dir_vis |
|
self.train_loader: DataLoader = train_dataloader |
|
self.val_loaders: List[DataLoader] = val_dataloaders |
|
self.vis_loaders: List[DataLoader] = vis_dataloaders |
|
self.accumulation_steps: int = accumulation_steps |
|
self.separate_list = separate_list |
|
self.timestep_method = timestep_method |
|
|
|
|
|
|
|
|
|
|
|
|
|
self.prompt = ['a view of a city skyline from a bridge', |
|
'a man and a woman sitting on a couch', |
|
'a black car parked in a parking lot next to the water', |
|
'Enchanted forest with glowing plants, fairies, and ancient castle.', |
|
'Futuristic city with skyscrapers, neon lights, and hovering vehicles.', |
|
'Fantasy mountain landscape with waterfalls, dragons, and mythical creatures.'] |
|
|
|
|
|
|
|
self.model.encode_empty_text() |
|
self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device) |
|
|
|
self.model.unet.enable_xformers_memory_efficient_attention() |
|
|
|
|
|
self.model.text_encoder.requires_grad_(False) |
|
|
|
|
|
grad_part = filter(lambda p: p.requires_grad, self.model.unet.parameters()) |
|
|
|
|
|
lr = self.cfg.lr |
|
self.optimizer = Adam(grad_part, lr=lr) |
|
|
|
total_params = sum(p.numel() for p in self.model.unet.parameters()) |
|
total_params_m = total_params / 1_000_000 |
|
print(f"Total parameters: {total_params_m:.2f}M") |
|
trainable_params = sum(p.numel() for p in self.model.unet.parameters() if p.requires_grad) |
|
trainable_params_m = trainable_params / 1_000_000 |
|
print(f"Trainable parameters: {trainable_params_m:.2f}M") |
|
|
|
|
|
lr_func = IterExponential( |
|
total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter, |
|
final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio, |
|
warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps, |
|
) |
|
self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func) |
|
|
|
|
|
self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs) |
|
|
|
|
|
self.training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained( |
|
os.path.join( |
|
cfg.trainer.training_noise_scheduler.pretrained_path, |
|
"scheduler", |
|
) |
|
) |
|
|
|
self.prediction_type = self.training_noise_scheduler.config.prediction_type |
|
assert ( |
|
self.prediction_type == self.model.scheduler.config.prediction_type |
|
), "Different prediction types" |
|
self.scheduler_timesteps = ( |
|
self.training_noise_scheduler.config.num_train_timesteps |
|
) |
|
|
|
|
|
self.metric_funcs = [getattr(metric, _met) for _met in cfg.eval.eval_metrics] |
|
self.train_metrics = MetricTracker(*["loss", 'rgb_loss', 'depth_loss']) |
|
self.val_metrics = MetricTracker(*[m.__name__ for m in self.metric_funcs]) |
|
|
|
self.main_val_metric = cfg.validation.main_val_metric |
|
self.main_val_metric_goal = cfg.validation.main_val_metric_goal |
|
assert ( |
|
self.main_val_metric in cfg.eval.eval_metrics |
|
), f"Main eval metric `{self.main_val_metric}` not found in evaluation metrics." |
|
self.best_metric = 1e8 if "minimize" == self.main_val_metric_goal else -1e8 |
|
|
|
|
|
self.max_epoch = self.cfg.max_epoch |
|
self.max_iter = self.cfg.max_iter |
|
self.gradient_accumulation_steps = accumulation_steps |
|
self.gt_depth_type = self.cfg.gt_depth_type |
|
self.gt_mask_type = self.cfg.gt_mask_type |
|
self.save_period = self.cfg.trainer.save_period |
|
self.backup_period = self.cfg.trainer.backup_period |
|
self.val_period = self.cfg.trainer.validation_period |
|
self.vis_period = self.cfg.trainer.visualization_period |
|
|
|
|
|
self.apply_multi_res_noise = self.cfg.multi_res_noise is not None |
|
if self.apply_multi_res_noise: |
|
self.mr_noise_strength = self.cfg.multi_res_noise.strength |
|
self.annealed_mr_noise = self.cfg.multi_res_noise.annealed |
|
self.mr_noise_downscale_strategy = ( |
|
self.cfg.multi_res_noise.downscale_strategy |
|
) |
|
|
|
|
|
self.epoch = 0 |
|
self.n_batch_in_epoch = 0 |
|
self.effective_iter = 0 |
|
self.in_evaluation = False |
|
self.global_seed_sequence: List = [] |
|
|
|
def _replace_unet_conv_in(self): |
|
|
|
_weight = self.model.unet.conv_in.weight.clone() |
|
_bias = self.model.unet.conv_in.bias.clone() |
|
zero_weight = torch.zeros(_weight.shape).to(_weight.device) |
|
_weight = torch.cat([_weight, zero_weight], dim=1) |
|
|
|
|
|
|
|
|
|
_n_convin_out_channel = self.model.unet.conv_in.out_channels |
|
_new_conv_in = Conv2d( |
|
8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) |
|
) |
|
_new_conv_in.weight = Parameter(_weight) |
|
_new_conv_in.bias = Parameter(_bias) |
|
self.model.unet.conv_in = _new_conv_in |
|
logging.info("Unet conv_in layer is replaced") |
|
|
|
self.model.unet.config["in_channels"] = 8 |
|
logging.info("Unet config is updated") |
|
return |
|
|
|
def _replace_unet_conv_out(self): |
|
|
|
_weight = self.model.unet.conv_out.weight.clone() |
|
_bias = self.model.unet.conv_out.bias.clone() |
|
_weight = _weight.repeat((2, 1, 1, 1)) |
|
_bias = _bias.repeat((2)) |
|
|
|
|
|
|
|
_n_convin_out_channel = self.model.unet.conv_out.out_channels |
|
_new_conv_out = Conv2d( |
|
_n_convin_out_channel, 8, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1) |
|
) |
|
_new_conv_out.weight = Parameter(_weight) |
|
_new_conv_out.bias = Parameter(_bias) |
|
self.model.unet.conv_out = _new_conv_out |
|
logging.info("Unet conv_out layer is replaced") |
|
|
|
self.model.unet.config["out_channels"] = 8 |
|
logging.info("Unet config is updated") |
|
return |
|
|
|
def parallel_train(self, t_end=None, accelerator=None): |
|
logging.info("Start training") |
|
|
|
self.model, self.optimizer, self.train_loader, self.lr_scheduler = accelerator.prepare( |
|
self.model, self.optimizer, self.train_loader, self.lr_scheduler |
|
) |
|
self.depth_model = accelerator.prepare(self.depth_model) |
|
|
|
self.accelerator = accelerator |
|
if self.val_loaders is not None: |
|
for idx, loader in enumerate(self.val_loaders): |
|
self.val_loaders[idx] = accelerator.prepare(loader) |
|
|
|
if os.path.exists(os.path.join(self.out_dir_ckpt, 'latest')): |
|
accelerator.load_state(os.path.join(self.out_dir_ckpt, 'latest')) |
|
self.load_miscs(os.path.join(self.out_dir_ckpt, 'latest')) |
|
|
|
self.train_metrics.reset() |
|
accumulated_step = 0 |
|
for epoch in range(self.epoch, self.max_epoch + 1): |
|
self.epoch = epoch |
|
logging.debug(f"epoch: {self.epoch}") |
|
|
|
|
|
for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch): |
|
self.model.unet.train() |
|
|
|
|
|
if self.seed is not None: |
|
local_seed = self._get_next_seed() |
|
rand_num_generator = torch.Generator(device=self.model.device) |
|
rand_num_generator.manual_seed(local_seed) |
|
else: |
|
rand_num_generator = None |
|
|
|
|
|
|
|
|
|
rgb = batch["rgb_norm"].to(self.model.device) |
|
if self.gt_depth_type not in batch: |
|
with torch.no_grad(): |
|
disparities = self.depth_model(batch["rgb_int"].numpy().astype(np.uint8), 518, device=self.model.device) |
|
depth_gt_for_latent = [] |
|
for disparity_map in disparities: |
|
depth_map = ((disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min())) * 2 - 1 |
|
depth_gt_for_latent.append(depth_map) |
|
depth_gt_for_latent = torch.stack(depth_gt_for_latent, dim=0) |
|
else: |
|
if "least_square_disparity" == self.cfg.eval.alignment: |
|
|
|
depth_raw_ts = batch["depth_raw_linear"].squeeze() |
|
depth_raw = depth_raw_ts.cpu().numpy() |
|
|
|
disparities = depth2disparity( |
|
depth=depth_raw |
|
) |
|
depth_gt_for_latent = [] |
|
for disparity_map in disparities: |
|
depth_map = ((disparity_map - disparity_map.min()) / ( |
|
disparity_map.max() - disparity_map.min())) * 2 - 1 |
|
depth_gt_for_latent.append(torch.from_numpy(depth_map)) |
|
depth_gt_for_latent = torch.stack(depth_gt_for_latent, dim=0).to(self.model.device) |
|
else: |
|
depth_gt_for_latent = batch[self.gt_depth_type].to(self.model.device) |
|
|
|
batch_size = rgb.shape[0] |
|
|
|
if self.gt_mask_type is not None: |
|
valid_mask_for_latent = batch[self.gt_mask_type].to(self.model.device) |
|
invalid_mask = ~valid_mask_for_latent |
|
valid_mask_down = ~torch.max_pool2d( |
|
invalid_mask.float(), 8, 8 |
|
).bool() |
|
valid_mask_down = valid_mask_down.repeat((1, 4, 1, 1)) |
|
|
|
with torch.no_grad(): |
|
|
|
rgb_latent = self.model.encode_rgb(rgb) |
|
|
|
gt_depth_latent = self.encode_depth( |
|
depth_gt_for_latent |
|
) |
|
|
|
if self.cfg.loss.depth_factor == 1: |
|
rgb_timesteps = torch.zeros( |
|
(batch_size), |
|
device=self.model.device |
|
).long() |
|
depth_timesteps = torch.randint( |
|
0, |
|
self.scheduler_timesteps, |
|
(batch_size,), |
|
device=self.model.device, |
|
generator=rand_num_generator, |
|
).long() |
|
elif self.timestep_method == 'unidiffuser': |
|
rgb_timesteps = torch.randint( |
|
0, |
|
self.scheduler_timesteps, |
|
(batch_size,), |
|
device=self.model.device, |
|
generator=rand_num_generator, |
|
).long() |
|
depth_timesteps = torch.randint( |
|
0, |
|
self.scheduler_timesteps, |
|
(batch_size,), |
|
device=self.model.device, |
|
generator=rand_num_generator, |
|
).long() |
|
elif self.timestep_method == 'joint': |
|
rgb_timesteps = torch.randint( |
|
0, |
|
self.scheduler_timesteps, |
|
(batch_size,), |
|
device=self.model.device, |
|
generator=rand_num_generator, |
|
).long() |
|
depth_timesteps = rgb_timesteps |
|
elif self.timestep_method == 'partition': |
|
rand_num = random.random() |
|
if rand_num < 0.3333: |
|
|
|
rgb_timesteps = torch.randint( |
|
0, |
|
self.scheduler_timesteps, |
|
(batch_size,), |
|
device=self.model.device, |
|
generator=rand_num_generator, |
|
).long() |
|
depth_timesteps = rgb_timesteps |
|
elif rand_num < 0.6666: |
|
|
|
rgb_timesteps = torch.zeros( |
|
(batch_size), |
|
device=self.model.device |
|
).long() |
|
depth_timesteps = torch.randint( |
|
0, |
|
self.scheduler_timesteps, |
|
(batch_size,), |
|
device=self.model.device, |
|
generator=rand_num_generator, |
|
).long() |
|
else: |
|
|
|
rgb_timesteps = torch.randint( |
|
0, |
|
self.scheduler_timesteps, |
|
(batch_size,), |
|
device=self.model.device, |
|
generator=rand_num_generator, |
|
).long() |
|
depth_timesteps = torch.zeros( |
|
(batch_size), |
|
device=self.model.device |
|
).long() |
|
|
|
|
|
if self.apply_multi_res_noise: |
|
rgb_strength = self.mr_noise_strength |
|
if self.annealed_mr_noise: |
|
|
|
rgb_strength = rgb_strength * (rgb_timesteps / self.scheduler_timesteps) |
|
rgb_noise = multi_res_noise_like( |
|
rgb_latent, |
|
strength=rgb_strength, |
|
downscale_strategy=self.mr_noise_downscale_strategy, |
|
generator=rand_num_generator, |
|
device=self.model.device, |
|
) |
|
|
|
depth_strength = self.mr_noise_strength |
|
if self.annealed_mr_noise: |
|
|
|
depth_strength = depth_strength * (depth_timesteps / self.scheduler_timesteps) |
|
depth_noise = multi_res_noise_like( |
|
gt_depth_latent, |
|
strength=depth_strength, |
|
downscale_strategy=self.mr_noise_downscale_strategy, |
|
generator=rand_num_generator, |
|
device=self.model.device, |
|
) |
|
else: |
|
rgb_noise = torch.randn( |
|
rgb_latent.shape, |
|
device=self.model.device, |
|
generator=rand_num_generator, |
|
) |
|
|
|
depth_noise = torch.randn( |
|
gt_depth_latent.shape, |
|
device=self.model.device, |
|
generator=rand_num_generator, |
|
) |
|
|
|
|
|
if depth_timesteps.sum() == 0: |
|
noisy_rgb_latents = rgb_latent |
|
else: |
|
noisy_rgb_latents = self.training_noise_scheduler.add_noise( |
|
rgb_latent, rgb_noise, rgb_timesteps |
|
) |
|
|
|
noisy_depth_latents = self.training_noise_scheduler.add_noise( |
|
gt_depth_latent, depth_noise, depth_timesteps |
|
) |
|
|
|
noisy_latents = torch.cat( |
|
[noisy_rgb_latents, noisy_depth_latents], dim=1 |
|
).float() |
|
|
|
|
|
input_ids = self.model.tokenizer( |
|
batch['text'], |
|
padding="max_length", |
|
max_length=self.model.tokenizer.model_max_length, |
|
truncation=True, |
|
return_tensors="pt", |
|
) |
|
input_ids = {k: v.to(self.model.device) for k, v in input_ids.items()} |
|
text_embed = self.model.text_encoder(**input_ids)[0] |
|
|
|
|
|
|
|
model_pred = self.model.unet( |
|
noisy_latents, rgb_timesteps, depth_timesteps, text_embed |
|
).sample |
|
if torch.isnan(model_pred).any(): |
|
logging.warning("model_pred contains NaN.") |
|
|
|
|
|
if "sample" == self.prediction_type: |
|
rgb_target = rgb_latent |
|
depth_target = gt_depth_latent |
|
elif "epsilon" == self.prediction_type: |
|
rgb_target = rgb_latent |
|
depth_target = gt_depth_latent |
|
elif "v_prediction" == self.prediction_type: |
|
rgb_target = self.training_noise_scheduler.get_velocity( |
|
rgb_latent, rgb_noise, rgb_timesteps |
|
) |
|
depth_target = self.training_noise_scheduler.get_velocity( |
|
gt_depth_latent, depth_noise, depth_timesteps |
|
) |
|
else: |
|
raise ValueError(f"Unknown prediction type {self.prediction_type}") |
|
|
|
with accelerator.accumulate(self.model): |
|
if self.gt_mask_type is not None: |
|
depth_loss = self.loss( |
|
model_pred[:, 4:, :, :][valid_mask_down].float(), |
|
depth_target[valid_mask_down].float(), |
|
) |
|
else: |
|
depth_loss = self.loss(model_pred[:, 4:, :, :].float(),depth_target.float()) |
|
|
|
rgb_loss = self.loss(model_pred[:, 0:4, :, :].float(), rgb_target.float()) |
|
|
|
if torch.sum(rgb_timesteps) == 0 or torch.sum(rgb_timesteps) == len(rgb_timesteps) * self.scheduler_timesteps: |
|
loss = depth_loss |
|
elif torch.sum(depth_timesteps) == 0 or torch.sum(depth_timesteps) == len(depth_timesteps) * self.scheduler_timesteps: |
|
loss = rgb_loss |
|
else: |
|
loss = self.cfg.loss.depth_factor * depth_loss + (1 - self.cfg.loss.depth_factor) * rgb_loss |
|
|
|
self.train_metrics.update("loss", loss.item()) |
|
self.train_metrics.update("rgb_loss", rgb_loss.item()) |
|
self.train_metrics.update("depth_loss", depth_loss.item()) |
|
|
|
accelerator.backward(loss) |
|
self.optimizer.step() |
|
self.optimizer.zero_grad() |
|
|
|
self.n_batch_in_epoch += 1 |
|
|
|
self.lr_scheduler.step(self.effective_iter) |
|
|
|
if accelerator.sync_gradients: |
|
accumulated_step += 1 |
|
|
|
if accumulated_step >= self.gradient_accumulation_steps: |
|
accumulated_step = 0 |
|
self.effective_iter += 1 |
|
|
|
if accelerator.is_main_process: |
|
|
|
if self.effective_iter == 1: |
|
generator = torch.Generator(self.model.device).manual_seed(1024) |
|
img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator, |
|
show_pbar=True) |
|
for idx in range(len(self.prompt)): |
|
tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter) |
|
self._depth2image() |
|
self._image2depth() |
|
|
|
accumulated_loss = self.train_metrics.result()["loss"] |
|
rgb_loss = self.train_metrics.result()["rgb_loss"] |
|
depth_loss = self.train_metrics.result()["depth_loss"] |
|
tb_logger.log_dic( |
|
{ |
|
f"train/{k}": v |
|
for k, v in self.train_metrics.result().items() |
|
}, |
|
global_step=self.effective_iter, |
|
) |
|
tb_logger.writer.add_scalar( |
|
"lr", |
|
self.lr_scheduler.get_last_lr()[0], |
|
global_step=self.effective_iter, |
|
) |
|
tb_logger.writer.add_scalar( |
|
"n_batch_in_epoch", |
|
self.n_batch_in_epoch, |
|
global_step=self.effective_iter, |
|
) |
|
logging.info( |
|
f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}, rgb_loss={rgb_loss:.5f}, depth_loss={depth_loss:.5f}" |
|
) |
|
accelerator.wait_for_everyone() |
|
|
|
if self.save_period > 0 and 0 == self.effective_iter % self.save_period: |
|
accelerator.save_state(output_dir=os.path.join(self.out_dir_ckpt, 'latest')) |
|
unwrapped_model = accelerator.unwrap_model(self.model) |
|
if accelerator.is_main_process: |
|
accelerator.save_model(unwrapped_model.unet, |
|
os.path.join(self.out_dir_ckpt, 'latest'), safe_serialization=False) |
|
self.save_miscs('latest') |
|
|
|
|
|
generator = torch.Generator(self.model.device).manual_seed(1024) |
|
img = self.model.generate_rgbd(self.prompt, num_inference_steps=50, generator=generator, show_pbar=False, height=64, width=64) |
|
for idx in range(len(self.prompt)): |
|
tb_logger.writer.add_image(f'image/{self.prompt[idx]}', img[idx], self.effective_iter) |
|
|
|
|
|
self._depth2image() |
|
|
|
self._image2depth() |
|
|
|
accelerator.wait_for_everyone() |
|
|
|
if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period: |
|
unwrapped_model = accelerator.unwrap_model(self.model) |
|
if accelerator.is_main_process: |
|
unwrapped_model.unet.save_pretrained( |
|
os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name())) |
|
accelerator.wait_for_everyone() |
|
|
|
if self.val_period > 0 and 0 == self.effective_iter % self.val_period: |
|
self.validate() |
|
|
|
|
|
if self.max_iter > 0 and self.effective_iter >= self.max_iter: |
|
unwrapped_model = accelerator.unwrap_model(self.model) |
|
if accelerator.is_main_process: |
|
unwrapped_model.unet.save_pretrained( |
|
os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name())) |
|
accelerator.wait_for_everyone() |
|
return |
|
|
|
torch.cuda.empty_cache() |
|
|
|
|
|
|
|
self.n_batch_in_epoch = 0 |
|
|
|
def _image2depth(self): |
|
generator = torch.Generator(self.model.device).manual_seed(1024) |
|
image2dept_paths = ['/home/aiops/wangzh/data/scannet/scene0593_00/color/000100.jpg', |
|
'/home/aiops/wangzh/data/scannet/scene0593_00/color/000700.jpg', |
|
'/home/aiops/wangzh/data/scannet/scene0591_01/color/000600.jpg', |
|
'/home/aiops/wangzh/data/scannet/scene0591_01/color/001500.jpg'] |
|
for img_idx, image_path in enumerate(image2dept_paths): |
|
rgb_input = Image.open(image_path) |
|
depth_pred: MarigoldDepthOutput = self.model.image2depth( |
|
rgb_input, |
|
denoising_steps=self.cfg.validation.denoising_steps, |
|
ensemble_size=self.cfg.validation.ensemble_size, |
|
processing_res=self.cfg.validation.processing_res, |
|
match_input_res=self.cfg.validation.match_input_res, |
|
generator=generator, |
|
batch_size=self.cfg.validation.ensemble_size, |
|
|
|
color_map="Spectral", |
|
show_progress_bar=False, |
|
resample_method=self.cfg.validation.resample_method, |
|
) |
|
img = self.model.post_process_rgbd(['None'], [rgb_input], [depth_pred['depth_colored']]) |
|
tb_logger.writer.add_image(f'image2depth_{img_idx}', img[0], self.effective_iter) |
|
|
|
def _depth2image(self): |
|
generator = torch.Generator(self.model.device).manual_seed(1024) |
|
if "least_square_disparity" == self.cfg.eval.alignment: |
|
depth2image_path = ['/home/aiops/wangzh/data/ori_depth_part0-0/sa_10000335.jpg', |
|
'/home/aiops/wangzh/data/ori_depth_part0-0/sa_3572319.jpg', |
|
'/home/aiops/wangzh/data/ori_depth_part0-0/sa_457934.jpg'] |
|
else: |
|
depth2image_path = ['/home/aiops/wangzh/data/sa_001000/sa_10000335.jpg', |
|
'/home/aiops/wangzh/data/sa_000357/sa_3572319.jpg', |
|
'/home/aiops/wangzh/data/sa_000045/sa_457934.jpg'] |
|
prompts = ['Red car parked in the factory', |
|
'White gothic church with cemetery next to it', |
|
'House with red roof and starry sky in the background'] |
|
for img_idx, depth_path in enumerate(depth2image_path): |
|
depth_input = Image.open(depth_path) |
|
image_pred = self.model.single_depth2image( |
|
depth_input, |
|
prompts[img_idx], |
|
num_inference_steps=50, |
|
processing_res=self.cfg.validation.processing_res, |
|
generator=generator, |
|
show_pbar=False, |
|
resample_method=self.cfg.validation.resample_method, |
|
) |
|
img = self.model.post_process_rgbd([prompts[img_idx]], [image_pred], [depth_input]) |
|
tb_logger.writer.add_image(f'depth2image_{img_idx}', img[0], self.effective_iter) |
|
|
|
def encode_depth(self, depth_in): |
|
|
|
stacked = self.stack_depth_images(depth_in) |
|
|
|
depth_latent = self.model.encode_rgb(stacked) |
|
return depth_latent |
|
|
|
@staticmethod |
|
def stack_depth_images(depth_in): |
|
if 4 == len(depth_in.shape): |
|
stacked = depth_in.repeat(1, 3, 1, 1) |
|
elif 3 == len(depth_in.shape): |
|
stacked = depth_in.unsqueeze(1) |
|
stacked = stacked.repeat(1, 3, 1, 1) |
|
return stacked |
|
|
|
def validate(self): |
|
for i, val_loader in enumerate(self.val_loaders): |
|
val_dataset_name = val_loader.dataset.disp_name |
|
val_metric_dic = self.validate_single_dataset( |
|
data_loader=val_loader, metric_tracker=self.val_metrics |
|
) |
|
|
|
if self.accelerator.is_main_process: |
|
val_metric_dic = {k:torch.tensor(v).cuda() for k,v in val_metric_dic.items()} |
|
|
|
tb_logger.log_dic( |
|
{f"val/{val_dataset_name}/{k}": v for k, v in val_metric_dic.items()}, |
|
global_step=self.effective_iter, |
|
) |
|
|
|
eval_text = eval_dic_to_text( |
|
val_metrics=val_metric_dic, |
|
dataset_name=val_dataset_name, |
|
sample_list_path=val_loader.dataset.filename_ls_path, |
|
) |
|
_save_to = os.path.join( |
|
self.out_dir_eval, |
|
f"eval-{val_dataset_name}-iter{self.effective_iter:06d}.txt", |
|
) |
|
with open(_save_to, "w+") as f: |
|
f.write(eval_text) |
|
|
|
|
|
if 0 == i: |
|
main_eval_metric = val_metric_dic[self.main_val_metric] |
|
if ( |
|
"minimize" == self.main_val_metric_goal |
|
and main_eval_metric < self.best_metric |
|
or "maximize" == self.main_val_metric_goal |
|
and main_eval_metric > self.best_metric |
|
): |
|
self.best_metric = main_eval_metric |
|
logging.info( |
|
f"Best metric: {self.main_val_metric} = {self.best_metric} at iteration {self.effective_iter}" |
|
) |
|
|
|
self.save_checkpoint( |
|
ckpt_name='best', save_train_state=False |
|
) |
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
def visualize(self): |
|
for val_loader in self.vis_loaders: |
|
vis_dataset_name = val_loader.dataset.disp_name |
|
vis_out_dir = os.path.join( |
|
self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name |
|
) |
|
os.makedirs(vis_out_dir, exist_ok=True) |
|
_ = self.validate_single_dataset( |
|
data_loader=val_loader, |
|
metric_tracker=self.val_metrics, |
|
save_to_dir=vis_out_dir, |
|
) |
|
|
|
@torch.no_grad() |
|
def validate_single_dataset( |
|
self, |
|
data_loader: DataLoader, |
|
metric_tracker: MetricTracker, |
|
save_to_dir: str = None, |
|
): |
|
self.model.to(self.device) |
|
metric_tracker.reset() |
|
|
|
|
|
val_init_seed = self.cfg.validation.init_seed |
|
val_seed_ls = generate_seed_sequence(val_init_seed, len(data_loader)) |
|
|
|
for i, batch in enumerate( |
|
tqdm(data_loader, desc=f"evaluating on {data_loader.dataset.disp_name}"), |
|
start=1, |
|
): |
|
|
|
rgb_int = batch["rgb_int"] |
|
|
|
depth_raw_ts = batch["depth_raw_linear"].squeeze() |
|
depth_raw = depth_raw_ts.cpu().numpy() |
|
depth_raw_ts = depth_raw_ts.to(self.device) |
|
valid_mask_ts = batch["valid_mask_raw"].squeeze() |
|
valid_mask = valid_mask_ts.cpu().numpy() |
|
valid_mask_ts = valid_mask_ts.to(self.device) |
|
|
|
|
|
seed = val_seed_ls.pop() |
|
if seed is None: |
|
generator = None |
|
else: |
|
generator = torch.Generator(device=self.device) |
|
generator.manual_seed(seed) |
|
|
|
|
|
pipe_out: MarigoldDepthOutput = self.model.image2depth( |
|
rgb_int, |
|
denoising_steps=self.cfg.validation.denoising_steps, |
|
ensemble_size=self.cfg.validation.ensemble_size, |
|
processing_res=self.cfg.validation.processing_res, |
|
match_input_res=self.cfg.validation.match_input_res, |
|
generator=generator, |
|
batch_size=self.cfg.validation.ensemble_size, |
|
color_map=None, |
|
show_progress_bar=False, |
|
resample_method=self.cfg.validation.resample_method, |
|
) |
|
|
|
depth_pred: np.ndarray = pipe_out.depth_np |
|
|
|
if "least_square" == self.cfg.eval.alignment: |
|
depth_pred, scale, shift = align_depth_least_square( |
|
gt_arr=depth_raw, |
|
pred_arr=depth_pred, |
|
valid_mask_arr=valid_mask, |
|
return_scale_shift=True, |
|
max_resolution=self.cfg.eval.align_max_res, |
|
) |
|
elif "least_square_disparity" == self.cfg.eval.alignment: |
|
|
|
gt_disparity, gt_non_neg_mask = depth2disparity( |
|
depth=depth_raw, return_mask=True |
|
) |
|
|
|
pred_non_neg_mask = depth_pred > 0 |
|
valid_nonnegative_mask = valid_mask & gt_non_neg_mask & pred_non_neg_mask |
|
|
|
disparity_pred, scale, shift = align_depth_least_square( |
|
gt_arr=gt_disparity, |
|
pred_arr=depth_pred, |
|
valid_mask_arr=valid_nonnegative_mask, |
|
return_scale_shift=True, |
|
max_resolution=self.cfg.eval.align_max_res, |
|
) |
|
|
|
disparity_pred = np.clip( |
|
disparity_pred, a_min=1e-3, a_max=None |
|
) |
|
depth_pred = disparity2depth(disparity_pred) |
|
|
|
|
|
depth_pred = np.clip( |
|
depth_pred, |
|
a_min=data_loader.dataset.min_depth, |
|
a_max=data_loader.dataset.max_depth, |
|
) |
|
|
|
|
|
depth_pred = np.clip(depth_pred, a_min=1e-6, a_max=None) |
|
|
|
|
|
sample_metric = [] |
|
depth_pred_ts = torch.from_numpy(depth_pred).to(self.device) |
|
|
|
for met_func in self.metric_funcs: |
|
_metric_name = met_func.__name__ |
|
_metric = met_func(depth_pred_ts, depth_raw_ts, valid_mask_ts).cuda(self.accelerator.process_index) |
|
self.accelerator.wait_for_everyone() |
|
_metric = self.accelerator.gather_for_metrics(_metric.unsqueeze(0)).mean().item() |
|
sample_metric.append(_metric.__str__()) |
|
metric_tracker.update(_metric_name, _metric) |
|
|
|
self.accelerator.wait_for_everyone() |
|
|
|
if save_to_dir is not None: |
|
img_name = batch["rgb_relative_path"][0].replace("/", "_") |
|
png_save_path = os.path.join(save_to_dir, f"{img_name}.png") |
|
depth_to_save = (pipe_out.depth_np * 65535.0).astype(np.uint16) |
|
Image.fromarray(depth_to_save).save(png_save_path, mode="I;16") |
|
|
|
return metric_tracker.result() |
|
|
|
def _get_next_seed(self): |
|
if 0 == len(self.global_seed_sequence): |
|
self.global_seed_sequence = generate_seed_sequence( |
|
initial_seed=self.seed, |
|
length=self.max_iter * self.gradient_accumulation_steps, |
|
) |
|
logging.info( |
|
f"Global seed sequence is generated, length={len(self.global_seed_sequence)}" |
|
) |
|
return self.global_seed_sequence.pop() |
|
|
|
def save_miscs(self, ckpt_name): |
|
ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) |
|
state = { |
|
"config": self.cfg, |
|
"effective_iter": self.effective_iter, |
|
"epoch": self.epoch, |
|
"n_batch_in_epoch": self.n_batch_in_epoch, |
|
"best_metric": self.best_metric, |
|
"in_evaluation": self.in_evaluation, |
|
"global_seed_sequence": self.global_seed_sequence, |
|
} |
|
train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") |
|
torch.save(state, train_state_path) |
|
|
|
logging.info(f"Misc state is saved to: {train_state_path}") |
|
|
|
def load_miscs(self, ckpt_path): |
|
checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) |
|
self.effective_iter = checkpoint["effective_iter"] |
|
self.epoch = checkpoint["epoch"] |
|
self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] |
|
self.in_evaluation = checkpoint["in_evaluation"] |
|
self.global_seed_sequence = checkpoint["global_seed_sequence"] |
|
|
|
self.best_metric = checkpoint["best_metric"] |
|
|
|
logging.info(f"Misc state is loaded from {ckpt_path}") |
|
|
|
|
|
def save_checkpoint(self, ckpt_name, save_train_state): |
|
ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name) |
|
logging.info(f"Saving checkpoint to: {ckpt_dir}") |
|
|
|
temp_ckpt_dir = None |
|
if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir): |
|
temp_ckpt_dir = os.path.join( |
|
os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}" |
|
) |
|
if os.path.exists(temp_ckpt_dir): |
|
shutil.rmtree(temp_ckpt_dir, ignore_errors=True) |
|
os.rename(ckpt_dir, temp_ckpt_dir) |
|
logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}") |
|
|
|
|
|
unet_path = os.path.join(ckpt_dir, "unet") |
|
self.model.unet.save_pretrained(unet_path, safe_serialization=False) |
|
logging.info(f"UNet is saved to: {unet_path}") |
|
|
|
if save_train_state: |
|
state = { |
|
"config": self.cfg, |
|
"effective_iter": self.effective_iter, |
|
"epoch": self.epoch, |
|
"n_batch_in_epoch": self.n_batch_in_epoch, |
|
"best_metric": self.best_metric, |
|
"in_evaluation": self.in_evaluation, |
|
"global_seed_sequence": self.global_seed_sequence, |
|
} |
|
train_state_path = os.path.join(ckpt_dir, "trainer.ckpt") |
|
torch.save(state, train_state_path) |
|
|
|
f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w") |
|
f.close() |
|
|
|
logging.info(f"Trainer state is saved to: {train_state_path}") |
|
|
|
|
|
if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir): |
|
shutil.rmtree(temp_ckpt_dir, ignore_errors=True) |
|
logging.debug("Old checkpoint backup is removed.") |
|
|
|
def load_checkpoint( |
|
self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True |
|
): |
|
logging.info(f"Loading checkpoint from: {ckpt_path}") |
|
|
|
_model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin") |
|
self.model.unet.load_state_dict( |
|
torch.load(_model_path, map_location=self.device) |
|
) |
|
self.model.unet.to(self.device) |
|
logging.info(f"UNet parameters are loaded from {_model_path}") |
|
|
|
|
|
if load_trainer_state: |
|
checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt")) |
|
self.effective_iter = checkpoint["effective_iter"] |
|
self.epoch = checkpoint["epoch"] |
|
self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"] |
|
self.in_evaluation = checkpoint["in_evaluation"] |
|
self.global_seed_sequence = checkpoint["global_seed_sequence"] |
|
|
|
self.best_metric = checkpoint["best_metric"] |
|
|
|
self.optimizer.load_state_dict(checkpoint["optimizer"]) |
|
logging.info(f"optimizer state is loaded from {ckpt_path}") |
|
|
|
if resume_lr_scheduler: |
|
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"]) |
|
logging.info(f"LR scheduler state is loaded from {ckpt_path}") |
|
|
|
logging.info( |
|
f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})" |
|
) |
|
return |
|
|
|
def _get_backup_ckpt_name(self): |
|
return f"iter_{self.effective_iter:06d}" |
|
|