|
import copy |
|
|
|
import functools |
|
import json |
|
import os |
|
from pathlib import Path |
|
from pdb import set_trace as st |
|
from einops import rearrange |
|
import webdataset as wds |
|
|
|
import traceback |
|
import blobfile as bf |
|
import imageio |
|
import numpy as np |
|
|
|
import torch as th |
|
import torch.distributed as dist |
|
import torchvision |
|
from PIL import Image |
|
from torch.nn.parallel.distributed import DistributedDataParallel as DDP |
|
from torch.optim import AdamW |
|
from torch.utils.tensorboard import SummaryWriter |
|
from tqdm import tqdm |
|
|
|
from guided_diffusion import dist_util, logger |
|
from guided_diffusion.fp16_util import MixedPrecisionTrainer |
|
from guided_diffusion.nn import update_ema |
|
from guided_diffusion.resample import LossAwareSampler, UniformSampler |
|
from guided_diffusion.train_util import (calc_average_loss, |
|
find_ema_checkpoint, |
|
find_resume_checkpoint, |
|
get_blob_logdir, log_rec3d_loss_dict, |
|
parse_resume_step_from_filename) |
|
|
|
from .camera_utils import LookAtPoseSampler, FOV_to_intrinsics |
|
|
|
from .train_util import TrainLoop3DRec |
|
|
|
|
|
class TrainLoop3DRecNV(TrainLoop3DRec): |
|
|
|
def __init__(self, |
|
*, |
|
rec_model, |
|
loss_class, |
|
data, |
|
eval_data, |
|
batch_size, |
|
microbatch, |
|
lr, |
|
ema_rate, |
|
log_interval, |
|
eval_interval, |
|
save_interval, |
|
resume_checkpoint, |
|
use_fp16=False, |
|
fp16_scale_growth=0.001, |
|
weight_decay=0, |
|
lr_anneal_steps=0, |
|
iterations=10001, |
|
load_submodule_name='', |
|
ignore_resume_opt=False, |
|
model_name='rec', |
|
use_amp=False, |
|
**kwargs): |
|
super().__init__(rec_model=rec_model, |
|
loss_class=loss_class, |
|
data=data, |
|
eval_data=eval_data, |
|
batch_size=batch_size, |
|
microbatch=microbatch, |
|
lr=lr, |
|
ema_rate=ema_rate, |
|
log_interval=log_interval, |
|
eval_interval=eval_interval, |
|
save_interval=save_interval, |
|
resume_checkpoint=resume_checkpoint, |
|
use_fp16=use_fp16, |
|
fp16_scale_growth=fp16_scale_growth, |
|
weight_decay=weight_decay, |
|
lr_anneal_steps=lr_anneal_steps, |
|
iterations=iterations, |
|
load_submodule_name=load_submodule_name, |
|
ignore_resume_opt=ignore_resume_opt, |
|
model_name=model_name, |
|
use_amp=use_amp, |
|
**kwargs) |
|
self.rec_cano = True |
|
|
|
def forward_backward(self, batch, *args, **kwargs): |
|
|
|
|
|
self.mp_trainer_rec.zero_grad() |
|
batch_size = batch['img_to_encoder'].shape[0] |
|
|
|
for i in range(0, batch_size, self.microbatch): |
|
|
|
|
|
micro = { |
|
k: v[i:i + self.microbatch].to(dist_util.dev()) |
|
for k, v in batch.items() |
|
} |
|
|
|
|
|
|
|
|
|
with th.autocast(device_type='cuda', |
|
dtype=th.float16, |
|
enabled=self.mp_trainer_rec.use_amp): |
|
|
|
target_nvs = {} |
|
target_cano = {} |
|
|
|
latent = self.rec_model(img=micro['img_to_encoder'], |
|
behaviour='enc_dec_wo_triplane') |
|
|
|
pred = self.rec_model( |
|
latent=latent, |
|
c=micro['nv_c'], |
|
behaviour='triplane_dec') |
|
|
|
for k, v in micro.items(): |
|
if k[:2] == 'nv': |
|
orig_key = k.replace('nv_', '') |
|
target_nvs[orig_key] = v |
|
target_cano[orig_key] = micro[orig_key] |
|
|
|
with self.rec_model.no_sync(): |
|
loss, loss_dict, fg_mask = self.loss_class( |
|
pred, |
|
target_nvs, |
|
step=self.step + self.resume_step, |
|
test_mode=False, |
|
return_fg_mask=True, |
|
conf_sigma_l1=None, |
|
conf_sigma_percl=None) |
|
log_rec3d_loss_dict(loss_dict) |
|
|
|
if self.rec_cano: |
|
|
|
pred_cano = self.rec_model(latent=latent, |
|
c=micro['c'], |
|
behaviour='triplane_dec') |
|
|
|
with self.rec_model.no_sync(): |
|
|
|
fg_mask = target_cano['depth_mask'].unsqueeze( |
|
1).repeat_interleave(3, 1).float() |
|
|
|
loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( |
|
pred_cano['image_raw'], |
|
target_cano['img'], |
|
fg_mask, |
|
step=self.step + self.resume_step, |
|
test_mode=False, |
|
) |
|
|
|
loss = loss + loss_cano |
|
|
|
|
|
log_rec3d_loss_dict({ |
|
f'cano_{k}': v |
|
for k, v in loss_cano_dict.items() |
|
|
|
}) |
|
|
|
self.mp_trainer_rec.backward(loss) |
|
|
|
if dist_util.get_rank() == 0 and self.step % 500 == 0: |
|
if self.rec_cano: |
|
self.log_img(micro, pred, pred_cano) |
|
else: |
|
self.log_img(micro, pred, None) |
|
|
|
@th.inference_mode() |
|
def log_img(self, micro, pred, pred_cano): |
|
|
|
|
|
def norm_depth(pred_depth): |
|
|
|
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - |
|
pred_depth.min()) |
|
return -(pred_depth * 2 - 1) |
|
|
|
pred_img = pred['image_raw'] |
|
gt_img = micro['img'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_depth = micro['depth'] |
|
if gt_depth.ndim == 3: |
|
gt_depth = gt_depth.unsqueeze(1) |
|
gt_depth = norm_depth(gt_depth) |
|
|
|
|
|
|
|
fg_mask = pred['image_mask'] * 2 - 1 |
|
input_fg_mask = pred_cano['image_mask'] * 2 - 1 |
|
if 'image_depth' in pred: |
|
pred_depth = norm_depth(pred['image_depth']) |
|
pred_nv_depth = norm_depth(pred_cano['image_depth']) |
|
else: |
|
pred_depth = th.zeros_like(gt_depth) |
|
pred_nv_depth = th.zeros_like(gt_depth) |
|
|
|
if 'image_sr' in pred: |
|
if pred['image_sr'].shape[-1] == 512: |
|
pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']], |
|
dim=-1) |
|
gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']], |
|
dim=-1) |
|
pred_depth = self.pool_512(pred_depth) |
|
gt_depth = self.pool_512(gt_depth) |
|
|
|
elif pred['image_sr'].shape[-1] == 256: |
|
pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']], |
|
dim=-1) |
|
gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']], |
|
dim=-1) |
|
pred_depth = self.pool_256(pred_depth) |
|
gt_depth = self.pool_256(gt_depth) |
|
|
|
else: |
|
pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']], |
|
dim=-1) |
|
gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']], |
|
dim=-1) |
|
gt_depth = self.pool_128(gt_depth) |
|
pred_depth = self.pool_128(pred_depth) |
|
else: |
|
gt_img = self.pool_64(gt_img) |
|
gt_depth = self.pool_64(gt_depth) |
|
|
|
pred_vis = th.cat([ |
|
pred_img, |
|
pred_depth.repeat_interleave(3, dim=1), |
|
fg_mask.repeat_interleave(3, dim=1), |
|
], |
|
dim=-1) |
|
|
|
pred_vis_nv = th.cat([ |
|
pred_cano['image_raw'], |
|
pred_nv_depth.repeat_interleave(3, dim=1), |
|
input_fg_mask.repeat_interleave(3, dim=1), |
|
], |
|
dim=-1) |
|
|
|
pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) |
|
|
|
gt_vis = th.cat([ |
|
gt_img, |
|
gt_depth.repeat_interleave(3, dim=1), |
|
th.zeros_like(gt_img) |
|
], |
|
dim=-1) |
|
|
|
if 'conf_sigma' in pred: |
|
gt_vis = th.cat([gt_vis, fg_mask], dim=-1) |
|
|
|
|
|
vis = th.cat([gt_vis, pred_vis], dim=-2) |
|
|
|
|
|
vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // |
|
64) |
|
torchvision.utils.save_image( |
|
vis_tensor, |
|
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', |
|
value_range=(-1, 1), |
|
normalize=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
logger.log('log vis to: ', |
|
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainLoop3DRecNVPatch(TrainLoop3DRecNV): |
|
|
|
def __init__(self, |
|
*, |
|
rec_model, |
|
loss_class, |
|
data, |
|
eval_data, |
|
batch_size, |
|
microbatch, |
|
lr, |
|
ema_rate, |
|
log_interval, |
|
eval_interval, |
|
save_interval, |
|
resume_checkpoint, |
|
use_fp16=False, |
|
fp16_scale_growth=0.001, |
|
weight_decay=0, |
|
lr_anneal_steps=0, |
|
iterations=10001, |
|
load_submodule_name='', |
|
ignore_resume_opt=False, |
|
model_name='rec', |
|
use_amp=False, |
|
**kwargs): |
|
super().__init__(rec_model=rec_model, |
|
loss_class=loss_class, |
|
data=data, |
|
eval_data=eval_data, |
|
batch_size=batch_size, |
|
microbatch=microbatch, |
|
lr=lr, |
|
ema_rate=ema_rate, |
|
log_interval=log_interval, |
|
eval_interval=eval_interval, |
|
save_interval=save_interval, |
|
resume_checkpoint=resume_checkpoint, |
|
use_fp16=use_fp16, |
|
fp16_scale_growth=fp16_scale_growth, |
|
weight_decay=weight_decay, |
|
lr_anneal_steps=lr_anneal_steps, |
|
iterations=iterations, |
|
load_submodule_name=load_submodule_name, |
|
ignore_resume_opt=ignore_resume_opt, |
|
model_name=model_name, |
|
use_amp=use_amp, |
|
**kwargs) |
|
|
|
self.eg3d_model = self.rec_model.module.decoder.triplane_decoder |
|
|
|
self.rec_cano = True |
|
|
|
def forward_backward(self, batch, *args, **kwargs): |
|
|
|
|
|
self.mp_trainer_rec.zero_grad() |
|
batch_size = batch['img_to_encoder'].shape[0] |
|
|
|
for i in range(0, batch_size, self.microbatch): |
|
|
|
micro = { |
|
k: v[i:i + self.microbatch].to(dist_util.dev()) |
|
for k, v in batch.items() |
|
} |
|
|
|
|
|
target = { |
|
**self.eg3d_model( |
|
c=micro['nv_c'], |
|
ws=None, |
|
planes=None, |
|
sample_ray_only=True, |
|
fg_bbox=micro['nv_bbox']), |
|
} |
|
|
|
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ |
|
'patch_rendering_resolution'] |
|
cropped_target = { |
|
k: |
|
th.empty_like(v) |
|
[..., :patch_rendering_resolution, :patch_rendering_resolution] |
|
if k not in [ |
|
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', |
|
'nv_img_sr', 'c' |
|
] else v |
|
for k, v in micro.items() |
|
} |
|
|
|
|
|
for j in range(micro['img'].shape[0]): |
|
top, left, height, width = target['ray_bboxes'][ |
|
j] |
|
|
|
for key in ('img', 'depth_mask', 'depth'): |
|
|
|
|
|
|
|
|
|
|
|
cropped_target[f'{key}'][ |
|
j:j + 1] = torchvision.transforms.functional.crop( |
|
micro[f'nv_{key}'][j:j + 1], top, left, height, |
|
width) |
|
|
|
|
|
|
|
|
|
with th.autocast(device_type='cuda', |
|
dtype=th.float16, |
|
enabled=self.mp_trainer_rec.use_amp): |
|
|
|
|
|
|
|
|
|
latent = self.rec_model(img=micro['img_to_encoder'], |
|
behaviour='enc_dec_wo_triplane') |
|
|
|
pred_nv = self.rec_model( |
|
latent=latent, |
|
c=micro['nv_c'], |
|
behaviour='triplane_dec', |
|
ray_origins=target['ray_origins'], |
|
ray_directions=target['ray_directions'], |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
with self.rec_model.no_sync(): |
|
loss, loss_dict, _ = self.loss_class(pred_nv, |
|
cropped_target, |
|
step=self.step + |
|
self.resume_step, |
|
test_mode=False, |
|
return_fg_mask=True, |
|
conf_sigma_l1=None, |
|
conf_sigma_percl=None) |
|
log_rec3d_loss_dict(loss_dict) |
|
|
|
if self.rec_cano: |
|
|
|
cano_target = { |
|
**self.eg3d_model( |
|
c=micro['c'], |
|
ws=None, |
|
planes=None, |
|
sample_ray_only=True, |
|
fg_bbox=micro['bbox']), |
|
} |
|
|
|
cano_cropped_target = { |
|
k: th.empty_like(v) |
|
for k, v in cropped_target.items() |
|
} |
|
|
|
for j in range(micro['img'].shape[0]): |
|
top, left, height, width = cano_target['ray_bboxes'][ |
|
j] |
|
|
|
for key in ('img', 'depth_mask', |
|
'depth'): |
|
|
|
cano_cropped_target[key][ |
|
j:j + |
|
1] = torchvision.transforms.functional.crop( |
|
micro[key][j:j + 1], top, left, height, |
|
width) |
|
|
|
|
|
|
|
pred_cano = self.rec_model( |
|
latent=latent, |
|
c=micro['c'], |
|
behaviour='triplane_dec', |
|
ray_origins=cano_target['ray_origins'], |
|
ray_directions=cano_target['ray_directions'], |
|
) |
|
|
|
with self.rec_model.no_sync(): |
|
|
|
fg_mask = cano_cropped_target['depth_mask'].unsqueeze( |
|
1).repeat_interleave(3, 1).float() |
|
|
|
loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss( |
|
pred_cano['image_raw'], |
|
cano_cropped_target['img'], |
|
fg_mask, |
|
step=self.step + self.resume_step, |
|
test_mode=False, |
|
) |
|
|
|
loss = loss + loss_cano |
|
|
|
|
|
log_rec3d_loss_dict({ |
|
f'cano_{k}': v |
|
for k, v in loss_cano_dict.items() |
|
|
|
}) |
|
|
|
self.mp_trainer_rec.backward(loss) |
|
|
|
if dist_util.get_rank() == 0 and self.step % 500 == 0: |
|
self.log_patch_img(cropped_target, pred_nv, pred_cano) |
|
|
|
@th.inference_mode() |
|
def log_patch_img(self, micro, pred, pred_cano): |
|
|
|
|
|
def norm_depth(pred_depth): |
|
|
|
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() - |
|
pred_depth.min()) |
|
return -(pred_depth * 2 - 1) |
|
|
|
pred_img = pred['image_raw'] |
|
gt_img = micro['img'] |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
gt_depth = micro['depth'] |
|
if gt_depth.ndim == 3: |
|
gt_depth = gt_depth.unsqueeze(1) |
|
gt_depth = norm_depth(gt_depth) |
|
|
|
|
|
|
|
fg_mask = pred['image_mask'] * 2 - 1 |
|
input_fg_mask = pred_cano['image_mask'] * 2 - 1 |
|
if 'image_depth' in pred: |
|
pred_depth = norm_depth(pred['image_depth']) |
|
pred_cano_depth = norm_depth(pred_cano['image_depth']) |
|
else: |
|
pred_depth = th.zeros_like(gt_depth) |
|
pred_cano_depth = th.zeros_like(gt_depth) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_vis = th.cat([ |
|
pred_img, |
|
pred_depth.repeat_interleave(3, dim=1), |
|
fg_mask.repeat_interleave(3, dim=1), |
|
], |
|
dim=-1) |
|
|
|
pred_vis_nv = th.cat([ |
|
pred_cano['image_raw'], |
|
pred_cano_depth.repeat_interleave(3, dim=1), |
|
input_fg_mask.repeat_interleave(3, dim=1), |
|
], |
|
dim=-1) |
|
|
|
pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) |
|
|
|
gt_vis = th.cat([ |
|
gt_img, |
|
gt_depth.repeat_interleave(3, dim=1), |
|
th.zeros_like(gt_img) |
|
], |
|
dim=-1) |
|
|
|
|
|
|
|
|
|
|
|
|
|
vis = th.cat([gt_vis, pred_vis], dim=-2) |
|
|
|
|
|
vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] // |
|
64) |
|
torchvision.utils.save_image( |
|
vis_tensor, |
|
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', |
|
value_range=(-1, 1), |
|
normalize=True) |
|
|
|
logger.log('log vis to: ', |
|
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainLoop3DRecNVPatchSingleForward(TrainLoop3DRecNVPatch): |
|
|
|
def __init__(self, |
|
*, |
|
rec_model, |
|
loss_class, |
|
data, |
|
eval_data, |
|
batch_size, |
|
microbatch, |
|
lr, |
|
ema_rate, |
|
log_interval, |
|
eval_interval, |
|
save_interval, |
|
resume_checkpoint, |
|
use_fp16=False, |
|
fp16_scale_growth=0.001, |
|
weight_decay=0, |
|
lr_anneal_steps=0, |
|
iterations=10001, |
|
load_submodule_name='', |
|
ignore_resume_opt=False, |
|
model_name='rec', |
|
use_amp=False, |
|
**kwargs): |
|
super().__init__(rec_model=rec_model, |
|
loss_class=loss_class, |
|
data=data, |
|
eval_data=eval_data, |
|
batch_size=batch_size, |
|
microbatch=microbatch, |
|
lr=lr, |
|
ema_rate=ema_rate, |
|
log_interval=log_interval, |
|
eval_interval=eval_interval, |
|
save_interval=save_interval, |
|
resume_checkpoint=resume_checkpoint, |
|
use_fp16=use_fp16, |
|
fp16_scale_growth=fp16_scale_growth, |
|
weight_decay=weight_decay, |
|
lr_anneal_steps=lr_anneal_steps, |
|
iterations=iterations, |
|
load_submodule_name=load_submodule_name, |
|
ignore_resume_opt=ignore_resume_opt, |
|
model_name=model_name, |
|
use_amp=use_amp, |
|
**kwargs) |
|
|
|
def forward_backward(self, batch, *args, **kwargs): |
|
|
|
|
|
self.mp_trainer_rec.zero_grad() |
|
batch_size = batch['img_to_encoder'].shape[0] |
|
|
|
batch.pop('caption') |
|
batch.pop('ins') |
|
|
|
|
|
for i in range(0, batch_size, self.microbatch): |
|
|
|
micro = { |
|
k: |
|
v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( |
|
v, th.Tensor) else v[i:i + self.microbatch] |
|
for k, v in batch.items() |
|
} |
|
|
|
|
|
target = { |
|
**self.eg3d_model( |
|
c=micro['nv_c'], |
|
ws=None, |
|
planes=None, |
|
sample_ray_only=True, |
|
fg_bbox=micro['nv_bbox']), |
|
} |
|
|
|
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ |
|
'patch_rendering_resolution'] |
|
cropped_target = { |
|
k: |
|
th.empty_like(v) |
|
[..., :patch_rendering_resolution, :patch_rendering_resolution] |
|
if k not in [ |
|
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', |
|
'nv_img_sr', 'c', 'caption', 'nv_caption' |
|
] else v |
|
for k, v in micro.items() |
|
} |
|
|
|
|
|
for j in range(micro['img'].shape[0]): |
|
top, left, height, width = target['ray_bboxes'][ |
|
j] |
|
|
|
for key in ('img', 'depth_mask', 'depth'): |
|
|
|
|
|
|
|
|
|
|
|
cropped_target[f'{key}'][ |
|
j:j + 1] = torchvision.transforms.functional.crop( |
|
micro[f'nv_{key}'][j:j + 1], top, left, height, |
|
width) |
|
|
|
|
|
cano_target = { |
|
**self.eg3d_model( |
|
c=micro['c'], |
|
ws=None, |
|
planes=None, |
|
sample_ray_only=True, |
|
fg_bbox=micro['bbox']), |
|
} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latent = self.rec_model(img=micro['img_to_encoder'], |
|
behaviour='enc_dec_wo_triplane') |
|
|
|
|
|
with th.autocast(device_type='cuda', |
|
dtype=th.float16, |
|
enabled=self.mp_trainer_rec.use_amp): |
|
|
|
|
|
|
|
instance_mv_num = batch_size // 4 |
|
|
|
|
|
c = th.cat([ |
|
micro['nv_c'].roll(instance_mv_num * i, dims=0) |
|
for i in range(1, 4) |
|
] |
|
|
|
) |
|
|
|
ray_origins = th.cat( |
|
[ |
|
target['ray_origins'].roll(instance_mv_num * i, dims=0) |
|
for i in range(1, 4) |
|
] |
|
|
|
, |
|
0) |
|
|
|
ray_directions = th.cat([ |
|
target['ray_directions'].roll(instance_mv_num * i, dims=0) |
|
for i in range(1, 4) |
|
] |
|
|
|
) |
|
|
|
pred_nv_cano = self.rec_model( |
|
|
|
latent={ |
|
'latent_after_vit': |
|
|
|
latent['latent_after_vit'].repeat(3, 1, 1, 1) |
|
}, |
|
c=c, |
|
behaviour='triplane_dec', |
|
|
|
|
|
ray_origins=ray_origins, |
|
ray_directions=ray_directions, |
|
) |
|
|
|
pred_nv_cano.update( |
|
latent |
|
) |
|
|
|
|
|
|
|
|
|
gt = { |
|
k: |
|
th.cat( |
|
[ |
|
v.roll(instance_mv_num * i, dims=0) |
|
for i in range(1, 4) |
|
] |
|
|
|
, |
|
0) |
|
for k, v in cropped_target.items() |
|
} |
|
|
|
with self.rec_model.no_sync(): |
|
loss, loss_dict, _ = self.loss_class( |
|
pred_nv_cano, |
|
gt, |
|
step=self.step + self.resume_step, |
|
test_mode=False, |
|
return_fg_mask=True, |
|
conf_sigma_l1=None, |
|
conf_sigma_percl=None) |
|
log_rec3d_loss_dict(loss_dict) |
|
|
|
self.mp_trainer_rec.backward(loss) |
|
|
|
|
|
|
|
|
|
|
|
if dist_util.get_rank() == 0 and self.step % 500 == 0: |
|
micro_bs = micro['img_to_encoder'].shape[0] |
|
self.log_patch_img( |
|
cropped_target, |
|
{ |
|
k: pred_nv_cano[k][-micro_bs:] |
|
for k in ['image_raw', 'image_depth', 'image_mask'] |
|
}, |
|
{ |
|
k: pred_nv_cano[k][:micro_bs] |
|
for k in ['image_raw', 'image_depth', 'image_mask'] |
|
}, |
|
) |
|
|
|
def eval_loop(self): |
|
return super().eval_loop() |
|
|
|
@th.inference_mode() |
|
|
|
def eval_novelview_loop_old(self, camera=None): |
|
|
|
|
|
all_loss_dict = [] |
|
novel_view_micro = {} |
|
|
|
|
|
|
|
export_mesh = True |
|
if export_mesh: |
|
Path(f'{logger.get_dir()}/FID_Cals/').mkdir(parents=True, |
|
exist_ok=True) |
|
|
|
|
|
|
|
batch = {} |
|
|
|
|
|
|
|
|
|
|
|
|
|
for eval_idx, render_reference in enumerate(tqdm(self.eval_data)): |
|
|
|
if eval_idx > 500: |
|
break |
|
|
|
video_out = imageio.get_writer( |
|
f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_{eval_idx}.mp4', |
|
mode='I', |
|
fps=25, |
|
codec='libx264') |
|
|
|
with open( |
|
f'{logger.get_dir()}/triplane_{self.step+self.resume_step}_{eval_idx}_caption.txt', |
|
'w') as f: |
|
f.write(render_reference['caption']) |
|
|
|
for key in ['ins', 'bbox', 'caption']: |
|
if key in render_reference: |
|
render_reference.pop(key) |
|
|
|
real_flag = False |
|
mv_flag = False |
|
if render_reference['c'].shape[:2] == (1, 40): |
|
real_flag = True |
|
|
|
|
|
render_reference = [{ |
|
k: v[0][idx:idx + 1] |
|
for k, v in render_reference.items() |
|
} for idx in range(40)] |
|
|
|
elif render_reference['c'].shape[0] == 8: |
|
mv_flag = True |
|
|
|
render_reference = { |
|
k: v[:4] |
|
for k, v in render_reference.items() |
|
} |
|
|
|
|
|
torchvision.utils.save_image( |
|
render_reference[0:4]['img'], |
|
logger.get_dir() + '/FID_Cals/{}_inp.png'.format(eval_idx), |
|
padding=0, |
|
normalize=True, |
|
value_range=(-1, 1), |
|
) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
else: |
|
|
|
st() |
|
render_reference = [{ |
|
k: v[idx:idx + 1] |
|
for k, v in render_reference.items() |
|
} for idx in range(40)] |
|
|
|
|
|
render_reference[0]['img_to_encoder'] = render_reference[14][ |
|
'img_to_encoder'] |
|
render_reference[0]['img'] = render_reference[14][ |
|
'img'] |
|
|
|
|
|
torchvision.utils.save_image( |
|
render_reference[0]['img'], |
|
logger.get_dir() + '/FID_Cals/{}_gt.png'.format(eval_idx), |
|
padding=0, |
|
normalize=True, |
|
value_range=(-1, 1)) |
|
|
|
|
|
for i, batch in enumerate(render_reference): |
|
|
|
|
|
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()} |
|
|
|
st() |
|
if i == 0: |
|
if mv_flag: |
|
novel_view_micro = None |
|
else: |
|
novel_view_micro = { |
|
k: |
|
v[0:1].to(dist_util.dev()).repeat_interleave( |
|
|
|
micro['img'].shape[0], |
|
0) if isinstance(v, th.Tensor) else v[0:1] |
|
for k, v in batch.items() |
|
} |
|
|
|
else: |
|
if i == 1: |
|
|
|
|
|
if export_mesh: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mesh_size = 384 |
|
|
|
|
|
|
|
mesh_thres = 10 |
|
import mcubes |
|
import trimesh |
|
dump_path = f'{logger.get_dir()}/mesh/' |
|
|
|
os.makedirs(dump_path, exist_ok=True) |
|
|
|
grid_out = self.rec_model( |
|
latent=pred, |
|
grid_size=mesh_size, |
|
behaviour='triplane_decode_grid', |
|
) |
|
|
|
vtx, faces = mcubes.marching_cubes( |
|
grid_out['sigma'].squeeze(0).squeeze( |
|
-1).cpu().numpy(), mesh_thres) |
|
vtx = vtx / (mesh_size - 1) * 2 - 1 |
|
|
|
|
|
|
|
|
|
|
|
|
|
mesh = trimesh.Trimesh( |
|
vertices=vtx, |
|
faces=faces, |
|
) |
|
|
|
mesh_dump_path = os.path.join( |
|
dump_path, f'{eval_idx}.ply') |
|
mesh.export(mesh_dump_path, 'ply') |
|
|
|
print(f"Mesh dumped to {dump_path}") |
|
del grid_out, mesh |
|
th.cuda.empty_cache() |
|
|
|
|
|
|
|
|
|
novel_view_micro = { |
|
k: |
|
v[0:1].to(dist_util.dev()).repeat_interleave( |
|
micro['img'].shape[0], 0) |
|
for k, v in novel_view_micro.items() |
|
} |
|
|
|
pred = self.rec_model(img=novel_view_micro['img_to_encoder'], |
|
c=micro['c']) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if not real_flag: |
|
_, loss_dict = self.loss_class(pred, micro, test_mode=True) |
|
all_loss_dict.append(loss_dict) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
pred_depth = pred['image_depth'] |
|
pred_depth = (pred_depth - pred_depth.min()) / ( |
|
pred_depth.max() - pred_depth.min()) |
|
if 'image_sr' in pred: |
|
|
|
if pred['image_sr'].shape[-1] == 512: |
|
|
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_512(pred['image_raw']), pred['image_sr'], |
|
self.pool_512(pred_depth).repeat_interleave(3, |
|
dim=1) |
|
], |
|
dim=-1) |
|
|
|
elif pred['image_sr'].shape[-1] == 256: |
|
|
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_256(pred['image_raw']), pred['image_sr'], |
|
self.pool_256(pred_depth).repeat_interleave(3, |
|
dim=1) |
|
], |
|
dim=-1) |
|
|
|
else: |
|
pred_vis = th.cat([ |
|
micro['img_sr'], |
|
self.pool_128(pred['image_raw']), |
|
self.pool_128(pred['image_sr']), |
|
self.pool_128(pred_depth).repeat_interleave(3, |
|
dim=1) |
|
], |
|
dim=-1) |
|
|
|
else: |
|
|
|
|
|
|
|
|
|
|
|
|
|
pooled_depth = self.pool_128(pred_depth).repeat_interleave( |
|
3, dim=1) |
|
pred_vis = th.cat( |
|
[ |
|
|
|
self.pool_128(novel_view_micro['img'] |
|
), |
|
self.pool_128(pred['image_raw']), |
|
pooled_depth, |
|
], |
|
dim=-1) |
|
|
|
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy() |
|
vis = vis * 127.5 + 127.5 |
|
vis = vis.clip(0, 255).astype(np.uint8) |
|
|
|
if export_mesh: |
|
|
|
torchvision.utils.save_image( |
|
pred['image_raw'], |
|
logger.get_dir() + |
|
'/FID_Cals/{}_{}.png'.format(eval_idx, i), |
|
padding=0, |
|
normalize=True, |
|
value_range=(-1, 1)) |
|
|
|
torchvision.utils.save_image( |
|
pooled_depth, |
|
logger.get_dir() + |
|
'/FID_Cals/{}_{}_dpeth.png'.format(eval_idx, i), |
|
padding=0, |
|
normalize=True, |
|
value_range=(0, 1)) |
|
|
|
|
|
|
|
for j in range(vis.shape[0]): |
|
video_out.append_data(vis[j]) |
|
|
|
video_out.close() |
|
|
|
|
|
if not real_flag or mv_flag: |
|
val_scores_for_logging = calc_average_loss(all_loss_dict) |
|
with open(os.path.join(logger.get_dir(), 'scores_novelview.json'), |
|
'a') as f: |
|
json.dump({'step': self.step, **val_scores_for_logging}, f) |
|
|
|
|
|
for k, v in val_scores_for_logging.items(): |
|
self.writer.add_scalar(f'Eval/NovelView/{k}', v, |
|
self.step + self.resume_step) |
|
|
|
del video_out |
|
|
|
|
|
|
|
th.cuda.empty_cache() |
|
|
|
@th.inference_mode() |
|
|
|
def eval_novelview_loop(self, camera=None, save_latent=False): |
|
|
|
if save_latent: |
|
latent_dir = Path(f'{logger.get_dir()}/latent_dir') |
|
latent_dir.mkdir(exist_ok=True, parents=True) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
eval_batch_size = 40 |
|
|
|
for eval_idx, micro in enumerate(tqdm(self.eval_data)): |
|
|
|
|
|
|
|
|
|
latent = self.rec_model( |
|
img=micro['img_to_encoder'][:4], |
|
behaviour='encoder_vae') |
|
|
|
if micro['img'].shape[0] == 40: |
|
assert eval_batch_size == 40 |
|
|
|
if save_latent: |
|
|
|
|
|
latent_save_dir = f'{logger.get_dir()}/latent_dir/{micro["ins"][0]}' |
|
Path(latent_save_dir).mkdir(parents=True, exist_ok=True) |
|
|
|
np.save(f'{latent_save_dir}/latent.npy', |
|
latent[self.latent_name][0].cpu().numpy()) |
|
assert all([ |
|
micro['ins'][0] == micro['ins'][i] |
|
for i in range(micro['c'].shape[0]) |
|
]) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if eval_idx < 50: |
|
|
|
self.render_video_given_triplane( |
|
latent[self.latent_name], |
|
self.rec_model, |
|
name_prefix=f'{self.step + self.resume_step}_{eval_idx}', |
|
save_img=False, |
|
render_reference={'c': camera}, |
|
save_mesh=True) |
|
|
|
|
|
class TrainLoop3DRecNVPatchSingleForwardMV(TrainLoop3DRecNVPatchSingleForward): |
|
|
|
def __init__(self, |
|
*, |
|
rec_model, |
|
loss_class, |
|
data, |
|
eval_data, |
|
batch_size, |
|
microbatch, |
|
lr, |
|
ema_rate, |
|
log_interval, |
|
eval_interval, |
|
save_interval, |
|
resume_checkpoint, |
|
use_fp16=False, |
|
fp16_scale_growth=0.001, |
|
weight_decay=0, |
|
lr_anneal_steps=0, |
|
iterations=10001, |
|
load_submodule_name='', |
|
ignore_resume_opt=False, |
|
model_name='rec', |
|
use_amp=False, |
|
**kwargs): |
|
super().__init__(rec_model=rec_model, |
|
loss_class=loss_class, |
|
data=data, |
|
eval_data=eval_data, |
|
batch_size=batch_size, |
|
microbatch=microbatch, |
|
lr=lr, |
|
ema_rate=ema_rate, |
|
log_interval=log_interval, |
|
eval_interval=eval_interval, |
|
save_interval=save_interval, |
|
resume_checkpoint=resume_checkpoint, |
|
use_fp16=use_fp16, |
|
fp16_scale_growth=fp16_scale_growth, |
|
weight_decay=weight_decay, |
|
lr_anneal_steps=lr_anneal_steps, |
|
iterations=iterations, |
|
load_submodule_name=load_submodule_name, |
|
ignore_resume_opt=ignore_resume_opt, |
|
model_name=model_name, |
|
use_amp=use_amp, |
|
**kwargs) |
|
|
|
def forward_backward(self, batch, behaviour='g_step', *args, **kwargs): |
|
|
|
|
|
self.mp_trainer_rec.zero_grad() |
|
batch_size = batch['img_to_encoder'].shape[0] |
|
|
|
batch.pop('caption') |
|
batch.pop('ins') |
|
if '__key__' in batch.keys(): |
|
batch.pop('__key__') |
|
|
|
for i in range(0, batch_size, self.microbatch): |
|
|
|
micro = { |
|
k: |
|
v[i:i + self.microbatch].to(dist_util.dev()) if isinstance( |
|
v, th.Tensor) else v[i:i + self.microbatch] |
|
for k, v in batch.items() |
|
} |
|
|
|
|
|
|
|
nv_c = th.cat([micro['nv_c'], micro['c']]) |
|
|
|
target = { |
|
**self.eg3d_model( |
|
c=nv_c, |
|
ws=None, |
|
planes=None, |
|
sample_ray_only=True, |
|
fg_bbox=th.cat([micro['nv_bbox'], micro['bbox']])), |
|
} |
|
|
|
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[ |
|
'patch_rendering_resolution'] |
|
cropped_target = { |
|
k: |
|
th.empty_like(v).repeat_interleave(2, 0) |
|
|
|
[..., :patch_rendering_resolution, :patch_rendering_resolution] |
|
if k not in [ |
|
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder', |
|
'nv_img_sr', 'c', 'caption', 'nv_caption' |
|
] else v |
|
for k, v in micro.items() |
|
} |
|
|
|
|
|
for j in range(2 * self.microbatch): |
|
top, left, height, width = target['ray_bboxes'][ |
|
j] |
|
|
|
for key in ('img', 'depth_mask', 'depth'): |
|
|
|
if j < self.microbatch: |
|
cropped_target[f'{key}'][ |
|
j:j + 1] = torchvision.transforms.functional.crop( |
|
micro[f'nv_{key}'][j:j + 1], top, left, height, |
|
width) |
|
else: |
|
cropped_target[f'{key}'][ |
|
j:j + 1] = torchvision.transforms.functional.crop( |
|
micro[f'{key}'][j - self.microbatch:j - |
|
self.microbatch + 1], top, |
|
left, height, width) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
latent = self.rec_model(img=micro['img_to_encoder'], |
|
behaviour='enc_dec_wo_triplane') |
|
|
|
|
|
with th.autocast(device_type='cuda', |
|
dtype=th.float16, |
|
enabled=self.mp_trainer_rec.use_amp): |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
ray_origins = target['ray_origins'] |
|
ray_directions = target['ray_directions'] |
|
|
|
pred_nv_cano = self.rec_model( |
|
|
|
latent={ |
|
'latent_after_vit': |
|
latent['latent_after_vit'].repeat_interleave(4, dim=0).repeat(2,1,1,1) |
|
|
|
}, |
|
c=nv_c, |
|
behaviour='triplane_dec', |
|
ray_origins=ray_origins, |
|
ray_directions=ray_directions, |
|
) |
|
|
|
pred_nv_cano.update( |
|
latent |
|
) |
|
gt = cropped_target |
|
|
|
with self.rec_model.no_sync(): |
|
loss, loss_dict, _ = self.loss_class( |
|
pred_nv_cano, |
|
gt, |
|
step=self.step + self.resume_step, |
|
test_mode=False, |
|
return_fg_mask=True, |
|
behaviour=behaviour, |
|
conf_sigma_l1=None, |
|
conf_sigma_percl=None) |
|
log_rec3d_loss_dict(loss_dict) |
|
|
|
self.mp_trainer_rec.backward(loss) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
if dist_util.get_rank() == 0 and self.step % 500 == 0 and i == 0: |
|
try: |
|
torchvision.utils.save_image( |
|
th.cat( |
|
[cropped_target['img'], pred_nv_cano['image_raw'] |
|
], ), |
|
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg', |
|
normalize=True) |
|
|
|
logger.log( |
|
'log vis to: ', |
|
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg') |
|
except Exception as e: |
|
logger.log(e) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
class TrainLoop3DRecNVPatchSingleForwardMVAdvLoss( |
|
TrainLoop3DRecNVPatchSingleForwardMV): |
|
|
|
def __init__(self, |
|
*, |
|
rec_model, |
|
loss_class, |
|
data, |
|
eval_data, |
|
batch_size, |
|
microbatch, |
|
lr, |
|
ema_rate, |
|
log_interval, |
|
eval_interval, |
|
save_interval, |
|
resume_checkpoint, |
|
use_fp16=False, |
|
fp16_scale_growth=0.001, |
|
weight_decay=0, |
|
lr_anneal_steps=0, |
|
iterations=10001, |
|
load_submodule_name='', |
|
ignore_resume_opt=False, |
|
model_name='rec', |
|
use_amp=False, |
|
**kwargs): |
|
super().__init__(rec_model=rec_model, |
|
loss_class=loss_class, |
|
data=data, |
|
eval_data=eval_data, |
|
batch_size=batch_size, |
|
microbatch=microbatch, |
|
lr=lr, |
|
ema_rate=ema_rate, |
|
log_interval=log_interval, |
|
eval_interval=eval_interval, |
|
save_interval=save_interval, |
|
resume_checkpoint=resume_checkpoint, |
|
use_fp16=use_fp16, |
|
fp16_scale_growth=fp16_scale_growth, |
|
weight_decay=weight_decay, |
|
lr_anneal_steps=lr_anneal_steps, |
|
iterations=iterations, |
|
load_submodule_name=load_submodule_name, |
|
ignore_resume_opt=ignore_resume_opt, |
|
model_name=model_name, |
|
use_amp=use_amp, |
|
**kwargs) |
|
|
|
|
|
disc_params = self.loss_class.get_trainable_parameters() |
|
|
|
self.mp_trainer_disc = MixedPrecisionTrainer( |
|
model=self.loss_class.discriminator, |
|
use_fp16=self.use_fp16, |
|
fp16_scale_growth=fp16_scale_growth, |
|
model_name='disc', |
|
use_amp=use_amp, |
|
model_params=disc_params) |
|
|
|
|
|
self.opt_disc = AdamW( |
|
self.mp_trainer_disc.master_params, |
|
lr=self.lr, |
|
betas=(0, 0.999), |
|
eps=1e-8) |
|
|
|
|
|
if self.use_ddp: |
|
self.ddp_disc = DDP( |
|
self.loss_class.discriminator, |
|
device_ids=[dist_util.dev()], |
|
output_device=dist_util.dev(), |
|
broadcast_buffers=False, |
|
bucket_cap_mb=128, |
|
find_unused_parameters=False, |
|
) |
|
else: |
|
self.ddp_disc = self.loss_class.discriminator |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
def save(self, mp_trainer=None, model_name='rec'): |
|
if mp_trainer is None: |
|
mp_trainer = self.mp_trainer_rec |
|
|
|
def save_checkpoint(rate, params): |
|
state_dict = mp_trainer.master_params_to_state_dict(params) |
|
if dist_util.get_rank() == 0: |
|
logger.log(f"saving model {model_name} {rate}...") |
|
if not rate: |
|
filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt" |
|
else: |
|
filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt" |
|
with bf.BlobFile(bf.join(get_blob_logdir(), filename), |
|
"wb") as f: |
|
th.save(state_dict, f) |
|
|
|
save_checkpoint(0, mp_trainer.master_params) |
|
|
|
dist.barrier() |
|
|
|
def run_step(self, batch, step='g_step'): |
|
|
|
|
|
if step == 'g_step': |
|
self.forward_backward(batch, behaviour='g_step') |
|
took_step_g_rec = self.mp_trainer_rec.optimize(self.opt) |
|
|
|
if took_step_g_rec: |
|
self._update_ema() |
|
|
|
elif step == 'd_step': |
|
self.forward_backward(batch, behaviour='d_step') |
|
_ = self.mp_trainer_disc.optimize(self.opt_disc) |
|
|
|
self._anneal_lr() |
|
self.log_step() |
|
|
|
def run_loop(self, batch=None): |
|
while (not self.lr_anneal_steps |
|
or self.step + self.resume_step < self.lr_anneal_steps): |
|
|
|
batch = next(self.data) |
|
self.run_step(batch, 'g_step') |
|
|
|
batch = next(self.data) |
|
self.run_step(batch, 'd_step') |
|
|
|
if self.step % 1000 == 0: |
|
dist_util.synchronize() |
|
if self.step % 10000 == 0: |
|
th.cuda.empty_cache() |
|
|
|
if self.step % self.log_interval == 0 and dist_util.get_rank( |
|
) == 0: |
|
out = logger.dumpkvs() |
|
|
|
for k, v in out.items(): |
|
self.writer.add_scalar(f'Loss/{k}', v, |
|
self.step + self.resume_step) |
|
|
|
if self.step % self.eval_interval == 0 and self.step != 0: |
|
if dist_util.get_rank() == 0: |
|
try: |
|
self.eval_loop() |
|
except Exception as e: |
|
logger.log(e) |
|
dist_util.synchronize() |
|
|
|
|
|
if self.step % self.save_interval == 0: |
|
self.save() |
|
self.save(self.mp_trainer_disc, |
|
self.mp_trainer_disc.model_name) |
|
dist_util.synchronize() |
|
|
|
if os.environ.get("DIFFUSION_TRAINING_TEST", |
|
"") and self.step > 0: |
|
return |
|
|
|
self.step += 1 |
|
|
|
if self.step > self.iterations: |
|
logger.log('reached maximum iterations, exiting') |
|
|
|
|
|
if (self.step - |
|
1) % self.save_interval != 0 and self.step != 1: |
|
self.save() |
|
|
|
exit() |
|
|
|
|
|
|
|
if (self.step - 1) % self.save_interval != 0: |
|
self.save() |
|
self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name) |
|
|