Inpaint / src /trainer /marigold_inpaint_trainer.py
ZehanWang's picture
Upload folder using huggingface_hub
864ec44 verified
raw
history blame
30.3 kB
# An official reimplemented version of Marigold training script.
# Last modified: 2024-04-29
#
# Copyright 2023 Bingxin Ke, ETH Zurich. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# --------------------------------------------------------------------------
# If you find this code useful, we kindly ask you to cite our paper in your work.
# Please find bibtex at: https://github.com/prs-eth/Marigold#-citation
# If you use or adapt this code, please attribute to https://github.com/prs-eth/marigold.
# More information about the method can be found at https://marigoldmonodepth.github.io
# --------------------------------------------------------------------------
from diffusers import StableDiffusionInpaintPipeline
import logging
import os
import pdb
import cv2
import shutil
import json
from pycocotools import mask as coco_mask
from datetime import datetime
from typing import List, Union
import random
import safetensors
import numpy as np
import torch
from diffusers import DDPMScheduler
from omegaconf import OmegaConf
from torch.nn import Conv2d
from torch.nn.parameter import Parameter
from torch.optim import Adam
from torch.optim.lr_scheduler import LambdaLR
from torch.utils.data import DataLoader, Dataset
from tqdm import tqdm
from PIL import Image
# import torch.optim.lr_scheduler
from diffusers.schedulers import PNDMScheduler
from torchvision.transforms.functional import pil_to_tensor
from marigold.marigold_pipeline import MarigoldPipeline, MarigoldDepthOutput
from src.util import metric
from src.util.data_loader import skip_first_batches
from src.util.logging_util import tb_logger, eval_dic_to_text
from src.util.loss import get_loss
from src.util.lr_scheduler import IterExponential
from src.util.metric import MetricTracker
from src.util.multi_res_noise import multi_res_noise_like
from src.util.alignment import align_depth_least_square, depth2disparity, disparity2depth
from src.util.seeding import generate_seed_sequence
from accelerate import Accelerator
import os
from torchvision.transforms import InterpolationMode, Resize, CenterCrop
import torchvision.transforms as transforms
# os.environ['CUDA_LAUNCH_BLOCKING'] = '1'
class MarigoldInpaintTrainer:
def __init__(
self,
cfg: OmegaConf,
model: MarigoldPipeline,
train_dataloader: DataLoader,
device,
base_ckpt_dir,
out_dir_ckpt,
out_dir_eval,
out_dir_vis,
accumulation_steps: int,
depth_model = None,
separate_list: List = None,
val_dataloaders: List[DataLoader] = None,
vis_dataloaders: List[DataLoader] = None,
train_dataset: Dataset = None,
timestep_method: str = 'unidiffuser',
connection: bool = False
):
self.cfg: OmegaConf = cfg
self.model: MarigoldPipeline = model
self.depth_model = depth_model
self.device = device
self.seed: Union[int, None] = (
self.cfg.trainer.init_seed
) # used to generate seed sequence, set to `None` to train w/o seeding
self.out_dir_ckpt = out_dir_ckpt
self.out_dir_eval = out_dir_eval
self.out_dir_vis = out_dir_vis
self.train_loader: DataLoader = train_dataloader
self.val_loaders: List[DataLoader] = val_dataloaders
self.vis_loaders: List[DataLoader] = vis_dataloaders
self.accumulation_steps: int = accumulation_steps
self.separate_list = separate_list
self.timestep_method = timestep_method
self.train_dataset = train_dataset
self.connection = connection
# Adapt input layers
# if 8 != self.model.unet.config["in_channels"]:
# self._replace_unet_conv_in()
# if 8 != self.model.unet.config["out_channels"]:
# self._replace_unet_conv_out()
self.train_metrics = MetricTracker(*["loss", 'rgb_loss', 'depth_loss'])
# self.generator = torch.Generator('cuda:0').manual_seed(1024)
# Encode empty text prompt
self.model.encode_empty_text()
self.empty_text_embed = self.model.empty_text_embed.detach().clone().to(device)
self.model.unet.enable_xformers_memory_efficient_attention()
# Trainability
self.model.text_encoder.requires_grad_(False)
# self.model.unet.requires_grad_(True)
grad_part = filter(lambda p: p.requires_grad, self.model.unet.parameters())
# Optimizer !should be defined after input layer is adapted
lr = self.cfg.lr
self.optimizer = Adam(grad_part, lr=lr)
total_params = sum(p.numel() for p in self.model.unet.parameters())
total_params_m = total_params / 1_000_000
print(f"Total parameters: {total_params_m:.2f}M")
trainable_params = sum(p.numel() for p in self.model.unet.parameters() if p.requires_grad)
trainable_params_m = trainable_params / 1_000_000
print(f"Trainable parameters: {trainable_params_m:.2f}M")
# LR scheduler
lr_func = IterExponential(
total_iter_length=self.cfg.lr_scheduler.kwargs.total_iter,
final_ratio=self.cfg.lr_scheduler.kwargs.final_ratio,
warmup_steps=self.cfg.lr_scheduler.kwargs.warmup_steps,
)
self.lr_scheduler = LambdaLR(optimizer=self.optimizer, lr_lambda=lr_func)
# Loss
self.loss = get_loss(loss_name=self.cfg.loss.name, **self.cfg.loss.kwargs)
# Training noise scheduler
# self.rgb_training_noise_scheduler: PNDMScheduler = PNDMScheduler.from_pretrained(
# os.path.join(
# cfg.trainer.rgb_training_noise_scheduler.pretrained_path,
# "scheduler",
# )
# )
self.rgb_training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(
cfg.trainer.depth_training_noise_scheduler.pretrained_path, subfolder="scheduler")
self.depth_training_noise_scheduler: DDPMScheduler = DDPMScheduler.from_pretrained(
cfg.trainer.depth_training_noise_scheduler.pretrained_path, subfolder="scheduler")
self.rgb_prediction_type = self.rgb_training_noise_scheduler.config.prediction_type
# assert (
# self.rgb_prediction_type == self.model.rgb_scheduler.config.prediction_type
# ), "Different prediction types"
self.depth_prediction_type = self.depth_training_noise_scheduler.config.prediction_type
assert (
self.depth_prediction_type == self.model.depth_scheduler.config.prediction_type
), "Different prediction types"
self.scheduler_timesteps = (
self.rgb_training_noise_scheduler.config.num_train_timesteps
)
# Settings
self.max_epoch = self.cfg.max_epoch
self.max_iter = self.cfg.max_iter
self.gradient_accumulation_steps = accumulation_steps
self.gt_depth_type = self.cfg.gt_depth_type
self.gt_mask_type = self.cfg.gt_mask_type
self.save_period = self.cfg.trainer.save_period
self.backup_period = self.cfg.trainer.backup_period
self.val_period = self.cfg.trainer.validation_period
self.vis_period = self.cfg.trainer.visualization_period
# Multi-resolution noise
self.apply_multi_res_noise = self.cfg.multi_res_noise is not None
if self.apply_multi_res_noise:
self.mr_noise_strength = self.cfg.multi_res_noise.strength
self.annealed_mr_noise = self.cfg.multi_res_noise.annealed
self.mr_noise_downscale_strategy = (
self.cfg.multi_res_noise.downscale_strategy
)
# Internal variables
self.epoch = 0
self.n_batch_in_epoch = 0 # batch index in the epoch, used when resume training
self.effective_iter = 0 # how many times optimizer.step() is called
self.in_evaluation = False
self.global_seed_sequence: List = [] # consistent global seed sequence, used to seed random generator, to ensure consistency when resuming
def _replace_unet_conv_in(self):
# replace the first layer to accept 8 in_channels
_weight = self.model.unet.conv_in.weight.clone() # [320, 4, 3, 3]
_bias = self.model.unet.conv_in.bias.clone() # [320]
zero_weight = torch.zeros(_weight.shape).to(_weight.device)
_weight = torch.cat([_weight, zero_weight], dim=1)
# _weight = _weight.repeat((1, 2, 1, 1)) # Keep selected channel(s)
# half the activation magnitude
# _weight *= 0.5
# new conv_in channel
_n_convin_out_channel = self.model.unet.conv_in.out_channels
_new_conv_in = Conv2d(
8, _n_convin_out_channel, kernel_size=(3, 3), stride=(1, 1), padding=(1, 1)
)
_new_conv_in.weight = Parameter(_weight)
_new_conv_in.bias = Parameter(_bias)
self.model.unet.conv_in = _new_conv_in
logging.info("Unet conv_in layer is replaced")
# replace config
self.model.unet.config["in_channels"] = 8
logging.info("Unet config is updated")
return
def parallel_train(self, t_end=None, accelerator=None):
logging.info("Start training")
self.model, self.optimizer, self.train_loader, self.lr_scheduler = accelerator.prepare(
self.model, self.optimizer, self.train_loader, self.lr_scheduler
)
self.depth_model = accelerator.prepare(self.depth_model)
self.accelerator = accelerator
if os.path.exists(os.path.join(self.out_dir_ckpt, 'latest')):
accelerator.load_state(os.path.join(self.out_dir_ckpt, 'latest'))
self.load_miscs(os.path.join(self.out_dir_ckpt, 'latest'))
# if accelerator.is_main_process:
# self._inpaint_rgbd()
self.train_metrics.reset()
accumulated_step = 0
for epoch in range(self.epoch, self.max_epoch + 1):
self.epoch = epoch
logging.debug(f"epoch: {self.epoch}")
# Skip previous batches when resume
for batch in skip_first_batches(self.train_loader, self.n_batch_in_epoch):
self.model.unet.train()
# globally consistent random generators
if self.seed is not None:
local_seed = self._get_next_seed()
rand_num_generator = torch.Generator(device=self.model.device)
rand_num_generator.manual_seed(local_seed)
else:
rand_num_generator = None
# >>> With gradient accumulation >>>
# Get data
rgb = batch["rgb_norm"].to(self.model.device)
with torch.no_grad():
disparities = self.depth_model(batch["rgb_int"].numpy().astype(np.uint8), 518, device=self.model.device)
if len(disparities.shape) == 2:
disparities = disparities.unsqueeze(0)
depth_gt_for_latent = []
for disparity_map in disparities:
depth_map = ((disparity_map - disparity_map.min()) / (disparity_map.max() - disparity_map.min())) * 2 - 1
depth_gt_for_latent.append(depth_map)
depth_gt_for_latent = torch.stack(depth_gt_for_latent, dim=0)
batch_size = rgb.shape[0]
mask = self.model.mask_processor.preprocess(batch['mask'] * 255).to(self.model.device)
rgb_timesteps = torch.randint(
0,
self.scheduler_timesteps,
(batch_size,),
device=self.model.device,
generator=rand_num_generator,
).long() # [B]
depth_timesteps = rgb_timesteps
rgb_flag = 1
depth_flag = 1
if self.timestep_method == 'joint':
rgb_mask = mask
depth_mask = mask
elif self.timestep_method == 'partition':
rand_num = random.random()
if rand_num < 0.5: # joint prediction
rgb_mask = mask
depth_mask = mask
elif rand_num < 0.75: # full rgb; depth prediction
rgb_flag = 0
rgb_mask = torch.zeros_like(mask)
depth_mask = mask
else:
depth_flag = 0
rgb_mask = mask
if random.random() < 0.5:
depth_mask = torch.zeros_like(mask) # full depth; rgb prediction
else:
depth_mask = mask # partial depth; rgb prediction
masked_rgb = rgb * (rgb_mask < 0.5)
masked_depth = depth_gt_for_latent * (depth_mask.squeeze() < 0.5)
with torch.no_grad():
# Encode image
rgb_latent = self.model.encode_rgb(rgb) # [B, 4, h, w]
mask_rgb_latent = self.model.encode_rgb(masked_rgb)
if depth_timesteps.sum() == 0:
gt_depth_latent = self.encode_depth(masked_depth)
else:
gt_depth_latent = self.encode_depth(depth_gt_for_latent)
mask_depth_latent = self.encode_depth(masked_depth)
rgb_mask = torch.nn.functional.interpolate(rgb_mask, size=rgb_latent.shape[-2:])
depth_mask = torch.nn.functional.interpolate(depth_mask, size=gt_depth_latent.shape[-2:])
# Sample noise
rgb_noise = torch.randn(
rgb_latent.shape,
device=self.model.device,
generator=rand_num_generator,
) # [B, 4, h, w]
depth_noise = torch.randn(
gt_depth_latent.shape,
device=self.model.device,
generator=rand_num_generator,
) # [B, 4, h, w]
if rgb_timesteps.sum() == 0:
noisy_rgb_latents = rgb_latent
else:
noisy_rgb_latents = self.rgb_training_noise_scheduler.add_noise(
rgb_latent, rgb_noise, rgb_timesteps
) # [B, 4, h, w]
if depth_timesteps.sum() == 0:
noisy_depth_latents = gt_depth_latent
else:
noisy_depth_latents = self.depth_training_noise_scheduler.add_noise(
gt_depth_latent, depth_noise, depth_timesteps
) # [B, 4, h, w]
noisy_latents = torch.cat(
[noisy_rgb_latents, rgb_mask, mask_rgb_latent, mask_depth_latent, noisy_depth_latents, depth_mask, mask_rgb_latent, mask_depth_latent], dim=1
).float() # [B, 9*2, h, w]
# Text embedding
input_ids = self.model.tokenizer(
batch['text'],
padding="max_length",
max_length=self.model.tokenizer.model_max_length,
truncation=True,
return_tensors="pt",
)
input_ids = {k: v.to(self.model.device) for k, v in input_ids.items()}
text_embed = self.model.text_encoder(**input_ids)[0]
model_pred = self.model.unet(
noisy_latents, rgb_timesteps, depth_timesteps, text_embed, controlnet_connection=self.connection
).sample # [B, 8, h, w]
if torch.isnan(model_pred).any():
logging.warning("model_pred contains NaN.")
# Get the target for loss depending on the prediction type
if "sample" == self.rgb_prediction_type:
rgb_target = rgb_latent
elif "epsilon" == self.rgb_prediction_type:
rgb_target = rgb_latent
elif "v_prediction" == self.rgb_prediction_type:
rgb_target = self.rgb_training_noise_scheduler.get_velocity(
rgb_latent, rgb_noise, rgb_timesteps
) # [B, 4, h, w]
else:
raise ValueError(f"Unknown rgb prediction type {self.prediction_type}")
if "sample" == self.depth_prediction_type:
depth_target = gt_depth_latent
elif "epsilon" == self.depth_prediction_type:
depth_target = gt_depth_latent
elif "v_prediction" == self.depth_prediction_type:
depth_target = self.depth_training_noise_scheduler.get_velocity(
gt_depth_latent, depth_noise, depth_timesteps
) # [B, 4, h, w]
else:
raise ValueError(f"Unknown depth prediction type {self.prediction_type}")
# Masked latent loss
with accelerator.accumulate(self.model):
rgb_loss = self.loss(model_pred[:, 0:4, :, :].float(), rgb_target.float())
depth_loss = self.loss(model_pred[:, 4:, :, :].float(), depth_target.float())
if rgb_flag == 0:
loss = depth_loss
elif depth_flag == 0:
loss = rgb_loss
else:
loss = self.cfg.loss.depth_factor * depth_loss + (1 - self.cfg.loss.depth_factor) * rgb_loss
self.train_metrics.update("loss", loss.item())
self.train_metrics.update("rgb_loss", rgb_loss.item())
self.train_metrics.update("depth_loss", depth_loss.item())
# loss = loss / self.gradient_accumulation_steps
accelerator.backward(loss)
self.optimizer.step()
self.optimizer.zero_grad()
# loss.backward()
self.n_batch_in_epoch += 1
# print(accelerator.process_index, self.lr_scheduler.get_last_lr())
self.lr_scheduler.step(self.effective_iter)
if accelerator.sync_gradients:
accumulated_step += 1
if accumulated_step >= self.gradient_accumulation_steps:
accumulated_step = 0
self.effective_iter += 1
if accelerator.is_main_process:
# Log to tensorboard
if self.effective_iter == 1:
self._inpaint_rgbd()
accumulated_loss = self.train_metrics.result()["loss"]
rgb_loss = self.train_metrics.result()["rgb_loss"]
depth_loss = self.train_metrics.result()["depth_loss"]
tb_logger.log_dic(
{
f"train/{k}": v
for k, v in self.train_metrics.result().items()
},
global_step=self.effective_iter,
)
tb_logger.writer.add_scalar(
"lr",
self.lr_scheduler.get_last_lr()[0],
global_step=self.effective_iter,
)
tb_logger.writer.add_scalar(
"n_batch_in_epoch",
self.n_batch_in_epoch,
global_step=self.effective_iter,
)
logging.info(
f"iter {self.effective_iter:5d} (epoch {epoch:2d}): loss={accumulated_loss:.5f}, rgb_loss={rgb_loss:.5f}, depth_loss={depth_loss:.5f}"
)
accelerator.wait_for_everyone()
if self.save_period > 0 and 0 == self.effective_iter % self.save_period:
accelerator.save_state(output_dir=os.path.join(self.out_dir_ckpt, 'latest'))
unwrapped_model = accelerator.unwrap_model(self.model)
if accelerator.is_main_process:
accelerator.save_model(unwrapped_model.unet,
os.path.join(self.out_dir_ckpt, 'latest'), safe_serialization=False)
self.save_miscs('latest')
self._inpaint_rgbd()
accelerator.wait_for_everyone()
if self.backup_period > 0 and 0 == self.effective_iter % self.backup_period:
unwrapped_model = accelerator.unwrap_model(self.model)
if accelerator.is_main_process:
accelerator.save_model(unwrapped_model.unet,
os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name()), safe_serialization=False)
accelerator.wait_for_everyone()
# End of training
if self.max_iter > 0 and self.effective_iter >= self.max_iter:
unwrapped_model = accelerator.unwrap_model(self.model)
if accelerator.is_main_process:
unwrapped_model.unet.save_pretrained(
os.path.join(self.out_dir_ckpt, self._get_backup_ckpt_name()))
accelerator.wait_for_everyone()
return
torch.cuda.empty_cache()
# <<< Effective batch end <<<
# Epoch end
self.n_batch_in_epoch = 0
def _inpaint_rgbd(self):
image_path = ['/dataset/~sa-1b/data/sa_001000/sa_10000335.jpg',
'/dataset/~sa-1b/data/sa_000357/sa_3572319.jpg',
'/dataset/~sa-1b/data/sa_000045/sa_457934.jpg']
prompt = ['A white car is parked in front of the factory',
'church with cemetery next to it',
'A house with a red brick roof']
imgs = [pil_to_tensor(Image.open(p)) for p in image_path]
depth_imgs = [self.depth_model(img.unsqueeze(0).cpu().numpy()) for img in imgs]
masks = []
for rgb_path in image_path:
anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations']
random.shuffle(anno)
object_num = random.randint(5, 10)
mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8)
for single_anno in (anno[0:object_num] if len(anno)>object_num else anno):
mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8)
masks.append(torch.from_numpy(mask))
resize_transform = transforms.Compose([
Resize(size=512, interpolation=InterpolationMode.NEAREST_EXACT),
CenterCrop(size=[512, 512])])
imgs = [resize_transform(img) for img in imgs]
depth_imgs = [resize_transform(depth_img.unsqueeze(0)) for depth_img in depth_imgs]
masks = [resize_transform(mask.unsqueeze(0)) for mask in masks]
# pdb.set_trace()
for i in range(len(imgs)):
output_image = self.model._rgbd_inpaint(imgs[i], depth_imgs[i], masks[i], [prompt[i]], processing_res=512, mode='joint_inpaint')
tb_logger.writer.add_image(f'{prompt[i]}', pil_to_tensor(output_image), self.effective_iter)
def encode_depth(self, depth_in):
# stack depth into 3-channel
stacked = self.stack_depth_images(depth_in)
# encode using VAE encoder
depth_latent = self.model.encode_rgb(stacked)
return depth_latent
@staticmethod
def stack_depth_images(depth_in):
if 4 == len(depth_in.shape):
stacked = depth_in.repeat(1, 3, 1, 1)
elif 3 == len(depth_in.shape):
stacked = depth_in.unsqueeze(1)
stacked = stacked.repeat(1, 3, 1, 1)
elif 2 == len(depth_in.shape):
stacked = depth_in.unsqueeze(0).unsqueeze(0)
stacked = stacked.repeat(1, 3, 1, 1)
return stacked
def visualize(self):
for val_loader in self.vis_loaders:
vis_dataset_name = val_loader.dataset.disp_name
vis_out_dir = os.path.join(
self.out_dir_vis, self._get_backup_ckpt_name(), vis_dataset_name
)
os.makedirs(vis_out_dir, exist_ok=True)
_ = self.validate_single_dataset(
data_loader=val_loader,
metric_tracker=self.val_metrics,
save_to_dir=vis_out_dir,
)
def _get_next_seed(self):
if 0 == len(self.global_seed_sequence):
self.global_seed_sequence = generate_seed_sequence(
initial_seed=self.seed,
length=self.max_iter * self.gradient_accumulation_steps,
)
logging.info(
f"Global seed sequence is generated, length={len(self.global_seed_sequence)}"
)
return self.global_seed_sequence.pop()
def save_miscs(self, ckpt_name):
ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
state = {
"config": self.cfg,
"effective_iter": self.effective_iter,
"epoch": self.epoch,
"n_batch_in_epoch": self.n_batch_in_epoch,
"global_seed_sequence": self.global_seed_sequence,
}
train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
torch.save(state, train_state_path)
logging.info(f"Misc state is saved to: {train_state_path}")
def load_miscs(self, ckpt_path):
checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
self.effective_iter = checkpoint["effective_iter"]
self.epoch = checkpoint["epoch"]
self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
self.global_seed_sequence = checkpoint["global_seed_sequence"]
logging.info(f"Misc state is loaded from {ckpt_path}")
def save_checkpoint(self, ckpt_name, save_train_state):
ckpt_dir = os.path.join(self.out_dir_ckpt, ckpt_name)
logging.info(f"Saving checkpoint to: {ckpt_dir}")
# Backup previous checkpoint
temp_ckpt_dir = None
if os.path.exists(ckpt_dir) and os.path.isdir(ckpt_dir):
temp_ckpt_dir = os.path.join(
os.path.dirname(ckpt_dir), f"_old_{os.path.basename(ckpt_dir)}"
)
if os.path.exists(temp_ckpt_dir):
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
os.rename(ckpt_dir, temp_ckpt_dir)
logging.debug(f"Old checkpoint is backed up at: {temp_ckpt_dir}")
# Save UNet
unet_path = os.path.join(ckpt_dir, "unet")
self.model.unet.save_pretrained(unet_path, safe_serialization=False)
logging.info(f"UNet is saved to: {unet_path}")
if save_train_state:
state = {
"config": self.cfg,
"effective_iter": self.effective_iter,
"epoch": self.epoch,
"n_batch_in_epoch": self.n_batch_in_epoch,
"best_metric": self.best_metric,
"in_evaluation": self.in_evaluation,
"global_seed_sequence": self.global_seed_sequence,
}
train_state_path = os.path.join(ckpt_dir, "trainer.ckpt")
torch.save(state, train_state_path)
# iteration indicator
f = open(os.path.join(ckpt_dir, self._get_backup_ckpt_name()), "w")
f.close()
logging.info(f"Trainer state is saved to: {train_state_path}")
# Remove temp ckpt
if temp_ckpt_dir is not None and os.path.exists(temp_ckpt_dir):
shutil.rmtree(temp_ckpt_dir, ignore_errors=True)
logging.debug("Old checkpoint backup is removed.")
def load_checkpoint(
self, ckpt_path, load_trainer_state=True, resume_lr_scheduler=True
):
logging.info(f"Loading checkpoint from: {ckpt_path}")
# Load UNet
_model_path = os.path.join(ckpt_path, "unet", "diffusion_pytorch_model.bin")
self.model.unet.load_state_dict(
torch.load(_model_path, map_location=self.device)
)
self.model.unet.to(self.device)
logging.info(f"UNet parameters are loaded from {_model_path}")
# Load training states
if load_trainer_state:
checkpoint = torch.load(os.path.join(ckpt_path, "trainer.ckpt"))
self.effective_iter = checkpoint["effective_iter"]
self.epoch = checkpoint["epoch"]
self.n_batch_in_epoch = checkpoint["n_batch_in_epoch"]
self.in_evaluation = checkpoint["in_evaluation"]
self.global_seed_sequence = checkpoint["global_seed_sequence"]
self.best_metric = checkpoint["best_metric"]
self.optimizer.load_state_dict(checkpoint["optimizer"])
logging.info(f"optimizer state is loaded from {ckpt_path}")
if resume_lr_scheduler:
self.lr_scheduler.load_state_dict(checkpoint["lr_scheduler"])
logging.info(f"LR scheduler state is loaded from {ckpt_path}")
logging.info(
f"Checkpoint loaded from: {ckpt_path}. Resume from iteration {self.effective_iter} (epoch {self.epoch})"
)
return
def _get_backup_ckpt_name(self):
return f"iter_{self.effective_iter:06d}"