LN3Diff / nsr /train_nv_util.py
NIRVANALAN
release file
87c126b
import copy
# import imageio.v3
import functools
import json
import os
from pathlib import Path
from pdb import set_trace as st
from einops import rearrange
import webdataset as wds
import traceback
import blobfile as bf
import imageio
import numpy as np
# from sympy import O
import torch as th
import torch.distributed as dist
import torchvision
from PIL import Image
from torch.nn.parallel.distributed import DistributedDataParallel as DDP
from torch.optim import AdamW
from torch.utils.tensorboard import SummaryWriter
from tqdm import tqdm
from guided_diffusion import dist_util, logger
from guided_diffusion.fp16_util import MixedPrecisionTrainer
from guided_diffusion.nn import update_ema
from guided_diffusion.resample import LossAwareSampler, UniformSampler
from guided_diffusion.train_util import (calc_average_loss,
find_ema_checkpoint,
find_resume_checkpoint,
get_blob_logdir, log_rec3d_loss_dict,
parse_resume_step_from_filename)
from .camera_utils import LookAtPoseSampler, FOV_to_intrinsics
from .train_util import TrainLoop3DRec
class TrainLoop3DRecNV(TrainLoop3DRec):
# supervise the training of novel view
def __init__(self,
*,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
self.rec_cano = True
def forward_backward(self, batch, *args, **kwargs):
# return super().forward_backward(batch, *args, **kwargs)
self.mp_trainer_rec.zero_grad()
batch_size = batch['img_to_encoder'].shape[0]
for i in range(0, batch_size, self.microbatch):
# st()
micro = {
k: v[i:i + self.microbatch].to(dist_util.dev())
for k, v in batch.items()
}
# ! concat novel-view? next version. also add self reconstruction, patch-based loss in the next version. verify novel-view prediction first.
# wrap forward within amp
with th.autocast(device_type='cuda',
dtype=th.float16,
enabled=self.mp_trainer_rec.use_amp):
target_nvs = {}
target_cano = {}
latent = self.rec_model(img=micro['img_to_encoder'],
behaviour='enc_dec_wo_triplane')
pred = self.rec_model(
latent=latent,
c=micro['nv_c'], # predict novel view here
behaviour='triplane_dec')
for k, v in micro.items():
if k[:2] == 'nv':
orig_key = k.replace('nv_', '')
target_nvs[orig_key] = v
target_cano[orig_key] = micro[orig_key]
with self.rec_model.no_sync(): # type: ignore
loss, loss_dict, fg_mask = self.loss_class(
pred,
target_nvs,
step=self.step + self.resume_step,
test_mode=False,
return_fg_mask=True,
conf_sigma_l1=None,
conf_sigma_percl=None)
log_rec3d_loss_dict(loss_dict)
if self.rec_cano:
pred_cano = self.rec_model(latent=latent,
c=micro['c'],
behaviour='triplane_dec')
with self.rec_model.no_sync(): # type: ignore
fg_mask = target_cano['depth_mask'].unsqueeze(
1).repeat_interleave(3, 1).float()
loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss(
pred_cano['image_raw'],
target_cano['img'],
fg_mask,
step=self.step + self.resume_step,
test_mode=False,
)
loss = loss + loss_cano
# remove redundant log
log_rec3d_loss_dict({
f'cano_{k}': v
for k, v in loss_cano_dict.items()
# if "loss" in k
})
self.mp_trainer_rec.backward(loss)
if dist_util.get_rank() == 0 and self.step % 500 == 0:
if self.rec_cano:
self.log_img(micro, pred, pred_cano)
else:
self.log_img(micro, pred, None)
@th.inference_mode()
def log_img(self, micro, pred, pred_cano):
# gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
def norm_depth(pred_depth): # to [-1,1]
# pred_depth = pred['image_depth']
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
pred_depth.min())
return -(pred_depth * 2 - 1)
pred_img = pred['image_raw']
gt_img = micro['img']
# infer novel view also
# if self.loss_class.opt.symmetry_loss:
# pred_nv_img = nvs_pred
# else:
# ! replace with novel view prediction
# ! log another novel-view prediction
# pred_nv_img = self.rec_model(
# img=micro['img_to_encoder'],
# c=self.novel_view_poses) # pred: (B, 3, 64, 64)
# if 'depth' in micro:
gt_depth = micro['depth']
if gt_depth.ndim == 3:
gt_depth = gt_depth.unsqueeze(1)
gt_depth = norm_depth(gt_depth)
# gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
# gt_depth.min())
# if True:
fg_mask = pred['image_mask'] * 2 - 1 # 0-1
input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1
if 'image_depth' in pred:
pred_depth = norm_depth(pred['image_depth'])
pred_nv_depth = norm_depth(pred_cano['image_depth'])
else:
pred_depth = th.zeros_like(gt_depth)
pred_nv_depth = th.zeros_like(gt_depth)
if 'image_sr' in pred:
if pred['image_sr'].shape[-1] == 512:
pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']],
dim=-1)
gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']],
dim=-1)
pred_depth = self.pool_512(pred_depth)
gt_depth = self.pool_512(gt_depth)
elif pred['image_sr'].shape[-1] == 256:
pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']],
dim=-1)
gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']],
dim=-1)
pred_depth = self.pool_256(pred_depth)
gt_depth = self.pool_256(gt_depth)
else:
pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']],
dim=-1)
gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']],
dim=-1)
gt_depth = self.pool_128(gt_depth)
pred_depth = self.pool_128(pred_depth)
else:
gt_img = self.pool_64(gt_img)
gt_depth = self.pool_64(gt_depth)
pred_vis = th.cat([
pred_img,
pred_depth.repeat_interleave(3, dim=1),
fg_mask.repeat_interleave(3, dim=1),
],
dim=-1) # B, 3, H, W
pred_vis_nv = th.cat([
pred_cano['image_raw'],
pred_nv_depth.repeat_interleave(3, dim=1),
input_fg_mask.repeat_interleave(3, dim=1),
],
dim=-1) # B, 3, H, W
pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim
gt_vis = th.cat([
gt_img,
gt_depth.repeat_interleave(3, dim=1),
th.zeros_like(gt_img)
],
dim=-1) # TODO, fail to load depth. range [0, 1]
if 'conf_sigma' in pred:
gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder
# vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
vis = th.cat([gt_vis, pred_vis], dim=-2)
# .permute(
# 0, 2, 3, 1).cpu()
vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] //
64) # HWC
torchvision.utils.save_image(
vis_tensor,
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg',
value_range=(-1, 1),
normalize=True)
# vis = vis.numpy() * 127.5 + 127.5
# vis = vis.clip(0, 255).astype(np.uint8)
# Image.fromarray(vis).save(
# f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
logger.log('log vis to: ',
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
# self.writer.add_image(f'images',
# vis,
# self.step + self.resume_step,
# dataformats='HWC')
# return pred
class TrainLoop3DRecNVPatch(TrainLoop3DRecNV):
# add patch rendering
def __init__(self,
*,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
# the rendrer
self.eg3d_model = self.rec_model.module.decoder.triplane_decoder # type: ignore
# self.rec_cano = False
self.rec_cano = True
def forward_backward(self, batch, *args, **kwargs):
# add patch sampling
self.mp_trainer_rec.zero_grad()
batch_size = batch['img_to_encoder'].shape[0]
for i in range(0, batch_size, self.microbatch):
micro = {
k: v[i:i + self.microbatch].to(dist_util.dev())
for k, v in batch.items()
}
# ! sample rendering patch
target = {
**self.eg3d_model(
c=micro['nv_c'], # type: ignore
ws=None,
planes=None,
sample_ray_only=True,
fg_bbox=micro['nv_bbox']), # rays o / dir
}
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[
'patch_rendering_resolution'] # type: ignore
cropped_target = {
k:
th.empty_like(v)
[..., :patch_rendering_resolution, :patch_rendering_resolution]
if k not in [
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder',
'nv_img_sr', 'c'
] else v
for k, v in micro.items()
}
# crop according to uv sampling
for j in range(micro['img'].shape[0]):
top, left, height, width = target['ray_bboxes'][
j] # list of tuple
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
for key in ('img', 'depth_mask', 'depth'): # type: ignore
# target[key][i:i+1] = torchvision.transforms.functional.crop(
# cropped_target[key][
# j:j + 1] = torchvision.transforms.functional.crop(
# micro[key][j:j + 1], top, left, height, width)
cropped_target[f'{key}'][ # ! no nv_ here
j:j + 1] = torchvision.transforms.functional.crop(
micro[f'nv_{key}'][j:j + 1], top, left, height,
width)
# target.update(cropped_target)
# wrap forward within amp
with th.autocast(device_type='cuda',
dtype=th.float16,
enabled=self.mp_trainer_rec.use_amp):
# target_nvs = {}
# target_cano = {}
latent = self.rec_model(img=micro['img_to_encoder'],
behaviour='enc_dec_wo_triplane')
pred_nv = self.rec_model(
latent=latent,
c=micro['nv_c'], # predict novel view here
behaviour='triplane_dec',
ray_origins=target['ray_origins'],
ray_directions=target['ray_directions'],
)
# ! directly retrieve from target
# for k, v in target.items():
# if k[:2] == 'nv':
# orig_key = k.replace('nv_', '')
# target_nvs[orig_key] = v
# target_cano[orig_key] = target[orig_key]
with self.rec_model.no_sync(): # type: ignore
loss, loss_dict, _ = self.loss_class(pred_nv,
cropped_target,
step=self.step +
self.resume_step,
test_mode=False,
return_fg_mask=True,
conf_sigma_l1=None,
conf_sigma_percl=None)
log_rec3d_loss_dict(loss_dict)
if self.rec_cano:
cano_target = {
**self.eg3d_model(
c=micro['c'], # type: ignore
ws=None,
planes=None,
sample_ray_only=True,
fg_bbox=micro['bbox']), # rays o / dir
}
cano_cropped_target = {
k: th.empty_like(v)
for k, v in cropped_target.items()
}
for j in range(micro['img'].shape[0]):
top, left, height, width = cano_target['ray_bboxes'][
j] # list of tuple
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
for key in ('img', 'depth_mask',
'depth'): # type: ignore
# target[key][i:i+1] = torchvision.transforms.functional.crop(
cano_cropped_target[key][
j:j +
1] = torchvision.transforms.functional.crop(
micro[key][j:j + 1], top, left, height,
width)
# cano_target.update(cano_cropped_target)
pred_cano = self.rec_model(
latent=latent,
c=micro['c'],
behaviour='triplane_dec',
ray_origins=cano_target['ray_origins'],
ray_directions=cano_target['ray_directions'],
)
with self.rec_model.no_sync(): # type: ignore
fg_mask = cano_cropped_target['depth_mask'].unsqueeze(
1).repeat_interleave(3, 1).float()
loss_cano, loss_cano_dict = self.loss_class.calc_2d_rec_loss(
pred_cano['image_raw'],
cano_cropped_target['img'],
fg_mask,
step=self.step + self.resume_step,
test_mode=False,
)
loss = loss + loss_cano
# remove redundant log
log_rec3d_loss_dict({
f'cano_{k}': v
for k, v in loss_cano_dict.items()
# if "loss" in k
})
self.mp_trainer_rec.backward(loss)
if dist_util.get_rank() == 0 and self.step % 500 == 0:
self.log_patch_img(cropped_target, pred_nv, pred_cano)
@th.inference_mode()
def log_patch_img(self, micro, pred, pred_cano):
# gt_vis = th.cat([batch['img'], batch['depth']], dim=-1)
def norm_depth(pred_depth): # to [-1,1]
# pred_depth = pred['image_depth']
pred_depth = (pred_depth - pred_depth.min()) / (pred_depth.max() -
pred_depth.min())
return -(pred_depth * 2 - 1)
pred_img = pred['image_raw']
gt_img = micro['img']
# infer novel view also
# if self.loss_class.opt.symmetry_loss:
# pred_nv_img = nvs_pred
# else:
# ! replace with novel view prediction
# ! log another novel-view prediction
# pred_nv_img = self.rec_model(
# img=micro['img_to_encoder'],
# c=self.novel_view_poses) # pred: (B, 3, 64, 64)
# if 'depth' in micro:
gt_depth = micro['depth']
if gt_depth.ndim == 3:
gt_depth = gt_depth.unsqueeze(1)
gt_depth = norm_depth(gt_depth)
# gt_depth = (gt_depth - gt_depth.min()) / (gt_depth.max() -
# gt_depth.min())
# if True:
fg_mask = pred['image_mask'] * 2 - 1 # 0-1
input_fg_mask = pred_cano['image_mask'] * 2 - 1 # 0-1
if 'image_depth' in pred:
pred_depth = norm_depth(pred['image_depth'])
pred_cano_depth = norm_depth(pred_cano['image_depth'])
else:
pred_depth = th.zeros_like(gt_depth)
pred_cano_depth = th.zeros_like(gt_depth)
# if 'image_sr' in pred:
# if pred['image_sr'].shape[-1] == 512:
# pred_img = th.cat([self.pool_512(pred_img), pred['image_sr']],
# dim=-1)
# gt_img = th.cat([self.pool_512(micro['img']), micro['img_sr']],
# dim=-1)
# pred_depth = self.pool_512(pred_depth)
# gt_depth = self.pool_512(gt_depth)
# elif pred['image_sr'].shape[-1] == 256:
# pred_img = th.cat([self.pool_256(pred_img), pred['image_sr']],
# dim=-1)
# gt_img = th.cat([self.pool_256(micro['img']), micro['img_sr']],
# dim=-1)
# pred_depth = self.pool_256(pred_depth)
# gt_depth = self.pool_256(gt_depth)
# else:
# pred_img = th.cat([self.pool_128(pred_img), pred['image_sr']],
# dim=-1)
# gt_img = th.cat([self.pool_128(micro['img']), micro['img_sr']],
# dim=-1)
# gt_depth = self.pool_128(gt_depth)
# pred_depth = self.pool_128(pred_depth)
# else:
# gt_img = self.pool_64(gt_img)
# gt_depth = self.pool_64(gt_depth)
pred_vis = th.cat([
pred_img,
pred_depth.repeat_interleave(3, dim=1),
fg_mask.repeat_interleave(3, dim=1),
],
dim=-1) # B, 3, H, W
pred_vis_nv = th.cat([
pred_cano['image_raw'],
pred_cano_depth.repeat_interleave(3, dim=1),
input_fg_mask.repeat_interleave(3, dim=1),
],
dim=-1) # B, 3, H, W
pred_vis = th.cat([pred_vis, pred_vis_nv], dim=-2) # cat in H dim
gt_vis = th.cat([
gt_img,
gt_depth.repeat_interleave(3, dim=1),
th.zeros_like(gt_img)
],
dim=-1) # TODO, fail to load depth. range [0, 1]
# if 'conf_sigma' in pred:
# gt_vis = th.cat([gt_vis, fg_mask], dim=-1) # placeholder
# vis = th.cat([gt_vis, pred_vis], dim=-2)[0].permute(
# st()
vis = th.cat([gt_vis, pred_vis], dim=-2)
# .permute(
# 0, 2, 3, 1).cpu()
vis_tensor = torchvision.utils.make_grid(vis, nrow=vis.shape[-1] //
64) # HWC
torchvision.utils.save_image(
vis_tensor,
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg',
value_range=(-1, 1),
normalize=True)
logger.log('log vis to: ',
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
# self.writer.add_image(f'images',
# vis,
# self.step + self.resume_step,
# dataformats='HWC')
class TrainLoop3DRecNVPatchSingleForward(TrainLoop3DRecNVPatch):
def __init__(self,
*,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
def forward_backward(self, batch, *args, **kwargs):
# add patch sampling
self.mp_trainer_rec.zero_grad()
batch_size = batch['img_to_encoder'].shape[0]
batch.pop('caption') # not required
batch.pop('ins') # not required
# batch.pop('nv_caption') # not required
for i in range(0, batch_size, self.microbatch):
micro = {
k:
v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
v, th.Tensor) else v[i:i + self.microbatch]
for k, v in batch.items()
}
# ! sample rendering patch
target = {
**self.eg3d_model(
c=micro['nv_c'], # type: ignore
ws=None,
planes=None,
sample_ray_only=True,
fg_bbox=micro['nv_bbox']), # rays o / dir
}
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[
'patch_rendering_resolution'] # type: ignore
cropped_target = {
k:
th.empty_like(v)
[..., :patch_rendering_resolution, :patch_rendering_resolution]
if k not in [
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder',
'nv_img_sr', 'c', 'caption', 'nv_caption'
] else v
for k, v in micro.items()
}
# crop according to uv sampling
for j in range(micro['img'].shape[0]):
top, left, height, width = target['ray_bboxes'][
j] # list of tuple
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
for key in ('img', 'depth_mask', 'depth'): # type: ignore
# target[key][i:i+1] = torchvision.transforms.functional.crop(
# cropped_target[key][
# j:j + 1] = torchvision.transforms.functional.crop(
# micro[key][j:j + 1], top, left, height, width)
cropped_target[f'{key}'][ # ! no nv_ here
j:j + 1] = torchvision.transforms.functional.crop(
micro[f'nv_{key}'][j:j + 1], top, left, height,
width)
# ! cano view loss
cano_target = {
**self.eg3d_model(
c=micro['c'], # type: ignore
ws=None,
planes=None,
sample_ray_only=True,
fg_bbox=micro['bbox']), # rays o / dir
}
# cano_cropped_target = {
# k: th.empty_like(v)
# for k, v in cropped_target.items()
# }
# for j in range(micro['img'].shape[0]):
# top, left, height, width = cano_target['ray_bboxes'][
# j] # list of tuple
# # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
# for key in ('img', 'depth_mask', 'depth'): # type: ignore
# # target[key][i:i+1] = torchvision.transforms.functional.crop(
# cano_cropped_target[key][
# j:j + 1] = torchvision.transforms.functional.crop(
# micro[key][j:j + 1], top, left, height, width)
# ! vit no amp
latent = self.rec_model(img=micro['img_to_encoder'],
behaviour='enc_dec_wo_triplane')
# wrap forward within amp
with th.autocast(device_type='cuda',
dtype=th.float16,
enabled=self.mp_trainer_rec.use_amp):
# c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here
# c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here
instance_mv_num = batch_size // 4 # 4 pairs by default
# instance_mv_num = 4
# ! roll views for multi-view supervision
c = th.cat([
micro['nv_c'].roll(instance_mv_num * i, dims=0)
for i in range(1, 4)
]
# + [micro['c']]
) # predict novel view here
ray_origins = th.cat(
[
target['ray_origins'].roll(instance_mv_num * i, dims=0)
for i in range(1, 4)
]
# + [cano_target['ray_origins'] ]
,
0)
ray_directions = th.cat([
target['ray_directions'].roll(instance_mv_num * i, dims=0)
for i in range(1, 4)
]
# + [cano_target['ray_directions'] ]
)
pred_nv_cano = self.rec_model(
# latent=latent.expand(2,),
latent={
'latent_after_vit': # ! triplane for rendering
# latent['latent_after_vit'].repeat(2, 1, 1, 1)
latent['latent_after_vit'].repeat(3, 1, 1, 1)
},
c=c,
behaviour='triplane_dec',
# ray_origins=target['ray_origins'],
# ray_directions=target['ray_directions'],
ray_origins=ray_origins,
ray_directions=ray_directions,
)
pred_nv_cano.update(
latent
) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True)
# gt = {
# k: th.cat([v, cano_cropped_target[k]], 0)
# for k, v in cropped_target.items()
# }
gt = {
k:
th.cat(
[
v.roll(instance_mv_num * i, dims=0)
for i in range(1, 4)
]
# + [cano_cropped_target[k] ]
,
0)
for k, v in cropped_target.items()
} # torchvision.utils.save_image(gt['img'], 'gt.png', normalize=True)
with self.rec_model.no_sync(): # type: ignore
loss, loss_dict, _ = self.loss_class(
pred_nv_cano,
gt, # prepare merged data
step=self.step + self.resume_step,
test_mode=False,
return_fg_mask=True,
conf_sigma_l1=None,
conf_sigma_percl=None)
log_rec3d_loss_dict(loss_dict)
self.mp_trainer_rec.backward(loss)
# for name, p in self.rec_model.named_parameters():
# if p.grad is None:
# logger.log(f"found rec unused param: {name}")
if dist_util.get_rank() == 0 and self.step % 500 == 0:
micro_bs = micro['img_to_encoder'].shape[0]
self.log_patch_img( # record one cano view and one novel view
cropped_target,
{
k: pred_nv_cano[k][-micro_bs:]
for k in ['image_raw', 'image_depth', 'image_mask']
},
{
k: pred_nv_cano[k][:micro_bs]
for k in ['image_raw', 'image_depth', 'image_mask']
},
)
def eval_loop(self):
return super().eval_loop()
@th.inference_mode()
# def eval_loop(self, c_list:list):
def eval_novelview_loop_old(self, camera=None):
# novel view synthesis given evaluation camera trajectory
all_loss_dict = []
novel_view_micro = {}
# ! randomly inference an instance
export_mesh = True
if export_mesh:
Path(f'{logger.get_dir()}/FID_Cals/').mkdir(parents=True,
exist_ok=True)
# for i in range(0, len(c_list), 1): # TODO, larger batch size for eval
batch = {}
# if camera is not None:
# # batch['c'] = camera.to(batch['c'].device())
# batch['c'] = camera.clone()
# else:
# batch =
for eval_idx, render_reference in enumerate(tqdm(self.eval_data)):
if eval_idx > 500:
break
video_out = imageio.get_writer(
f'{logger.get_dir()}/video_novelview_{self.step+self.resume_step}_{eval_idx}.mp4',
mode='I',
fps=25,
codec='libx264')
with open(
f'{logger.get_dir()}/triplane_{self.step+self.resume_step}_{eval_idx}_caption.txt',
'w') as f:
f.write(render_reference['caption'])
for key in ['ins', 'bbox', 'caption']:
if key in render_reference:
render_reference.pop(key)
real_flag = False
mv_flag = False # TODO, use full-instance for evaluation? Calculate the metrics.
if render_reference['c'].shape[:2] == (1, 40):
real_flag = True
# real img monocular reconstruction
# compat lst for enumerate
render_reference = [{
k: v[0][idx:idx + 1]
for k, v in render_reference.items()
} for idx in range(40)]
elif render_reference['c'].shape[0] == 8:
mv_flag = True
render_reference = {
k: v[:4]
for k, v in render_reference.items()
}
# save gt
torchvision.utils.save_image(
render_reference[0:4]['img'],
logger.get_dir() + '/FID_Cals/{}_inp.png'.format(eval_idx),
padding=0,
normalize=True,
value_range=(-1, 1),
)
# torchvision.utils.save_image(render_reference[4:8]['img'],
# logger.get_dir() + '/FID_Cals/{}_inp2.png'.format(eval_idx),
# padding=0,
# normalize=True,
# value_range=(-1,1),
# )
else:
# compat lst for enumerate
st()
render_reference = [{
k: v[idx:idx + 1]
for k, v in render_reference.items()
} for idx in range(40)]
# ! single-view version
render_reference[0]['img_to_encoder'] = render_reference[14][
'img_to_encoder'] # encode side view
render_reference[0]['img'] = render_reference[14][
'img'] # encode side view
# save gt
torchvision.utils.save_image(
render_reference[0]['img'],
logger.get_dir() + '/FID_Cals/{}_gt.png'.format(eval_idx),
padding=0,
normalize=True,
value_range=(-1, 1))
# ! TODO, merge with render_video_given_triplane later
for i, batch in enumerate(render_reference):
# for i in range(0, 8, self.microbatch):
# c = c_list[i].to(dist_util.dev()).reshape(1, -1)
micro = {k: v.to(dist_util.dev()) for k, v in batch.items()}
st()
if i == 0:
if mv_flag:
novel_view_micro = None
else:
novel_view_micro = {
k:
v[0:1].to(dist_util.dev()).repeat_interleave(
# v[14:15].to(dist_util.dev()).repeat_interleave(
micro['img'].shape[0],
0) if isinstance(v, th.Tensor) else v[0:1]
for k, v in batch.items()
}
else:
if i == 1:
# ! output mesh
if export_mesh:
# ! get planes first
# self.latent_name = 'latent_normalized' # normalized triplane latent
# ddpm_latent = {
# self.latent_name: planes,
# }
# ddpm_latent.update(self.rec_model(latent=ddpm_latent, behaviour='decode_after_vae_no_render'))
# mesh_size = 512
# mesh_size = 256
mesh_size = 384
# mesh_size = 320
# mesh_thres = 3 # TODO, requires tuning
# mesh_thres = 5 # TODO, requires tuning
mesh_thres = 10 # TODO, requires tuning
import mcubes
import trimesh
dump_path = f'{logger.get_dir()}/mesh/'
os.makedirs(dump_path, exist_ok=True)
grid_out = self.rec_model(
latent=pred,
grid_size=mesh_size,
behaviour='triplane_decode_grid',
)
vtx, faces = mcubes.marching_cubes(
grid_out['sigma'].squeeze(0).squeeze(
-1).cpu().numpy(), mesh_thres)
vtx = vtx / (mesh_size - 1) * 2 - 1
# vtx_tensor = th.tensor(vtx, dtype=th.float32, device=dist_util.dev()).unsqueeze(0)
# vtx_colors = self.model.synthesizer.forward_points(planes, vtx_tensor)['rgb'].squeeze(0).cpu().numpy() # (0, 1)
# vtx_colors = (vtx_colors * 255).astype(np.uint8)
# mesh = trimesh.Trimesh(vertices=vtx, faces=faces, vertex_colors=vtx_colors)
mesh = trimesh.Trimesh(
vertices=vtx,
faces=faces,
)
mesh_dump_path = os.path.join(
dump_path, f'{eval_idx}.ply')
mesh.export(mesh_dump_path, 'ply')
print(f"Mesh dumped to {dump_path}")
del grid_out, mesh
th.cuda.empty_cache()
# return
# st()
# if novel_view_micro['c'].shape[0] < micro['img'].shape[0]:
novel_view_micro = {
k:
v[0:1].to(dist_util.dev()).repeat_interleave(
micro['img'].shape[0], 0)
for k, v in novel_view_micro.items()
}
pred = self.rec_model(img=novel_view_micro['img_to_encoder'],
c=micro['c']) # pred: (B, 3, 64, 64)
# target = {
# 'img': micro['img'],
# 'depth': micro['depth'],
# 'depth_mask': micro['depth_mask']
# }
# targe
# if not export_mesh:
if not real_flag:
_, loss_dict = self.loss_class(pred, micro, test_mode=True)
all_loss_dict.append(loss_dict)
# ! move to other places, add tensorboard
# pred_vis = th.cat([
# pred['image_raw'],
# -pred['image_depth'].repeat_interleave(3, dim=1)
# ],
# dim=-1)
# normalize depth
# if True:
pred_depth = pred['image_depth']
pred_depth = (pred_depth - pred_depth.min()) / (
pred_depth.max() - pred_depth.min())
if 'image_sr' in pred:
if pred['image_sr'].shape[-1] == 512:
pred_vis = th.cat([
micro['img_sr'],
self.pool_512(pred['image_raw']), pred['image_sr'],
self.pool_512(pred_depth).repeat_interleave(3,
dim=1)
],
dim=-1)
elif pred['image_sr'].shape[-1] == 256:
pred_vis = th.cat([
micro['img_sr'],
self.pool_256(pred['image_raw']), pred['image_sr'],
self.pool_256(pred_depth).repeat_interleave(3,
dim=1)
],
dim=-1)
else:
pred_vis = th.cat([
micro['img_sr'],
self.pool_128(pred['image_raw']),
self.pool_128(pred['image_sr']),
self.pool_128(pred_depth).repeat_interleave(3,
dim=1)
],
dim=-1)
else:
# pred_vis = th.cat([
# self.pool_64(micro['img']), pred['image_raw'],
# pred_depth.repeat_interleave(3, dim=1)
# ],
# dim=-1) # B, 3, H, W
pooled_depth = self.pool_128(pred_depth).repeat_interleave(
3, dim=1)
pred_vis = th.cat(
[
# self.pool_128(micro['img']),
self.pool_128(novel_view_micro['img']
), # use the input here
self.pool_128(pred['image_raw']),
pooled_depth,
],
dim=-1) # B, 3, H, W
vis = pred_vis.permute(0, 2, 3, 1).cpu().numpy()
vis = vis * 127.5 + 127.5
vis = vis.clip(0, 255).astype(np.uint8)
if export_mesh:
# save image
torchvision.utils.save_image(
pred['image_raw'],
logger.get_dir() +
'/FID_Cals/{}_{}.png'.format(eval_idx, i),
padding=0,
normalize=True,
value_range=(-1, 1))
torchvision.utils.save_image(
pooled_depth,
logger.get_dir() +
'/FID_Cals/{}_{}_dpeth.png'.format(eval_idx, i),
padding=0,
normalize=True,
value_range=(0, 1))
# st()
for j in range(vis.shape[0]):
video_out.append_data(vis[j])
video_out.close()
# if not export_mesh:
if not real_flag or mv_flag:
val_scores_for_logging = calc_average_loss(all_loss_dict)
with open(os.path.join(logger.get_dir(), 'scores_novelview.json'),
'a') as f:
json.dump({'step': self.step, **val_scores_for_logging}, f)
# * log to tensorboard
for k, v in val_scores_for_logging.items():
self.writer.add_scalar(f'Eval/NovelView/{k}', v,
self.step + self.resume_step)
del video_out
# del pred_vis
# del pred
th.cuda.empty_cache()
@th.inference_mode()
# def eval_loop(self, c_list:list):
def eval_novelview_loop(self, camera=None, save_latent=False):
# novel view synthesis given evaluation camera trajectory
if save_latent: # for diffusion learning
latent_dir = Path(f'{logger.get_dir()}/latent_dir')
latent_dir.mkdir(exist_ok=True, parents=True)
# wds_path = os.path.join(logger.get_dir(), 'latent_dir',
# f'wds-%06d.tar')
# sink = wds.ShardWriter(wds_path, start_shard=0)
# eval_batch_size = 20
# eval_batch_size = 1
eval_batch_size = 40 # ! for i23d
for eval_idx, micro in enumerate(tqdm(self.eval_data)):
# if eval_idx > 500:
# break
latent = self.rec_model(
img=micro['img_to_encoder'][:4],
behaviour='encoder_vae') # pred: (B, 3, 64, 64)
# torchvision.utils.save_image(micro['img'], 'inp.jpg')
if micro['img'].shape[0] == 40:
assert eval_batch_size == 40
if save_latent:
# np.save(f'{logger.get_dir()}/latent_dir/{eval_idx}.npy', latent[self.latent_name].cpu().numpy())
latent_save_dir = f'{logger.get_dir()}/latent_dir/{micro["ins"][0]}'
Path(latent_save_dir).mkdir(parents=True, exist_ok=True)
np.save(f'{latent_save_dir}/latent.npy',
latent[self.latent_name][0].cpu().numpy())
assert all([
micro['ins'][0] == micro['ins'][i]
for i in range(micro['c'].shape[0])
]) # ! assert same instance
# for i in range(micro['img'].shape[0]):
# compressed_sample = {
# 'latent':latent[self.latent_name][0].cpu().numpy(), # 12 32 32
# 'caption': micro['caption'][0].encode('utf-8'),
# 'ins': micro['ins'][0].encode('utf-8'),
# 'c': micro['c'][i].cpu().numpy(),
# 'img': micro['img'][i].cpu().numpy() # 128x128, for diffusion log
# }
# sink.write({
# "__key__": f"sample_{eval_idx*eval_batch_size+i:07d}",
# 'sample.pyd': compressed_sample
# })
if eval_idx < 50:
# if False:
self.render_video_given_triplane(
latent[self.latent_name], # B 12 32 32
self.rec_model, # compatible with join_model
name_prefix=f'{self.step + self.resume_step}_{eval_idx}',
save_img=False,
render_reference={'c': camera},
save_mesh=True)
class TrainLoop3DRecNVPatchSingleForwardMV(TrainLoop3DRecNVPatchSingleForward):
def __init__(self,
*,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
def forward_backward(self, batch, behaviour='g_step', *args, **kwargs):
# add patch sampling
self.mp_trainer_rec.zero_grad()
batch_size = batch['img_to_encoder'].shape[0]
batch.pop('caption') # not required
batch.pop('ins') # not required
if '__key__' in batch.keys():
batch.pop('__key__')
for i in range(0, batch_size, self.microbatch):
micro = {
k:
v[i:i + self.microbatch].to(dist_util.dev()) if isinstance(
v, th.Tensor) else v[i:i + self.microbatch]
for k, v in batch.items()
}
# ! sample rendering patch
# nv_c = th.cat([micro['nv_c'], micro['c']])
nv_c = th.cat([micro['nv_c'], micro['c']])
# nv_c = micro['nv_c']
target = {
**self.eg3d_model(
c=nv_c, # type: ignore
ws=None,
planes=None,
sample_ray_only=True,
fg_bbox=th.cat([micro['nv_bbox'], micro['bbox']])), # rays o / dir
}
patch_rendering_resolution = self.eg3d_model.rendering_kwargs[
'patch_rendering_resolution'] # type: ignore
cropped_target = {
k:
th.empty_like(v).repeat_interleave(2, 0)
# th.empty_like(v).repeat_interleave(1, 0)
[..., :patch_rendering_resolution, :patch_rendering_resolution]
if k not in [
'ins_idx', 'img_to_encoder', 'img_sr', 'nv_img_to_encoder',
'nv_img_sr', 'c', 'caption', 'nv_caption'
] else v
for k, v in micro.items()
}
# crop according to uv sampling
for j in range(2 * self.microbatch):
top, left, height, width = target['ray_bboxes'][
j] # list of tuple
# for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
for key in ('img', 'depth_mask', 'depth'): # type: ignore
if j < self.microbatch:
cropped_target[f'{key}'][ # ! no nv_ here
j:j + 1] = torchvision.transforms.functional.crop(
micro[f'nv_{key}'][j:j + 1], top, left, height,
width)
else:
cropped_target[f'{key}'][ # ! no nv_ here
j:j + 1] = torchvision.transforms.functional.crop(
micro[f'{key}'][j - self.microbatch:j -
self.microbatch + 1], top,
left, height, width)
# for j in range(batch_size, 2*batch_size, 1):
# top, left, height, width = target['ray_bboxes'][
# j] # list of tuple
# # for key in ('img', 'depth_mask', 'depth', 'depth_mask_sr'): # type: ignore
# for key in ('img', 'depth_mask', 'depth'): # type: ignore
# cropped_target[f'{key}'][ # ! no nv_ here
# j:j + 1] = torchvision.transforms.functional.crop(
# micro[f'{key}'][j-batch_size:j-batch_size + 1], top, left, height,
# width)
# ! vit no amp
latent = self.rec_model(img=micro['img_to_encoder'],
behaviour='enc_dec_wo_triplane')
# wrap forward within amp
with th.autocast(device_type='cuda',
dtype=th.float16,
enabled=self.mp_trainer_rec.use_amp):
# c = th.cat([micro['nv_c'], micro['c']]), # predict novel view here
# c = th.cat([micro['nv_c'].repeat(3, 1), micro['c']]), # predict novel view here
# instance_mv_num = batch_size // 4 # 4 pairs by default
# instance_mv_num = 4
# ! roll views for multi-view supervision
# c = micro['nv_c']
ray_origins = target['ray_origins']
ray_directions = target['ray_directions']
pred_nv_cano = self.rec_model(
# latent=latent.expand(2,),
latent={
'latent_after_vit': # ! triplane for rendering
latent['latent_after_vit'].repeat_interleave(4, dim=0).repeat(2,1,1,1) # NV=4
# latent['latent_after_vit'].repeat_interleave(8, dim=0) # NV=4
},
c=nv_c,
behaviour='triplane_dec',
ray_origins=ray_origins,
ray_directions=ray_directions,
)
pred_nv_cano.update(
latent
) # torchvision.utils.save_image(pred_nv_cano['image_raw'], 'pred.png', normalize=True)
gt = cropped_target
with self.rec_model.no_sync(): # type: ignore
loss, loss_dict, _ = self.loss_class(
pred_nv_cano,
gt, # prepare merged data
step=self.step + self.resume_step,
test_mode=False,
return_fg_mask=True,
behaviour=behaviour,
conf_sigma_l1=None,
conf_sigma_percl=None)
log_rec3d_loss_dict(loss_dict)
self.mp_trainer_rec.backward(loss)
# for name, p in self.rec_model.named_parameters():
# if p.grad is None:
# logger.log(f"found rec unused param: {name}")
# torchvision.utils.save_image(cropped_target['img'], 'gt.png', normalize=True)
# torchvision.utils.save_image( pred_nv_cano['image_raw'], 'pred.png', normalize=True)
if dist_util.get_rank() == 0 and self.step % 500 == 0 and i == 0:
try:
torchvision.utils.save_image(
th.cat(
[cropped_target['img'], pred_nv_cano['image_raw']
], ),
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg',
normalize=True)
logger.log(
'log vis to: ',
f'{logger.get_dir()}/{self.step+self.resume_step}.jpg')
except Exception as e:
logger.log(e)
# micro_bs = micro['img_to_encoder'].shape[0]
# self.log_patch_img( # record one cano view and one novel view
# cropped_target,
# {
# k: pred_nv_cano[k][0:1]
# for k in ['image_raw', 'image_depth', 'image_mask']
# },
# {
# k: pred_nv_cano[k][1:2]
# for k in ['image_raw', 'image_depth', 'image_mask']
# },
# )
# def save(self):
# return super().save()
class TrainLoop3DRecNVPatchSingleForwardMVAdvLoss(
TrainLoop3DRecNVPatchSingleForwardMV):
def __init__(self,
*,
rec_model,
loss_class,
data,
eval_data,
batch_size,
microbatch,
lr,
ema_rate,
log_interval,
eval_interval,
save_interval,
resume_checkpoint,
use_fp16=False,
fp16_scale_growth=0.001,
weight_decay=0,
lr_anneal_steps=0,
iterations=10001,
load_submodule_name='',
ignore_resume_opt=False,
model_name='rec',
use_amp=False,
**kwargs):
super().__init__(rec_model=rec_model,
loss_class=loss_class,
data=data,
eval_data=eval_data,
batch_size=batch_size,
microbatch=microbatch,
lr=lr,
ema_rate=ema_rate,
log_interval=log_interval,
eval_interval=eval_interval,
save_interval=save_interval,
resume_checkpoint=resume_checkpoint,
use_fp16=use_fp16,
fp16_scale_growth=fp16_scale_growth,
weight_decay=weight_decay,
lr_anneal_steps=lr_anneal_steps,
iterations=iterations,
load_submodule_name=load_submodule_name,
ignore_resume_opt=ignore_resume_opt,
model_name=model_name,
use_amp=use_amp,
**kwargs)
# create discriminator
disc_params = self.loss_class.get_trainable_parameters()
self.mp_trainer_disc = MixedPrecisionTrainer(
model=self.loss_class.discriminator,
use_fp16=self.use_fp16,
fp16_scale_growth=fp16_scale_growth,
model_name='disc',
use_amp=use_amp,
model_params=disc_params)
# st() # check self.lr
self.opt_disc = AdamW(
self.mp_trainer_disc.master_params,
lr=self.lr, # follow sd code base
betas=(0, 0.999),
eps=1e-8)
# TODO, is loss cls already in the DDP?
if self.use_ddp:
self.ddp_disc = DDP(
self.loss_class.discriminator,
device_ids=[dist_util.dev()],
output_device=dist_util.dev(),
broadcast_buffers=False,
bucket_cap_mb=128,
find_unused_parameters=False,
)
else:
self.ddp_disc = self.loss_class.discriminator
# def run_st
# def run_step(self, batch, *args):
# self.forward_backward(batch)
# took_step = self.mp_trainer_rec.optimize(self.opt)
# if took_step:
# self._update_ema()
# self._anneal_lr()
# self.log_step()
def save(self, mp_trainer=None, model_name='rec'):
if mp_trainer is None:
mp_trainer = self.mp_trainer_rec
def save_checkpoint(rate, params):
state_dict = mp_trainer.master_params_to_state_dict(params)
if dist_util.get_rank() == 0:
logger.log(f"saving model {model_name} {rate}...")
if not rate:
filename = f"model_{model_name}{(self.step+self.resume_step):07d}.pt"
else:
filename = f"ema_{model_name}_{rate}_{(self.step+self.resume_step):07d}.pt"
with bf.BlobFile(bf.join(get_blob_logdir(), filename),
"wb") as f:
th.save(state_dict, f)
save_checkpoint(0, mp_trainer.master_params)
dist.barrier()
def run_step(self, batch, step='g_step'):
# self.forward_backward(batch)
if step == 'g_step':
self.forward_backward(batch, behaviour='g_step')
took_step_g_rec = self.mp_trainer_rec.optimize(self.opt)
if took_step_g_rec:
self._update_ema() # g_ema
elif step == 'd_step':
self.forward_backward(batch, behaviour='d_step')
_ = self.mp_trainer_disc.optimize(self.opt_disc)
self._anneal_lr()
self.log_step()
def run_loop(self, batch=None):
while (not self.lr_anneal_steps
or self.step + self.resume_step < self.lr_anneal_steps):
batch = next(self.data)
self.run_step(batch, 'g_step')
batch = next(self.data)
self.run_step(batch, 'd_step')
if self.step % 1000 == 0:
dist_util.synchronize()
if self.step % 10000 == 0:
th.cuda.empty_cache() # avoid memory leak
if self.step % self.log_interval == 0 and dist_util.get_rank(
) == 0:
out = logger.dumpkvs()
# * log to tensorboard
for k, v in out.items():
self.writer.add_scalar(f'Loss/{k}', v,
self.step + self.resume_step)
if self.step % self.eval_interval == 0 and self.step != 0:
if dist_util.get_rank() == 0:
try:
self.eval_loop()
except Exception as e:
logger.log(e)
dist_util.synchronize()
# if self.step % self.save_interval == 0 and self.step != 0:
if self.step % self.save_interval == 0:
self.save()
self.save(self.mp_trainer_disc,
self.mp_trainer_disc.model_name)
dist_util.synchronize()
# Run for a finite amount of time in integration tests.
if os.environ.get("DIFFUSION_TRAINING_TEST",
"") and self.step > 0:
return
self.step += 1
if self.step > self.iterations:
logger.log('reached maximum iterations, exiting')
# Save the last checkpoint if it wasn't already saved.
if (self.step -
1) % self.save_interval != 0 and self.step != 1:
self.save()
exit()
# Save the last checkpoint if it wasn't already saved.
# if (self.step - 1) % self.save_interval != 0 and self.step != 1:
if (self.step - 1) % self.save_interval != 0:
self.save() # save rec
self.save(self.mp_trainer_disc, self.mp_trainer_disc.model_name)