import math import random from einops import rearrange import torch from torch import nn from torch.nn import functional as F import numpy as np from tqdm import trange from functools import partial from nsr.networks_stylegan2 import Generator as StyleGAN2Backbone from nsr.volumetric_rendering.renderer import ImportanceRenderer, ImportanceRendererfg_bg from nsr.volumetric_rendering.ray_sampler import RaySampler from nsr.triplane import OSGDecoder, Triplane, Triplane_fg_bg_plane # from nsr.losses.helpers import ResidualBlock # from vit.vision_transformer import TriplaneFusionBlockv4_nested, VisionTransformer, TriplaneFusionBlockv4_nested_init_from_dino from vit.vision_transformer import TriplaneFusionBlockv4_nested, TriplaneFusionBlockv4_nested_init_from_dino_lite, TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, VisionTransformer, TriplaneFusionBlockv4_nested_init_from_dino from .vision_transformer import Block, VisionTransformer from .utils import trunc_normal_ from guided_diffusion import dist_util, logger from pdb import set_trace as st from ldm.modules.diffusionmodules.model import Encoder, Decoder from utils.torch_utils.components import PixelShuffleUpsample, ResidualBlock, Upsample, PixelUnshuffleUpsample, Conv3x3TriplaneTransformation from utils.torch_utils.distributions.distributions import DiagonalGaussianDistribution from nsr.superresolution import SuperresolutionHybrid2X, SuperresolutionHybrid4X from torch.nn.parameter import Parameter, UninitializedParameter, UninitializedBuffer from nsr.common_blks import ResMlp from .vision_transformer import * from dit.dit_models import get_2d_sincos_pos_embed from torch import _assert from itertools import repeat import collections.abc # From PyTorch internals def _ntuple(n): def parse(x): if isinstance(x, collections.abc.Iterable) and not isinstance(x, str): return tuple(x) return tuple(repeat(x, n)) return parse to_1tuple = _ntuple(1) to_2tuple = _ntuple(2) class PatchEmbedTriplane(nn.Module): """ GroupConv patchembeder on triplane """ def __init__( self, img_size=32, patch_size=2, in_chans=4, embed_dim=768, norm_layer=None, flatten=True, bias=True, ): super().__init__() img_size = to_2tuple(img_size) patch_size = to_2tuple(patch_size) self.img_size = img_size self.patch_size = patch_size self.grid_size = (img_size[0] // patch_size[0], img_size[1] // patch_size[1]) self.num_patches = self.grid_size[0] * self.grid_size[1] self.flatten = flatten self.proj = nn.Conv2d(in_chans, embed_dim * 3, kernel_size=patch_size, stride=patch_size, bias=bias, groups=3) self.norm = norm_layer(embed_dim) if norm_layer else nn.Identity() def forward(self, x): B, C, H, W = x.shape _assert( H == self.img_size[0], f"Input image height ({H}) doesn't match model ({self.img_size[0]})." ) _assert( W == self.img_size[1], f"Input image width ({W}) doesn't match model ({self.img_size[1]})." ) x = self.proj(x) # B 3*C token_H token_W x = x.reshape(B, x.shape[1] // 3, 3, x.shape[-2], x.shape[-1]) # B C 3 H W if self.flatten: x = x.flatten(2).transpose(1, 2) # BC3HW -> B 3HW C x = self.norm(x) return x class PatchEmbedTriplaneRodin(PatchEmbedTriplane): def __init__(self, img_size=32, patch_size=2, in_chans=4, embed_dim=768, norm_layer=None, flatten=True, bias=True): super().__init__(img_size, patch_size, in_chans, embed_dim, norm_layer, flatten, bias) self.proj = RodinRollOutConv3D_GroupConv(in_chans, embed_dim * 3, kernel_size=patch_size, stride=patch_size, padding=0) class ViTTriplaneDecomposed(nn.Module): def __init__( self, vit_decoder, triplane_decoder: Triplane, cls_token=False, decoder_pred_size=-1, unpatchify_out_chans=-1, # * uvit arch channel_multiplier=4, use_fusion_blk=True, fusion_blk_depth=4, fusion_blk=TriplaneFusionBlock, fusion_blk_start=0, # appy fusion blk start with? ldm_z_channels=4, # ldm_embed_dim=4, vae_p=2, token_size=None, w_avg=torch.zeros([512]), patch_size=None, **kwargs, ) -> None: super().__init__() # self.superresolution = None self.superresolution = nn.ModuleDict({}) self.decomposed_IN = False self.decoder_pred_3d = None self.transformer_3D_blk = None self.logvar = None self.channel_multiplier = channel_multiplier self.cls_token = cls_token self.vit_decoder = vit_decoder self.triplane_decoder = triplane_decoder if patch_size is None: self.patch_size = self.vit_decoder.patch_embed.patch_size else: self.patch_size = patch_size if isinstance(self.patch_size, tuple): # dino-v2 self.patch_size = self.patch_size[0] # self.img_size = self.vit_decoder.patch_embed.img_size if unpatchify_out_chans == -1: self.unpatchify_out_chans = self.triplane_decoder.out_chans else: self.unpatchify_out_chans = unpatchify_out_chans # ! mlp decoder from mae/dino if decoder_pred_size == -1: decoder_pred_size = self.patch_size**2 * self.triplane_decoder.out_chans self.decoder_pred = nn.Linear( self.vit_decoder.embed_dim, decoder_pred_size, # self.patch_size**2 * # self.triplane_decoder.out_chans, bias=True) # decoder to pat # st() # triplane self.plane_n = 3 # ! vae self.ldm_z_channels = ldm_z_channels self.ldm_embed_dim = ldm_embed_dim self.vae_p = vae_p self.token_size = 16 # use dino-v2 dim tradition here self.vae_res = self.vae_p * self.token_size # ! uvit # if token_size is None: # token_size = 224 // self.patch_size # logger.log('token_size: {}', token_size) self.vit_decoder.pos_embed = nn.Parameter( torch.zeros(1, 3 * (self.token_size**2 + self.cls_token), vit_decoder.embed_dim)) self.fusion_blk_start = fusion_blk_start self.create_fusion_blks(fusion_blk_depth, use_fusion_blk, fusion_blk) # self.vit_decoder.cls_token = self.vit_decoder.cls_token.clone().repeat_interleave(3, dim=0) # each plane has a separate cls token # translate # ! placeholder, not used here self.register_buffer('w_avg', w_avg) # will replace externally self.rendering_kwargs = self.triplane_decoder.rendering_kwargs @torch.inference_mode() def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**16): # planes: (N, 3, D', H', W') # points: (N, P, 3) N, P = points.shape[:2] if planes.ndim == 4: planes = planes.reshape( len(planes), 3, -1, # ! support background plane planes.shape[-2], planes.shape[-1]) # BS 96 256 256 # query triplane in chunks outs = [] for i in trange(0, points.shape[1], chunk_size): chunk_points = points[:, i:i+chunk_size] # query triplane # st() chunk_out = self.triplane_decoder.renderer._run_model( # type: ignore planes=planes, decoder=self.triplane_decoder.decoder, sample_coordinates=chunk_points, sample_directions=torch.zeros_like(chunk_points), options=self.rendering_kwargs, ) # st() outs.append(chunk_out) torch.cuda.empty_cache() # st() # concatenate the outputs point_features = { k: torch.cat([out[k] for out in outs], dim=1) for k in outs[0].keys() } return point_features def triplane_decode_grid(self, vit_decode_out, grid_size, aabb: torch.Tensor = None, **kwargs): # planes: (N, 3, D', H', W') # grid_size: int assert isinstance(vit_decode_out, dict) planes = vit_decode_out['latent_after_vit'] # aabb: (N, 2, 3) if aabb is None: if 'sampler_bbox_min' in self.rendering_kwargs: aabb = torch.tensor([ [self.rendering_kwargs['sampler_bbox_min']] * 3, [self.rendering_kwargs['sampler_bbox_max']] * 3, ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) else: # shapenet dataset, follow eg3d aabb = torch.tensor([ # https://github.com/NVlabs/eg3d/blob/7cf1fd1e99e1061e8b6ba850f91c94fe56e7afe4/eg3d/gen_samples.py#L188 [-self.rendering_kwargs['box_warp']/2] * 3, [self.rendering_kwargs['box_warp']/2] * 3, ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat(planes.shape[0], 1, 1) assert planes.shape[0] == aabb.shape[0], "Batch size mismatch for planes and aabb" N = planes.shape[0] # create grid points for triplane query grid_points = [] for i in range(N): grid_points.append(torch.stack(torch.meshgrid( torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), indexing='ij', ), dim=-1).reshape(-1, 3)) cube_grid = torch.stack(grid_points, dim=0).to(planes.device) # 1 N 3 # st() features = self.forward_points(planes, cube_grid) # reshape into grid features = { k: v.reshape(N, grid_size, grid_size, grid_size, -1) for k, v in features.items() } # st() return features def create_uvit_arch(self): # create skip linear logger.log( f'length of vit_decoder.blocks: {len(self.vit_decoder.blocks)}') for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: blk.skip_linear = nn.Linear(2 * self.vit_decoder.embed_dim, self.vit_decoder.embed_dim) # trunc_normal_(blk.skip_linear.weight, std=.02) nn.init.constant_(blk.skip_linear.weight, 0) if isinstance(blk.skip_linear, nn.Linear) and blk.skip_linear.bias is not None: nn.init.constant_(blk.skip_linear.bias, 0) # def vit_decode_backbone(self, latent, img_size): return self.forward_vit_decoder(latent, img_size) # pred_vit_latent def init_weights(self): # Initialize (and freeze) pos_embed by sin-cos embedding: p = self.token_size D = self.vit_decoder.pos_embed.shape[-1] grid_size = (3 * p, p) pos_embed = get_2d_sincos_pos_embed(D, grid_size).reshape(3 * p * p, D) # H*W, D self.vit_decoder.pos_embed.data.copy_( torch.from_numpy(pos_embed).float().unsqueeze(0)) logger.log('init pos_embed with sincos') # ! def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): vit_decoder_blks = self.vit_decoder.blocks assert len(vit_decoder_blks) == 12, 'ViT-B by default' nh = self.vit_decoder.blocks[0].attn.num_heads dim = self.vit_decoder.embed_dim fusion_blk_start = self.fusion_blk_start triplane_fusion_vit_blks = nn.ModuleList() if fusion_blk_start != 0: for i in range(0, fusion_blk_start): triplane_fusion_vit_blks.append( vit_decoder_blks[i]) # append all vit blocks in the front for i in range(fusion_blk_start, len(vit_decoder_blks), fusion_blk_depth): vit_blks_group = vit_decoder_blks[i:i + fusion_blk_depth] # moduleList triplane_fusion_vit_blks.append( # TriplaneFusionBlockv2(vit_blks_group, nh, dim, use_fusion_blk)) fusion_blk(vit_blks_group, nh, dim, use_fusion_blk)) self.vit_decoder.blocks = triplane_fusion_vit_blks def triplane_decode(self, latent, c): ret_dict = self.triplane_decoder(latent, c) # triplane latent -> imgs ret_dict.update({'latent': latent}) return ret_dict def triplane_renderer(self, latent, coordinates, directions): planes = latent.view(len(latent), 3, self.triplane_decoder.decoder_in_chans, latent.shape[-2], latent.shape[-1]) # BS 96 256 256 ret_dict = self.triplane_decoder.renderer.run_model( planes, self.triplane_decoder.decoder, coordinates, directions, self.triplane_decoder.rendering_kwargs) # triplane latent -> imgs # ret_dict.update({'latent': latent}) return ret_dict # * increase encoded encoded latent dim to match decoder # ! util functions def unpatchify_triplane(self, x, p=None, unpatchify_out_chans=None): """ x: (N, L, patch_size**2 * self.out_chans) imgs: (N, self.out_chans, H, W) """ if unpatchify_out_chans is None: unpatchify_out_chans = self.unpatchify_out_chans // 3 # p = self.vit_decoder.patch_size if self.cls_token: # TODO, how to better use cls token x = x[:, 1:] if p is None: # assign upsample patch size p = self.patch_size h = w = int((x.shape[1] // 3)**.5) assert h * w * 3 == x.shape[1] x = x.reshape(shape=(x.shape[0], 3, h, w, p, p, unpatchify_out_chans)) x = torch.einsum('ndhwpqc->ndchpwq', x) # nplanes, C order in the renderer.py triplanes = x.reshape(shape=(x.shape[0], unpatchify_out_chans * 3, h * p, h * p)) return triplanes def interpolate_pos_encoding(self, x, w, h): previous_dtype = x.dtype npatch = x.shape[1] - 1 N = self.vit_decoder.pos_embed.shape[1] - 1 # type: ignore # if npatch == N and w == h: # assert npatch == N and w == h return self.vit_decoder.pos_embed # pos_embed = self.vit_decoder.pos_embed.float() # return pos_embed class_pos_embed = pos_embed[:, 0] # type: ignore patch_pos_embed = pos_embed[:, 1:] # type: ignore dim = x.shape[-1] w0 = w // self.patch_size h0 = h // self.patch_size # we add a small number to avoid floating point error in the interpolation # see discussion at https://github.com/facebookresearch/dino/issues/8 w0, h0 = w0 + 0.1, h0 + 0.1 # patch_pos_embed = nn.functional.interpolate( # patch_pos_embed.reshape(1, 3, int(math.sqrt(N//3)), int(math.sqrt(N//3)), dim).permute(0, 4, 1, 2, 3), # scale_factor=(w0 / math.sqrt(N), h0 / math.sqrt(N)), # mode="bicubic", # ) # ! no interpolation needed, just add, since the resolution shall match # assert int(w0) == patch_pos_embed.shape[-2] and int(h0) == patch_pos_embed.shape[-1] patch_pos_embed = patch_pos_embed.permute(0, 2, 3, 1).view(1, -1, dim) return torch.cat((class_pos_embed.unsqueeze(0), patch_pos_embed), dim=1).to(previous_dtype) def forward_vit_decoder(self, x, img_size=None): # latent: (N, L, C) from DINO/CLIP ViT encoder # * also dino ViT # add positional encoding to each token if img_size is None: img_size = self.img_size if self.cls_token: x = x + self.vit_decoder.interpolate_pos_encoding( x, img_size, img_size)[:, :] # B, L, C else: x = x + self.vit_decoder.interpolate_pos_encoding( x, img_size, img_size)[:, 1:] # B, L, C for blk in self.vit_decoder.blocks: x = blk(x) x = self.vit_decoder.norm(x) return x def unpatchify(self, x, p=None, unpatchify_out_chans=None): """ x: (N, L, patch_size**2 * self.out_chans) imgs: (N, self.out_chans, H, W) """ # st() if unpatchify_out_chans is None: unpatchify_out_chans = self.unpatchify_out_chans # p = self.vit_decoder.patch_size if self.cls_token: # TODO, how to better use cls token x = x[:, 1:] if p is None: # assign upsample patch size p = self.patch_size h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, unpatchify_out_chans)) x = torch.einsum('nhwpqc->nchpwq', x) imgs = x.reshape(shape=(x.shape[0], unpatchify_out_chans, h * p, h * p)) return imgs def forward(self, latent, c, img_size): latent = self.forward_vit_decoder(latent, img_size) # pred_vit_latent if self.cls_token: # latent, cls_token = latent[:, 1:], latent[:, :1] cls_token = latent[:, :1] else: cls_token = None # ViT decoder projection, from MAE latent = self.decoder_pred( latent) # pred_vit_latent -> patch or original size # st() latent = self.unpatchify( latent) # spatial_vit_latent, B, C, H, W (B, 96, 256,256) # TODO 2D convolutions -> Triplane # * triplane rendering # ret_dict = self.forward_triplane_decoder(latent, # c) # triplane latent -> imgs ret_dict = self.triplane_decoder(planes=latent, c=c) ret_dict.update({'latent': latent, 'cls_token': cls_token}) return ret_dict class VAE_LDM_V4_vit3D_v3_conv3D_depth2_xformer_mha_PEinit_2d_sincos_uvit_RodinRollOutConv_4x4_lite_mlp_unshuffle_4XC_final( ViTTriplaneDecomposed): """ 1. reuse attention proj layer from dino 2. reuse attention; first self then 3D cross attention """ """ 4*4 SR with 2X channels """ def __init__( self, vit_decoder: VisionTransformer, triplane_decoder: Triplane, cls_token, # normalize_feat=True, # sr_ratio=2, use_fusion_blk=True, fusion_blk_depth=2, channel_multiplier=4, fusion_blk=TriplaneFusionBlockv3, **kwargs) -> None: super().__init__( vit_decoder, triplane_decoder, cls_token, # normalize_feat, # sr_ratio, fusion_blk=fusion_blk, # type: ignore use_fusion_blk=use_fusion_blk, fusion_blk_depth=fusion_blk_depth, channel_multiplier=channel_multiplier, decoder_pred_size=(4 // 1)**2 * int(triplane_decoder.out_chans // 3 * channel_multiplier), **kwargs) patch_size = vit_decoder.patch_embed.patch_size # type: ignore self.reparameterization_soft_clamp = False if isinstance(patch_size, tuple): patch_size = patch_size[0] # ! todo, hard coded unpatchify_out_chans = triplane_decoder.out_chans * 1, if unpatchify_out_chans == -1: unpatchify_out_chans = triplane_decoder.out_chans * 3 ldm_z_channels = triplane_decoder.out_chans # ldm_embed_dim = 16 # https://github.com/CompVis/latent-diffusion/blob/e66308c7f2e64cb581c6d27ab6fbeb846828253b/models/first_stage_models/kl-f16/config.yaml ldm_embed_dim = triplane_decoder.out_chans ldm_z_channels = ldm_embed_dim = triplane_decoder.out_chans self.superresolution.update( dict( after_vit_conv=nn.Conv2d( int(triplane_decoder.out_chans * 2), triplane_decoder.out_chans * 2, # for vae features 3, padding=1), quant_conv=torch.nn.Conv2d(2 * ldm_z_channels, 2 * ldm_embed_dim, 1), ldm_downsample=nn.Linear( 384, # vit_decoder.embed_dim, self.vae_p * self.vae_p * 3 * self.ldm_z_channels * 2, # 48 bias=True), ldm_upsample=nn.Linear(self.vae_p * self.vae_p * self.ldm_z_channels * 1, vit_decoder.embed_dim, bias=True), # ? too high dim upsample quant_mlp=Mlp(2 * self.ldm_z_channels, out_features=2 * self.ldm_embed_dim), conv_sr=RodinConv3D4X_lite_mlp_as_residual( int(triplane_decoder.out_chans * channel_multiplier), int(triplane_decoder.out_chans * 1)))) has_token = bool(self.cls_token) self.vit_decoder.pos_embed = nn.Parameter( torch.zeros(1, 3 * 16 * 16 + has_token, vit_decoder.embed_dim)) self.init_weights() self.reparameterization_soft_clamp = True # some instability in training VAE self.create_uvit_arch() def vae_reparameterization(self, latent, sample_posterior): """input: latent from ViT encoder """ # ! first downsample for VAE latents3D = self.superresolution['ldm_downsample'](latent) # B L 24 if self.vae_p > 1: latents3D = self.unpatchify3D( latents3D, p=self.vae_p, unpatchify_out_chans=self.ldm_z_channels * 2) # B 3 H W unpatchify_out_chans, H=W=16 now latents3D = latents3D.reshape( latents3D.shape[0], 3, -1, latents3D.shape[-1] ) # B 3 H*W C (H=self.vae_p*self.token_size) else: latents3D = latents3D.reshape(latents3D.shape[0], latents3D.shape[1], 3, 2 * self.ldm_z_channels) # B L 3 C latents3D = latents3D.permute(0, 2, 1, 3) # B 3 L C # ! maintain the cls token here # latent3D = latent.reshape() # ! do VAE here posterior = self.vae_encode(latents3D) # B self.ldm_z_channels 3 L if sample_posterior: latent = posterior.sample() else: latent = posterior.mode() # B C 3 L log_q = posterior.log_p(latent) # same shape as latent # latent = latent.permute(0, 2, 3, 4, # 1) # C to the last dim, B 3 16 16 4, for unpachify 3D # ! for LSGM KL code latent_normalized_2Ddiffusion = latent.reshape( latent.shape[0], -1, self.token_size * self.vae_p, self.token_size * self.vae_p) # B, 3*4, 16 16 log_q_2Ddiffusion = log_q.reshape( latent.shape[0], -1, self.token_size * self.vae_p, self.token_size * self.vae_p) # B, 3*4, 16 16 latent = latent.permute(0, 2, 3, 1) # B C 3 L -> B 3 L C latent = latent.reshape(latent.shape[0], -1, latent.shape[-1]) # B 3*L C ret_dict = dict( normal_entropy=posterior.normal_entropy(), latent_normalized=latent, latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, # log_q_2Ddiffusion=log_q_2Ddiffusion, log_q=log_q, posterior=posterior, latent_name= 'latent_normalized' # for which latent to decode; could be modified externally ) return ret_dict def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): if self.cls_token: cls_token = latent_from_vit[:, :1] else: cls_token = None # ViT decoder projection, from MAE latent = self.decoder_pred( latent_from_vit ) # pred_vit_latent -> patch or original size; B 768 384 latent = self.unpatchify_triplane( latent, p=4, unpatchify_out_chans=int( self.channel_multiplier * self.unpatchify_out_chans // 3)) # spatial_vit_latent, B, C, H, W (B, 96*2, 16, 16) # 4X SR with Rodin Conv 3D latent = self.superresolution['conv_sr'](latent) # still B 3C H W ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) # include the w_avg for now sr_w_code = self.w_avg assert sr_w_code is not None ret_dict.update( dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( latent_from_vit.shape[0], 0), )) # type: ignore return ret_dict def forward_vit_decoder(self, x, img_size=None): # latent: (N, L, C) from DINO/CLIP ViT encoder # * also dino ViT # add positional encoding to each token if img_size is None: img_size = self.img_size # if self.cls_token: # st() x = x + self.interpolate_pos_encoding(x, img_size, img_size)[:, :] # B, L, C B, L, C = x.shape # has [cls] token in N x = x.view(B, 3, L // 3, C) skips = [x] assert self.fusion_blk_start == 0 # in blks for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // 2 - 1]: x = blk(x) # B 3 N C skips.append(x) # mid blks # for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks)//2-1:len(self.vit_decoder.blocks)//2+1]: for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - 1:len(self.vit_decoder.blocks) // 2]: x = blk(x) # B 3 N C # out blks for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: x = x + blk.skip_linear(torch.cat([x, skips.pop()], dim=-1)) # long skip connections x = blk(x) # B 3 N C x = self.vit_decoder.norm(x) # post process shape x = x.view(B, L, C) return x def triplane_decode(self, vit_decode_out, c, return_raw_only=False, **kwargs): if isinstance(vit_decode_out, dict): latent_after_vit, sr_w_code = (vit_decode_out.get(k, None) for k in ('latent_after_vit', 'sr_w_code')) else: latent_after_vit = vit_decode_out sr_w_code = None vit_decode_out = dict(latent_normalized=latent_after_vit ) # for later dict update compatability # * triplane rendering ret_dict = self.triplane_decoder(latent_after_vit, c, ws=sr_w_code, return_raw_only=return_raw_only, **kwargs) # triplane latent -> imgs ret_dict.update({ 'latent_after_vit': latent_after_vit, **vit_decode_out }) return ret_dict def vit_decode_backbone(self, latent, img_size): # assert x.ndim == 3 # N L C if isinstance(latent, dict): if 'latent_normalized' not in latent: latent = latent[ 'latent_normalized_2Ddiffusion'] # B, C*3, H, W else: latent = latent[ 'latent_normalized'] # TODO, just for compatability now # st() if latent.ndim != 3: # B 3*4 16 16 latent = latent.reshape(latent.shape[0], latent.shape[1] // 3, 3, (self.vae_p * self.token_size)**2).permute( 0, 2, 3, 1) # B C 3 L => B 3 L C latent = latent.reshape(latent.shape[0], -1, latent.shape[-1]) # B 3*L C assert latent.shape == ( # latent.shape[0], 3 * (self.token_size**2), latent.shape[0], 3 * ((self.vae_p * self.token_size)**2), self.ldm_z_channels), f'latent.shape: {latent.shape}' latent = self.superresolution['ldm_upsample'](latent) return super().vit_decode_backbone( latent, img_size) # torch.Size([8, 3072, 768]) class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn( ViTTriplaneDecomposed): # lite version, no sd-bg, use TriplaneFusionBlockv4_nested_init_from_dino def __init__( self, vit_decoder: VisionTransformer, triplane_decoder: Triplane_fg_bg_plane, cls_token, # normalize_feat=True, # sr_ratio=2, use_fusion_blk=True, fusion_blk_depth=2, fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, channel_multiplier=4, ldm_z_channels=4, # ldm_embed_dim=4, vae_p=2, **kwargs) -> None: # st() super().__init__( vit_decoder, triplane_decoder, cls_token, # normalize_feat, channel_multiplier=channel_multiplier, use_fusion_blk=use_fusion_blk, fusion_blk_depth=fusion_blk_depth, fusion_blk=fusion_blk, ldm_z_channels=ldm_z_channels, ldm_embed_dim=ldm_embed_dim, vae_p=vae_p, decoder_pred_size=(4 // 1)**2 * int(triplane_decoder.out_chans // 3 * channel_multiplier), **kwargs) logger.log( f'length of vit_decoder.blocks: {len(self.vit_decoder.blocks)}') # latent vae modules self.superresolution.update( dict( ldm_downsample=nn.Linear( 384, self.vae_p * self.vae_p * 3 * self.ldm_z_channels * 2, # 48 bias=True), ldm_upsample=PatchEmbedTriplane( self.vae_p * self.token_size, self.vae_p, 3 * self.ldm_embed_dim, # B 3 L C vit_decoder.embed_dim, bias=True), quant_conv=nn.Conv2d(2 * 3 * self.ldm_z_channels, 2 * self.ldm_embed_dim * 3, kernel_size=1, groups=3), conv_sr=RodinConv3D4X_lite_mlp_as_residual_lite( int(triplane_decoder.out_chans * channel_multiplier), int(triplane_decoder.out_chans * 1)))) # ! initialize weights self.init_weights() self.reparameterization_soft_clamp = True # some instability in training VAE self.create_uvit_arch() # create skip linear, adapted from uvit # for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: # blk.skip_linear = nn.Linear(2 * self.vit_decoder.embed_dim, # self.vit_decoder.embed_dim) # # trunc_normal_(blk.skip_linear.weight, std=.02) # nn.init.constant_(blk.skip_linear.weight, 0) # if isinstance(blk.skip_linear, # nn.Linear) and blk.skip_linear.bias is not None: # nn.init.constant_(blk.skip_linear.bias, 0) def vit_decode(self, latent, img_size, sample_posterior=True): ret_dict = self.vae_reparameterization(latent, sample_posterior) # latent = ret_dict['latent_normalized'] latent = self.vit_decode_backbone(ret_dict, img_size) return self.vit_decode_postprocess(latent, ret_dict) # # ! merge? def unpatchify3D(self, x, p, unpatchify_out_chans, plane_n=3): """ x: (N, L, patch_size**2 * self.out_chans) return: 3D latents """ if self.cls_token: # TODO, how to better use cls token x = x[:, 1:] h = w = int(x.shape[1]**.5) assert h * w == x.shape[1] x = x.reshape(shape=(x.shape[0], h, w, p, p, plane_n, unpatchify_out_chans)) x = torch.einsum( 'nhwpqdc->ndhpwqc', x ) # nplanes, C little endian tradiition, as defined in the renderer.py latents3D = x.reshape(shape=(x.shape[0], plane_n, h * p, h * p, unpatchify_out_chans)) return latents3D # ! merge? def vae_encode(self, h): # * smooth convolution before triplane # h = self.superresolution['after_vit_conv'](h) # h = h.permute(0, 2, 3, 1) # B 64 64 6 B, _, H, W = h.shape moments = self.superresolution['quant_conv'](h) moments = moments.reshape( B, # moments.shape[1] // 3, moments.shape[1] // self.plane_n, # 3, self.plane_n, H, W, ) # B C 3 H W moments = moments.flatten(-2) # B C 3 L posterior = DiagonalGaussianDistribution( moments, soft_clamp=self.reparameterization_soft_clamp) return posterior def vae_reparameterization(self, latent, sample_posterior): """input: latent from ViT encoder """ # ! first downsample for VAE # st() # latent: B 256 384 latents3D = self.superresolution['ldm_downsample']( latent) # latents3D: B 256 96 assert self.vae_p > 1 latents3D = self.unpatchify3D( latents3D, p=self.vae_p, unpatchify_out_chans=self.ldm_z_channels * 2) # B 3 H W unpatchify_out_chans, H=W=16 now # latents3D = latents3D.reshape( # latents3D.shape[0], 3, -1, latents3D.shape[-1] # ) # B 3 H*W C (H=self.vae_p*self.token_size) # else: # latents3D = latents3D.reshape(latents3D.shape[0], # latents3D.shape[1], 3, # 2 * self.ldm_z_channels) # B L 3 C # latents3D = latents3D.permute(0, 2, 1, 3) # B 3 L C B, _, H, W, C = latents3D.shape latents3D = latents3D.permute(0, 1, 4, 2, 3).reshape(B, -1, H, W) # B 3C H W # ! do VAE here posterior = self.vae_encode(latents3D) # B self.ldm_z_channels 3 L if sample_posterior: latent = posterior.sample() else: latent = posterior.mode() # B C 3 L log_q = posterior.log_p(latent) # same shape as latent # ! for LSGM KL code latent_normalized_2Ddiffusion = latent.reshape( latent.shape[0], -1, self.token_size * self.vae_p, self.token_size * self.vae_p) # B, 3*4, 16 16 log_q_2Ddiffusion = log_q.reshape( latent.shape[0], -1, self.token_size * self.vae_p, self.token_size * self.vae_p) # B, 3*4, 16 16 # st() latent = latent.permute(0, 2, 3, 1) # B C 3 L -> B 3 L C latent = latent.reshape(latent.shape[0], -1, latent.shape[-1]) # B 3*L C ret_dict = dict( normal_entropy=posterior.normal_entropy(), latent_normalized=latent, latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, # log_q_2Ddiffusion=log_q_2Ddiffusion, log_q=log_q, posterior=posterior, ) return ret_dict def vit_decode_backbone(self, latent, img_size): # assert x.ndim == 3 # N L C if isinstance(latent, dict): latent = latent['latent_normalized_2Ddiffusion'] # B, C*3, H, W # assert latent.shape == ( # latent.shape[0], 3 * (self.token_size * self.vae_p)**2, # self.ldm_z_channels), f'latent.shape: {latent.shape}' # st() # latent: B 12 32 32 latent = self.superresolution['ldm_upsample']( # ! B 768 (3*256) 768 latent) # torch.Size([8, 12, 32, 32]) => torch.Size([8, 256, 768]) # latent: torch.Size([8, 768, 768]) # ! directly feed to vit_decoder return self.forward_vit_decoder(latent, img_size) # pred_vit_latent def triplane_decode(self, vit_decode_out, c, return_raw_only=False, **kwargs): if isinstance(vit_decode_out, dict): latent_after_vit, sr_w_code = (vit_decode_out.get(k, None) for k in ('latent_after_vit', 'sr_w_code')) else: latent_after_vit = vit_decode_out sr_w_code = None vit_decode_out = dict(latent_normalized=latent_after_vit ) # for later dict update compatability # * triplane rendering ret_dict = self.triplane_decoder(latent_after_vit, c, ws=sr_w_code, return_raw_only=return_raw_only, **kwargs) # triplane latent -> imgs ret_dict.update({ 'latent_after_vit': latent_after_vit, **vit_decode_out }) return ret_dict def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): if self.cls_token: cls_token = latent_from_vit[:, :1] else: cls_token = None # ViT decoder projection, from MAE latent = self.decoder_pred( latent_from_vit ) # pred_vit_latent -> patch or original size; B 768 384 latent = self.unpatchify_triplane( latent, p=4, unpatchify_out_chans=int( self.channel_multiplier * self.unpatchify_out_chans // 3)) # spatial_vit_latent, B, C, H, W (B, 96*2, 16, 16) # 4X SR with Rodin Conv 3D latent = self.superresolution['conv_sr'](latent) # still B 3C H W ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) # include the w_avg for now sr_w_code = self.w_avg assert sr_w_code is not None ret_dict.update( dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( latent_from_vit.shape[0], 0), )) # type: ignore return ret_dict def forward_vit_decoder(self, x, img_size=None): # latent: (N, L, C) from DINO/CLIP ViT encoder # * also dino ViT # add positional encoding to each token if img_size is None: img_size = self.img_size # if self.cls_token: # st() x = x + self.interpolate_pos_encoding(x, img_size, img_size)[:, :] # B, L, C B, L, C = x.shape # has [cls] token in N x = x.view(B, 3, L // 3, C) skips = [x] assert self.fusion_blk_start == 0 # in blks for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // 2 - 1]: x = blk(x) # B 3 N C skips.append(x) # mid blks # for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks)//2-1:len(self.vit_decoder.blocks)//2+1]: for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - 1:len(self.vit_decoder.blocks) // 2]: x = blk(x) # B 3 N C # out blks for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: x = x + blk.skip_linear(torch.cat([x, skips.pop()], dim=-1)) # long skip connections x = blk(x) # B 3 N C x = self.vit_decoder.norm(x) # post process shape x = x.view(B, L, C) return x # ! SD version class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD( RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn): def __init__(self, vit_decoder: VisionTransformer, triplane_decoder: Triplane_fg_bg_plane, cls_token, normalize_feat=True, sr_ratio=2, use_fusion_blk=True, fusion_blk_depth=2, fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, channel_multiplier=4, **kwargs) -> None: super().__init__(vit_decoder, triplane_decoder, cls_token, # sr_ratio=sr_ratio, # not used use_fusion_blk=use_fusion_blk, fusion_blk_depth=fusion_blk_depth, fusion_blk=fusion_blk, channel_multiplier=channel_multiplier, **kwargs) for k in [ 'ldm_downsample', # 'conv_sr' ]: del self.superresolution[k] def vae_reparameterization(self, latent, sample_posterior): # latent: B 24 32 32 assert self.vae_p > 1 # latents3D = self.unpatchify3D( # latents3D, # p=self.vae_p, # unpatchify_out_chans=self.ldm_z_channels * # 2) # B 3 H W unpatchify_out_chans, H=W=16 now # B, C3, H, W = latent.shape # latents3D = latent.reshape(B, 3, C3//3, H, W) # latents3D = latents3D.permute(0, 1, 4, 2, 3).reshape(B, -1, H, # W) # B 3C H W # ! do VAE here posterior = self.vae_encode(latent) # B self.ldm_z_channels 3 L if sample_posterior: latent = posterior.sample() else: latent = posterior.mode() # B C 3 L log_q = posterior.log_p(latent) # same shape as latent # ! for LSGM KL code latent_normalized_2Ddiffusion = latent.reshape( latent.shape[0], -1, self.token_size * self.vae_p, self.token_size * self.vae_p) # B, 3*4, 16 16 log_q_2Ddiffusion = log_q.reshape( latent.shape[0], -1, self.token_size * self.vae_p, self.token_size * self.vae_p) # B, 3*4, 16 16 latent = latent.permute(0, 2, 3, 1) # B C 3 L -> B 3 L C latent = latent.reshape(latent.shape[0], -1, latent.shape[-1]) # B 3*L C ret_dict = dict( normal_entropy=posterior.normal_entropy(), latent_normalized=latent, latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, # log_q_2Ddiffusion=log_q_2Ddiffusion, log_q=log_q, posterior=posterior, ) return ret_dict class RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD_D( RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): def __init__(self, vit_decoder: VisionTransformer, triplane_decoder: Triplane_fg_bg_plane, cls_token, normalize_feat=True, sr_ratio=2, use_fusion_blk=True, fusion_blk_depth=2, fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino, channel_multiplier=4, **kwargs) -> None: super().__init__(vit_decoder, triplane_decoder, cls_token, normalize_feat, sr_ratio, use_fusion_blk, fusion_blk_depth, fusion_blk, channel_multiplier, **kwargs) self.decoder_pred = None # directly un-patchembed self.superresolution.update( dict(conv_sr=Decoder( # serve as Deconv resolution=128, in_channels=3, # ch=64, ch=32, ch_mult=[1, 2, 2, 4], # num_res_blocks=2, # ch_mult=[1,2,4], num_res_blocks=1, dropout=0.0, attn_resolutions=[], out_ch=32, # z_channels=vit_decoder.embed_dim//4, z_channels=vit_decoder.embed_dim, ))) # ''' # for SD Decoder, verify encoder first def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): if self.cls_token: cls_token = latent_from_vit[:, :1] else: cls_token = None def unflatten_token(x, p=None): B, L, C = x.shape x = x.reshape(B, 3, L // 3, C) if self.cls_token: # TODO, how to better use cls token x = x[:, :, 1:] # B 3 256 C h = w = int((x.shape[2])**.5) assert h * w == x.shape[2] if p is None: x = x.reshape(shape=(B, 3, h, w, -1)) x = rearrange( x, 'b n h w c->(b n) c h w' ) # merge plane into Batch and prepare for rendering else: x = x.reshape(shape=(B, 3, h, w, p, p, -1)) x = rearrange( x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' ) # merge plane into Batch and prepare for rendering return x latent = unflatten_token(latent_from_vit) # latent = unflatten_token(latent_from_vit, p=2) # ! SD SR latent = self.superresolution['conv_sr'](latent) # still B 3C H W latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) # include the w_avg for now # sr_w_code = self.w_avg # assert sr_w_code is not None # ret_dict.update( # dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( # latent_from_vit.shape[0], 0), )) # type: ignore return ret_dict # ''' class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_lite3DAttn( RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): def __init__(self, vit_decoder: VisionTransformer, triplane_decoder: Triplane_fg_bg_plane, cls_token, normalize_feat=True, sr_ratio=2, use_fusion_blk=True, fusion_blk_depth=2, fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite, channel_multiplier=4, **kwargs) -> None: super().__init__(vit_decoder, triplane_decoder, cls_token, normalize_feat, sr_ratio, use_fusion_blk, fusion_blk_depth, fusion_blk, channel_multiplier, **kwargs) # 1. convert output plane token to B L 3 C//3 shape # 2. change vit decoder fusion arch (fusion block) # 3. output follow B L 3 C//3 with decoder input dim C//3 # TODO: ablate basic decoder design, on the metrics (input/novelview both) self.decoder_pred = nn.Linear(self.vit_decoder.embed_dim // 3, 2048, bias=True) # decoder to patch # st() self.superresolution.update( dict(ldm_upsample=PatchEmbedTriplaneRodin( self.vae_p * self.token_size, self.vae_p, 3 * self.ldm_embed_dim, # B 3 L C vit_decoder.embed_dim // 3, bias=True))) # ! original pos_embed has_token = bool(self.cls_token) self.vit_decoder.pos_embed = nn.Parameter( torch.zeros(1, 16 * 16 + has_token, vit_decoder.embed_dim)) def forward(self, latent, c, img_size): latent_normalized = self.vit_decode(latent, img_size) return self.triplane_decode(latent_normalized, c) def vae_reparameterization(self, latent, sample_posterior): # latent: B 24 32 32 assert self.vae_p > 1 # ! do VAE here # st() posterior = self.vae_encode(latent) # B self.ldm_z_channels 3 L if sample_posterior: latent = posterior.sample() else: latent = posterior.mode() # B C 3 L log_q = posterior.log_p(latent) # same shape as latent # ! for LSGM KL code latent_normalized_2Ddiffusion = latent.reshape( latent.shape[0], -1, self.token_size * self.vae_p, self.token_size * self.vae_p) # B, 3*4, 16 16 log_q_2Ddiffusion = log_q.reshape( latent.shape[0], -1, self.token_size * self.vae_p, self.token_size * self.vae_p) # B, 3*4, 16 16 # TODO, add a conv_after_quant # ! reshape for ViT decoder latent = latent.permute(0, 3, 1, 2) # B C 3 L -> B L C 3 latent = latent.reshape(*latent.shape[:2], -1) # B L C3 ret_dict = dict( normal_entropy=posterior.normal_entropy(), latent_normalized=latent, latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, # log_q_2Ddiffusion=log_q_2Ddiffusion, log_q=log_q, posterior=posterior, ) return ret_dict def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): if self.cls_token: cls_token = latent_from_vit[:, :1] else: cls_token = None B, N, C = latent_from_vit.shape latent_from_vit = latent_from_vit.reshape(B, N, C // 3, 3).permute( 0, 3, 1, 2) # -> B 3 N C//3 # ! remaining unchanged # ViT decoder projection, from MAE latent = self.decoder_pred( latent_from_vit ) # pred_vit_latent -> patch or original size; B 768 384 latent = latent.reshape(B, 3 * N, -1) # B L C latent = self.unpatchify_triplane( latent, p=4, unpatchify_out_chans=int( self.channel_multiplier * self.unpatchify_out_chans // 3)) # spatial_vit_latent, B, C, H, W (B, 96*2, 16, 16) # 4X SR with Rodin Conv 3D latent = self.superresolution['conv_sr'](latent) # still B 3C H W ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) # include the w_avg for now sr_w_code = self.w_avg assert sr_w_code is not None ret_dict.update( dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( latent_from_vit.shape[0], 0), )) # type: ignore return ret_dict def vit_decode_backbone(self, latent, img_size): # assert x.ndim == 3 # N L C if isinstance(latent, dict): latent = latent['latent_normalized_2Ddiffusion'] # B, C*3, H, W # assert latent.shape == ( # latent.shape[0], 3 * (self.token_size * self.vae_p)**2, # self.ldm_z_channels), f'latent.shape: {latent.shape}' # st() # latent: B 12 32 32 latent = self.superresolution['ldm_upsample']( # ! B 768 (3*256) 768 latent) # torch.Size([8, 12, 32, 32]) => torch.Size([8, 256, 768]) # latent: torch.Size([8, 768, 768]) B, N3, C = latent.shape latent = latent.reshape(B, 3, N3 // 3, C).permute(0, 2, 3, 1) # B 3HW C -> B HW C 3 latent = latent.reshape(*latent.shape[:2], -1) # B HW C3 # ! directly feed to vit_decoder return self.forward_vit_decoder(latent, img_size) # pred_vit_latent def forward_vit_decoder(self, x, img_size=None): # latent: (N, L, C) from DINO/CLIP ViT encoder # * also dino ViT # add positional encoding to each token if img_size is None: img_size = self.img_size # if self.cls_token: x = x + self.interpolate_pos_encoding(x, img_size, img_size)[:, :] # B, L, C B, L, C = x.shape # has [cls] token in N # ! no need to reshape here # x = x.view(B, 3, L // 3, C) skips = [x] assert self.fusion_blk_start == 0 # in blks for blk in self.vit_decoder.blocks[0:len(self.vit_decoder.blocks) // 2 - 1]: x = blk(x) # B 3 N C skips.append(x) # mid blks # for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks)//2-1:len(self.vit_decoder.blocks)//2+1]: for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2 - 1:len(self.vit_decoder.blocks) // 2]: x = blk(x) # B 3 N C # out blks for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: x = x + blk.skip_linear(torch.cat([x, skips.pop()], dim=-1)) # long skip connections x = blk(x) # B 3 N C x = self.vit_decoder.norm(x) # post process shape x = x.view(B, L, C) return x def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): vit_decoder_blks = self.vit_decoder.blocks assert len(vit_decoder_blks) == 12, 'ViT-B by default' nh = self.vit_decoder.blocks[ 0].attn.num_heads // 3 # ! lighter, actually divisible by 4 dim = self.vit_decoder.embed_dim // 3 # ! separate fusion_blk_start = self.fusion_blk_start triplane_fusion_vit_blks = nn.ModuleList() if fusion_blk_start != 0: for i in range(0, fusion_blk_start): triplane_fusion_vit_blks.append( vit_decoder_blks[i]) # append all vit blocks in the front for i in range(fusion_blk_start, len(vit_decoder_blks), fusion_blk_depth): vit_blks_group = vit_decoder_blks[i:i + fusion_blk_depth] # moduleList triplane_fusion_vit_blks.append( # TriplaneFusionBlockv2(vit_blks_group, nh, dim, use_fusion_blk)) fusion_blk(vit_blks_group, nh, dim, use_fusion_blk)) self.vit_decoder.blocks = triplane_fusion_vit_blks # self.vit_decoder.blocks = triplane_fusion_vit_blks # default for objaverse class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder_S( RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn): def __init__( self, vit_decoder: VisionTransformer, triplane_decoder: Triplane_fg_bg_plane, cls_token, normalize_feat=True, sr_ratio=2, use_fusion_blk=True, fusion_blk_depth=2, fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, channel_multiplier=4, **kwargs) -> None: super().__init__( vit_decoder, triplane_decoder, cls_token, use_fusion_blk=use_fusion_blk, fusion_blk_depth=fusion_blk_depth, fusion_blk=fusion_blk, channel_multiplier=channel_multiplier, patch_size=-1, # placeholder, since we use dit here token_size=2, **kwargs) self.D_roll_out_input = False for k in [ 'ldm_downsample', # 'conv_sr' ]: del self.superresolution[k] self.decoder_pred = None # directly un-patchembed self.superresolution.update( dict( conv_sr=Decoder( # serve as Deconv resolution=128, # resolution=256, in_channels=3, # ch=64, ch=32, # ch=16, ch_mult=[1, 2, 2, 4], # ch_mult=[1, 1, 2, 2], # num_res_blocks=2, # ch_mult=[1,2,4], # num_res_blocks=0, num_res_blocks=1, dropout=0.0, attn_resolutions=[], out_ch=32, # z_channels=vit_decoder.embed_dim//4, z_channels=vit_decoder.embed_dim, # z_channels=vit_decoder.embed_dim//2, ), # after_vit_upsampler=Upsample2D(channels=vit_decoder.embed_dim,use_conv=True, use_conv_transpose=False, out_channels=vit_decoder.embed_dim//2) )) # del skip_lienar for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: del blk.skip_linear @torch.inference_mode() def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**16): # planes: (N, 3, D', H', W') # points: (N, P, 3) N, P = points.shape[:2] if planes.ndim == 4: planes = planes.reshape( len(planes), 3, -1, # ! support background plane planes.shape[-2], planes.shape[-1]) # BS 96 256 256 # query triplane in chunks outs = [] for i in trange(0, points.shape[1], chunk_size): chunk_points = points[:, i:i + chunk_size] # query triplane # st() chunk_out = self.triplane_decoder.renderer._run_model( # type: ignore planes=planes, decoder=self.triplane_decoder.decoder, sample_coordinates=chunk_points, sample_directions=torch.zeros_like(chunk_points), options=self.rendering_kwargs, ) # st() outs.append(chunk_out) torch.cuda.empty_cache() # st() # concatenate the outputs point_features = { k: torch.cat([out[k] for out in outs], dim=1) for k in outs[0].keys() } return point_features def triplane_decode_grid(self, vit_decode_out, grid_size, aabb: torch.Tensor = None, **kwargs): # planes: (N, 3, D', H', W') # grid_size: int assert isinstance(vit_decode_out, dict) planes = vit_decode_out['latent_after_vit'] # aabb: (N, 2, 3) if aabb is None: if 'sampler_bbox_min' in self.rendering_kwargs: aabb = torch.tensor([ [self.rendering_kwargs['sampler_bbox_min']] * 3, [self.rendering_kwargs['sampler_bbox_max']] * 3, ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat( planes.shape[0], 1, 1) else: # shapenet dataset, follow eg3d aabb = torch.tensor( [ # https://github.com/NVlabs/eg3d/blob/7cf1fd1e99e1061e8b6ba850f91c94fe56e7afe4/eg3d/gen_samples.py#L188 [-self.rendering_kwargs['box_warp'] / 2] * 3, [self.rendering_kwargs['box_warp'] / 2] * 3, ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat( planes.shape[0], 1, 1) assert planes.shape[0] == aabb.shape[ 0], "Batch size mismatch for planes and aabb" N = planes.shape[0] # create grid points for triplane query grid_points = [] for i in range(N): grid_points.append( torch.stack(torch.meshgrid( torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), indexing='ij', ), dim=-1).reshape(-1, 3)) cube_grid = torch.stack(grid_points, dim=0).to(planes.device) # 1 N 3 # st() features = self.forward_points(planes, cube_grid) # reshape into grid features = { k: v.reshape(N, grid_size, grid_size, grid_size, -1) for k, v in features.items() } # st() return features def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): # no need to fuse anymore pass def forward_vit_decoder(self, x, img_size=None): # st() return self.vit_decoder(x) def vit_decode_backbone(self, latent, img_size): # assert x.ndim == 3 # N L C if isinstance(latent, dict): latent = latent['latent_normalized_2Ddiffusion'] # B, C*3, H, W # assert latent.shape == ( # latent.shape[0], 3 * (self.token_size * self.vae_p)**2, # self.ldm_z_channels), f'latent.shape: {latent.shape}' # st() # latent: B 12 32 32 # st() latent = self.superresolution['ldm_upsample']( # ! B 768 (3*256) 768 latent) # torch.Size([8, 12, 32, 32]) => torch.Size([8, 256, 768]) # latent: torch.Size([8, 768, 768]) # ! directly feed to vit_decoder return self.forward_vit_decoder(latent, img_size) # pred_vit_latent def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): if self.cls_token: cls_token = latent_from_vit[:, :1] else: cls_token = None def unflatten_token(x, p=None): B, L, C = x.shape x = x.reshape(B, 3, L // 3, C) if self.cls_token: # TODO, how to better use cls token x = x[:, :, 1:] # B 3 256 C h = w = int((x.shape[2])**.5) assert h * w == x.shape[2] if p is None: x = x.reshape(shape=(B, 3, h, w, -1)) if not self.D_roll_out_input: x = rearrange( x, 'b n h w c->(b n) c h w' ) # merge plane into Batch and prepare for rendering else: x = rearrange( x, 'b n h w c->b c h (n w)' ) # merge plane into Batch and prepare for rendering else: x = x.reshape(shape=(B, 3, h, w, p, p, -1)) if self.D_roll_out_input: x = rearrange( x, 'b n h w p1 p2 c->b c (h p1) (n w p2)' ) # merge plane into Batch and prepare for rendering else: x = rearrange( x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' ) # merge plane into Batch and prepare for rendering return x latent = unflatten_token( latent_from_vit) # B 3 h w vit_decoder.embed_dim # ! x2 upsapmle, 16 -32 before sending into SD Decoder # latent = self.superresolution['after_vit_upsampler'](latent) # B*3 192 32 32 # latent = unflatten_token(latent_from_vit, p=2) # ! SD SR latent = self.superresolution['conv_sr'](latent) # still B 3C H W if not self.D_roll_out_input: latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) else: latent = rearrange(latent, 'b c h (n w)->b (n c) h w', n=3) ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) # include the w_avg for now # sr_w_code = self.w_avg # assert sr_w_code is not None # ret_dict.update( # dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( # latent_from_vit.shape[0], 0), )) # type: ignore return ret_dict def vae_reparameterization(self, latent, sample_posterior): # latent: B 24 32 32 assert self.vae_p > 1 # latents3D = self.unpatchify3D( # latents3D, # p=self.vae_p, # unpatchify_out_chans=self.ldm_z_channels * # 2) # B 3 H W unpatchify_out_chans, H=W=16 now # B, C3, H, W = latent.shape # latents3D = latent.reshape(B, 3, C3//3, H, W) # latents3D = latents3D.permute(0, 1, 4, 2, 3).reshape(B, -1, H, # W) # B 3C H W # ! do VAE here posterior = self.vae_encode(latent) # B self.ldm_z_channels 3 L if sample_posterior: latent = posterior.sample() else: latent = posterior.mode() # B C 3 L log_q = posterior.log_p(latent) # same shape as latent # ! for LSGM KL code latent_normalized_2Ddiffusion = latent.reshape( latent.shape[0], -1, self.token_size * self.vae_p, self.token_size * self.vae_p) # B, 3*4, 16 16 log_q_2Ddiffusion = log_q.reshape( latent.shape[0], -1, self.token_size * self.vae_p, self.token_size * self.vae_p) # B, 3*4, 16 16 # st() latent = latent.permute(0, 2, 3, 1) # B C 3 L -> B 3 L C latent = latent.reshape(latent.shape[0], -1, latent.shape[-1]) # B 3*L C ret_dict = dict( normal_entropy=posterior.normal_entropy(), latent_normalized=latent, latent_normalized_2Ddiffusion=latent_normalized_2Ddiffusion, # log_q_2Ddiffusion=log_q_2Ddiffusion, log_q=log_q, posterior=posterior, ) return ret_dict # objv class class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout( RodinSR_256_fusionv5_ConvQuant_liteSR_dinoInit3DAttn_SD): def __init__( self, vit_decoder: VisionTransformer, triplane_decoder: Triplane_fg_bg_plane, cls_token, normalize_feat=True, sr_ratio=2, use_fusion_blk=True, fusion_blk_depth=2, fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, channel_multiplier=4, **kwargs) -> None: super().__init__(vit_decoder, triplane_decoder, cls_token, normalize_feat, sr_ratio, use_fusion_blk, fusion_blk_depth, fusion_blk, channel_multiplier, **kwargs) # final version, above + SD-Decoder class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D( RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout ): def __init__( self, vit_decoder: VisionTransformer, triplane_decoder: Triplane_fg_bg_plane, cls_token, normalize_feat=True, sr_ratio=2, use_fusion_blk=True, fusion_blk_depth=2, fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, channel_multiplier=4, **kwargs) -> None: super().__init__(vit_decoder, triplane_decoder, cls_token, normalize_feat, sr_ratio, use_fusion_blk, fusion_blk_depth, fusion_blk, channel_multiplier, **kwargs) self.decoder_pred = None # directly un-patchembed self.superresolution.update( dict( conv_sr=Decoder( # serve as Deconv resolution=128, # resolution=256, in_channels=3, # ch=64, ch=32, # ch=16, ch_mult=[1, 2, 2, 4], # ch_mult=[1, 1, 2, 2], # num_res_blocks=2, # ch_mult=[1,2,4], # num_res_blocks=0, num_res_blocks=1, dropout=0.0, attn_resolutions=[], out_ch=32, # z_channels=vit_decoder.embed_dim//4, z_channels=vit_decoder.embed_dim, # z_channels=vit_decoder.embed_dim//2, ), # after_vit_upsampler=Upsample2D(channels=vit_decoder.embed_dim,use_conv=True, use_conv_transpose=False, out_channels=vit_decoder.embed_dim//2) )) self.D_roll_out_input = False # ''' # for SD Decoder def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): if self.cls_token: cls_token = latent_from_vit[:, :1] else: cls_token = None def unflatten_token(x, p=None): B, L, C = x.shape x = x.reshape(B, 3, L // 3, C) if self.cls_token: # TODO, how to better use cls token x = x[:, :, 1:] # B 3 256 C h = w = int((x.shape[2])**.5) assert h * w == x.shape[2] if p is None: x = x.reshape(shape=(B, 3, h, w, -1)) if not self.D_roll_out_input: x = rearrange( x, 'b n h w c->(b n) c h w' ) # merge plane into Batch and prepare for rendering else: x = rearrange( x, 'b n h w c->b c h (n w)' ) # merge plane into Batch and prepare for rendering else: x = x.reshape(shape=(B, 3, h, w, p, p, -1)) if self.D_roll_out_input: x = rearrange( x, 'b n h w p1 p2 c->b c (h p1) (n w p2)' ) # merge plane into Batch and prepare for rendering else: x = rearrange( x, 'b n h w p1 p2 c->(b n) c (h p1) (w p2)' ) # merge plane into Batch and prepare for rendering return x latent = unflatten_token( latent_from_vit) # B 3 h w vit_decoder.embed_dim # ! x2 upsapmle, 16 -32 before sending into SD Decoder # latent = self.superresolution['after_vit_upsampler'](latent) # B*3 192 32 32 # latent = unflatten_token(latent_from_vit, p=2) # ! SD SR latent = self.superresolution['conv_sr'](latent) # still B 3C H W if not self.D_roll_out_input: latent = rearrange(latent, '(b n) c h w->b (n c) h w', n=3) else: latent = rearrange(latent, 'b c h (n w)->b (n c) h w', n=3) ret_dict.update(dict(cls_token=cls_token, latent_after_vit=latent)) # include the w_avg for now # sr_w_code = self.w_avg # assert sr_w_code is not None # ret_dict.update( # dict(sr_w_code=sr_w_code.reshape(1, 1, -1).repeat_interleave( # latent_from_vit.shape[0], 0), )) # type: ignore return ret_dict # ''' class RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D_ditDecoder( RodinSR_256_fusionv6_ConvQuant_liteSR_dinoInit3DAttn_SD_B_3L_C_withrollout_withSD_D ): def __init__( self, vit_decoder: VisionTransformer, triplane_decoder: Triplane_fg_bg_plane, cls_token, normalize_feat=True, sr_ratio=2, use_fusion_blk=True, fusion_blk_depth=2, fusion_blk=TriplaneFusionBlockv4_nested_init_from_dino_lite_merge_B_3L_C_withrollout, channel_multiplier=4, **kwargs) -> None: super().__init__(vit_decoder, triplane_decoder, cls_token, normalize_feat, sr_ratio, use_fusion_blk, fusion_blk_depth, fusion_blk, channel_multiplier, patch_size=-1, **kwargs) # del skip_lienar for blk in self.vit_decoder.blocks[len(self.vit_decoder.blocks) // 2:]: del blk.skip_linear @torch.inference_mode() def forward_points(self, planes, points: torch.Tensor, chunk_size: int = 2**16): # planes: (N, 3, D', H', W') # points: (N, P, 3) N, P = points.shape[:2] if planes.ndim == 4: planes = planes.reshape( len(planes), 3, -1, # ! support background plane planes.shape[-2], planes.shape[-1]) # BS 96 256 256 # query triplane in chunks outs = [] for i in trange(0, points.shape[1], chunk_size): chunk_points = points[:, i:i + chunk_size] # query triplane # st() chunk_out = self.triplane_decoder.renderer._run_model( # type: ignore planes=planes, decoder=self.triplane_decoder.decoder, sample_coordinates=chunk_points, sample_directions=torch.zeros_like(chunk_points), options=self.rendering_kwargs, ) # st() outs.append(chunk_out) torch.cuda.empty_cache() # st() # concatenate the outputs point_features = { k: torch.cat([out[k] for out in outs], dim=1) for k in outs[0].keys() } return point_features def triplane_decode_grid(self, vit_decode_out, grid_size, aabb: torch.Tensor = None, **kwargs): # planes: (N, 3, D', H', W') # grid_size: int assert isinstance(vit_decode_out, dict) planes = vit_decode_out['latent_after_vit'] # aabb: (N, 2, 3) if aabb is None: if 'sampler_bbox_min' in self.rendering_kwargs: aabb = torch.tensor([ [self.rendering_kwargs['sampler_bbox_min']] * 3, [self.rendering_kwargs['sampler_bbox_max']] * 3, ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat( planes.shape[0], 1, 1) else: # shapenet dataset, follow eg3d aabb = torch.tensor( [ # https://github.com/NVlabs/eg3d/blob/7cf1fd1e99e1061e8b6ba850f91c94fe56e7afe4/eg3d/gen_samples.py#L188 [-self.rendering_kwargs['box_warp'] / 2] * 3, [self.rendering_kwargs['box_warp'] / 2] * 3, ], device=planes.device, dtype=planes.dtype).unsqueeze(0).repeat( planes.shape[0], 1, 1) assert planes.shape[0] == aabb.shape[ 0], "Batch size mismatch for planes and aabb" N = planes.shape[0] # create grid points for triplane query grid_points = [] for i in range(N): grid_points.append( torch.stack(torch.meshgrid( torch.linspace(aabb[i, 0, 0], aabb[i, 1, 0], grid_size, device=planes.device), torch.linspace(aabb[i, 0, 1], aabb[i, 1, 1], grid_size, device=planes.device), torch.linspace(aabb[i, 0, 2], aabb[i, 1, 2], grid_size, device=planes.device), indexing='ij', ), dim=-1).reshape(-1, 3)) cube_grid = torch.stack(grid_points, dim=0).to(planes.device) # 1 N 3 # st() features = self.forward_points(planes, cube_grid) # reshape into grid features = { k: v.reshape(N, grid_size, grid_size, grid_size, -1) for k, v in features.items() } # st() return features def create_fusion_blks(self, fusion_blk_depth, use_fusion_blk, fusion_blk): # no need to fuse anymore pass def forward_vit_decoder(self, x, img_size=None): # st() return self.vit_decoder(x) def vit_decode_backbone(self, latent, img_size): return super().vit_decode_backbone(latent, img_size) # ! flag2 def vit_decode_postprocess(self, latent_from_vit, ret_dict: dict): return super().vit_decode_postprocess(latent_from_vit, ret_dict) def vae_reparameterization(self, latent, sample_posterior): return super().vae_reparameterization(latent, sample_posterior)