# SPDX-FileCopyrightText: Copyright (c) 2021-2022 NVIDIA CORPORATION & AFFILIATES. All rights reserved. # SPDX-License-Identifier: LicenseRef-NvidiaProprietary # # NVIDIA CORPORATION, its affiliates and licensors retain all intellectual # property and proprietary rights in and to this material, related # documentation and any modifications thereto. Any use, reproduction, # disclosure or distribution of this material and related documentation # without an express license agreement from NVIDIA CORPORATION or # its affiliates is strictly prohibited. from threading import local import torch import torch.nn as nn from utils.torch_utils import persistence from .networks_stylegan2 import Generator as StyleGAN2Backbone from .networks_stylegan2 import ToRGBLayer, SynthesisNetwork, MappingNetwork from .volumetric_rendering.renderer import ImportanceRenderer from .volumetric_rendering.ray_sampler import RaySampler, PatchRaySampler import dnnlib from pdb import set_trace as st import math import torch.nn.functional as F import itertools from ldm.modules.diffusionmodules.model import SimpleDecoder, Decoder @persistence.persistent_class class TriPlaneGenerator(torch.nn.Module): def __init__( self, z_dim, # Input latent (Z) dimensionality. c_dim, # Conditioning label (C) dimensionality. w_dim, # Intermediate latent (W) dimensionality. img_resolution, # Output resolution. img_channels, # Number of output color channels. sr_num_fp16_res=0, mapping_kwargs={}, # Arguments for MappingNetwork. rendering_kwargs={}, sr_kwargs={}, bcg_synthesis_kwargs={}, # pifu_kwargs={}, # ada_kwargs={}, # not used, place holder **synthesis_kwargs, # Arguments for SynthesisNetwork. ): super().__init__() self.z_dim = z_dim self.c_dim = c_dim self.w_dim = w_dim self.img_resolution = img_resolution self.img_channels = img_channels self.renderer = ImportanceRenderer() # if 'PatchRaySampler' in rendering_kwargs: # self.ray_sampler = PatchRaySampler() # else: # self.ray_sampler = RaySampler() self.backbone = StyleGAN2Backbone(z_dim, c_dim, w_dim, img_resolution=256, img_channels=32 * 3, mapping_kwargs=mapping_kwargs, **synthesis_kwargs) self.superresolution = dnnlib.util.construct_class_by_name( class_name=rendering_kwargs['superresolution_module'], channels=32, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs) # self.bcg_synthesis = None if rendering_kwargs.get('use_background', False): self.bcg_synthesis = SynthesisNetwork( w_dim, img_resolution=self.superresolution.input_resolution, img_channels=32, **bcg_synthesis_kwargs) self.bcg_mapping = MappingNetwork(z_dim=z_dim, c_dim=c_dim, w_dim=w_dim, num_ws=self.num_ws, **mapping_kwargs) # New mapping network for self-adaptive camera pose, dim = 3 self.decoder = OSGDecoder( 32, { 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), 'decoder_output_dim': 32 }) self.neural_rendering_resolution = 64 self.rendering_kwargs = rendering_kwargs self._last_planes = None self.pool_256 = torch.nn.AdaptiveAvgPool2d((256, 256)) def mapping(self, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False): if self.rendering_kwargs['c_gen_conditioning_zero']: c = torch.zeros_like(c) return self.backbone.mapping(z, c * self.rendering_kwargs.get('c_scale', 0), truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) def synthesis(self, ws, c, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, return_meta=False, return_raw_only=False, **synthesis_kwargs): return_sampling_details_flag = self.rendering_kwargs.get( 'return_sampling_details_flag', False) if return_sampling_details_flag: return_meta = True cam2world_matrix = c[:, :16].view(-1, 4, 4) # cam2world_matrix = torch.eye(4, device=c.device).unsqueeze(0).repeat_interleave(c.shape[0], dim=0) # c[:, :16] = cam2world_matrix.view(-1, 16) intrinsics = c[:, 16:25].view(-1, 3, 3) if neural_rendering_resolution is None: neural_rendering_resolution = self.neural_rendering_resolution else: self.neural_rendering_resolution = neural_rendering_resolution H = W = self.neural_rendering_resolution # Create a batch of rays for volume rendering ray_origins, ray_directions = self.ray_sampler( cam2world_matrix, intrinsics, neural_rendering_resolution) # Create triplanes by running StyleGAN backbone N, M, _ = ray_origins.shape if use_cached_backbone and self._last_planes is not None: planes = self._last_planes else: planes = self.backbone.synthesis( ws[:, :self.backbone.num_ws, :], # ws, BS 14 512 update_emas=update_emas, **synthesis_kwargs) if cache_backbone: self._last_planes = planes # Reshape output into three 32-channel planes planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1]) # BS 96 256 256 # Perform volume rendering # st() rendering_details = self.renderer( planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs, # return_meta=True) return_meta=return_meta) # calibs = create_calib_matrix(c) # all_coords = rendering_details['all_coords'] # B, num_rays, S, _ = all_coords.shape # all_coords_B3N = all_coords.reshape(B, -1, 3).permute(0,2,1) # homo_coords = torch.cat([all_coords, torch.zeros_like(all_coords[..., :1])], -1) # homo_coords[..., -1] = 1 # homo_coords = homo_coords.reshape(homo_coords.shape[0], -1, 4) # homo_coords = homo_coords.permute(0,2,1) # xyz = calibs @ homo_coords # xyz = xyz.permute(0,2,1).reshape(B, H, W, S, 4) # st() # xyz_proj = perspective(all_coords_B3N, calibs) # xyz_proj = xyz_proj.permute(0,2,1).reshape(B, H, W, S, 3) # [0,0] - [1,1] # st() feature_samples, depth_samples, weights_samples = ( rendering_details[k] for k in ['feature_samples', 'depth_samples', 'weights_samples']) if return_sampling_details_flag: shape_synthesized = rendering_details['shape_synthesized'] else: shape_synthesized = None # Reshape into 'raw' neural-rendered image feature_image = feature_samples.permute(0, 2, 1).reshape( N, feature_samples.shape[-1], H, W).contiguous() # B 32 H W depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) # Run superresolution to get final image rgb_image = feature_image[:, :3] # B 3 H W if not return_raw_only: sr_image = self.superresolution( rgb_image, feature_image, ws[:, -1:, :], # only use the last layer noise_mode=self.rendering_kwargs['superresolution_noise_mode'], **{ k: synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode' }) else: sr_image = rgb_image ret_dict = { 'image': sr_image, 'image_raw': rgb_image, 'image_depth': depth_image, 'weights_samples': weights_samples, 'shape_synthesized': shape_synthesized } if return_meta: ret_dict.update({ # 'feature_image': feature_image, 'feature_volume': rendering_details['feature_volume'], 'all_coords': rendering_details['all_coords'], 'weights': rendering_details['weights'], }) return ret_dict def sample(self, coordinates, directions, z, c, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): # Compute RGB features, density for arbitrary 3D coordinates. Mostly used for extracting shapes. ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1]) return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) def sample_mixed(self, coordinates, directions, ws, truncation_psi=1, truncation_cutoff=None, update_emas=False, **synthesis_kwargs): # Same as sample, but expects latent vectors 'ws' instead of Gaussian noise 'z' planes = self.backbone.synthesis(ws, update_emas=update_emas, **synthesis_kwargs) planes = planes.view(len(planes), 3, 32, planes.shape[-2], planes.shape[-1]) return self.renderer.run_model(planes, self.decoder, coordinates, directions, self.rendering_kwargs) def forward(self, z, c, truncation_psi=1, truncation_cutoff=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, **synthesis_kwargs): # Render a batch of generated images. ws = self.mapping(z, c, truncation_psi=truncation_psi, truncation_cutoff=truncation_cutoff, update_emas=update_emas) return self.synthesis( ws, c, update_emas=update_emas, neural_rendering_resolution=neural_rendering_resolution, cache_backbone=cache_backbone, use_cached_backbone=use_cached_backbone, **synthesis_kwargs) from .networks_stylegan2 import FullyConnectedLayer # class OSGDecoder(torch.nn.Module): # def __init__(self, n_features, options): # super().__init__() # self.hidden_dim = 64 # self.output_dim = options['decoder_output_dim'] # self.n_features = n_features # self.net = torch.nn.Sequential( # FullyConnectedLayer(n_features, # self.hidden_dim, # lr_multiplier=options['decoder_lr_mul']), # torch.nn.Softplus(), # FullyConnectedLayer(self.hidden_dim, # 1 + options['decoder_output_dim'], # lr_multiplier=options['decoder_lr_mul'])) # def forward(self, sampled_features, ray_directions): # # Aggregate features # sampled_features = sampled_features.mean(1) # x = sampled_features # N, M, C = x.shape # x = x.view(N * M, C) # x = self.net(x) # x = x.view(N, M, -1) # rgb = torch.sigmoid(x[..., 1:]) * ( # 1 + 2 * 0.001) - 0.001 # Uses sigmoid clamping from MipNeRF # sigma = x[..., 0:1] # return {'rgb': rgb, 'sigma': sigma} @persistence.persistent_class class OSGDecoder(torch.nn.Module): def __init__(self, n_features, options): super().__init__() self.hidden_dim = 64 self.decoder_output_dim = options['decoder_output_dim'] self.net = torch.nn.Sequential( FullyConnectedLayer(n_features, self.hidden_dim, lr_multiplier=options['decoder_lr_mul']), torch.nn.Softplus(), FullyConnectedLayer(self.hidden_dim, 1 + options['decoder_output_dim'], lr_multiplier=options['decoder_lr_mul'])) self.activation = options.get('decoder_activation', 'sigmoid') def forward(self, sampled_features, ray_directions): # Aggregate features sampled_features = sampled_features.mean(1) x = sampled_features N, M, C = x.shape x = x.view(N * M, C) x = self.net(x) x = x.view(N, M, -1) rgb = x[..., 1:] sigma = x[..., 0:1] if self.activation == "sigmoid": # Original EG3D rgb = torch.sigmoid(rgb) * (1 + 2 * 0.001) - 0.001 elif self.activation == "lrelu": # StyleGAN2-style, use with toRGB rgb = torch.nn.functional.leaky_relu(rgb, 0.2, inplace=True) * math.sqrt(2) return {'rgb': rgb, 'sigma': sigma} class LRMOSGDecoder(nn.Module): """ Triplane decoder that gives RGB and sigma values from sampled features. Using ReLU here instead of Softplus in the original implementation. Reference: EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 """ def __init__(self, n_features: int, hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): super().__init__() self.decoder_output_dim = 3 self.net = nn.Sequential( nn.Linear(3 * n_features, hidden_dim), activation(), *itertools.chain(*[[ nn.Linear(hidden_dim, hidden_dim), activation(), ] for _ in range(num_layers - 2)]), nn.Linear(hidden_dim, 1 + self.decoder_output_dim), ) # init all bias to zero for m in self.modules(): if isinstance(m, nn.Linear): nn.init.zeros_(m.bias) def forward(self, sampled_features, ray_directions): # Aggregate features by mean # sampled_features = sampled_features.mean(1) # Aggregate features by concatenation _N, n_planes, _M, _C = sampled_features.shape sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) x = sampled_features N, M, C = x.shape x = x.contiguous().view(N*M, C) x = self.net(x) x = x.view(N, M, -1) rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 # Uses sigmoid clamping from MipNeRF sigma = x[..., 0:1] return {'rgb': rgb, 'sigma': sigma} class Triplane(torch.nn.Module): def __init__( self, c_dim=25, # Conditioning label (C) dimensionality. img_resolution=128, # Output resolution. img_channels=3, # Number of output color channels. out_chans=96, triplane_size=224, rendering_kwargs={}, decoder_in_chans=32, decoder_output_dim=32, sr_num_fp16_res=0, sr_kwargs={}, create_triplane=False, # for overfitting single instance study bcg_synthesis_kwargs={}, lrm_decoder=False, ): super().__init__() self.c_dim = c_dim self.img_resolution = img_resolution # TODO self.img_channels = img_channels self.triplane_size = triplane_size self.decoder_in_chans = decoder_in_chans self.out_chans = out_chans self.renderer = ImportanceRenderer() if 'PatchRaySampler' in rendering_kwargs: self.ray_sampler = PatchRaySampler() else: self.ray_sampler = RaySampler() if lrm_decoder: self.decoder = LRMOSGDecoder( decoder_in_chans,) else: self.decoder = OSGDecoder( decoder_in_chans, { 'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), # 'decoder_output_dim': 32 'decoder_output_dim': decoder_output_dim }) self.neural_rendering_resolution = img_resolution # TODO # self.neural_rendering_resolution = 128 # TODO self.rendering_kwargs = rendering_kwargs self.create_triplane = create_triplane if create_triplane: self.planes = nn.Parameter(torch.randn(1, out_chans, 256, 256)) if bool(sr_kwargs): # check whether empty assert decoder_in_chans == decoder_output_dim, 'tradition' if rendering_kwargs['superresolution_module'] in [ 'utils.torch_utils.components.PixelUnshuffleUpsample', 'utils.torch_utils.components.NearestConvSR', 'utils.torch_utils.components.NearestConvSR_Residual' ]: self.superresolution = dnnlib.util.construct_class_by_name( class_name=rendering_kwargs['superresolution_module'], # * for PixelUnshuffleUpsample sr_ratio=2, # 2x SR, 128 -> 256 output_dim=decoder_output_dim, num_out_ch=3, ) else: self.superresolution = dnnlib.util.construct_class_by_name( class_name=rendering_kwargs['superresolution_module'], # * for stylegan upsample channels=decoder_output_dim, img_resolution=img_resolution, sr_num_fp16_res=sr_num_fp16_res, sr_antialias=rendering_kwargs['sr_antialias'], **sr_kwargs) else: self.superresolution = None self.bcg_synthesis = None # * pure reconstruction def forward( self, planes=None, # img, c=None, ws=None, ray_origins=None, ray_directions=None, z_bcg=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, return_meta=False, return_raw_only=False, sample_ray_only=False, fg_bbox=None, **synthesis_kwargs): cam2world_matrix = c[:, :16].reshape(-1, 4, 4) # cam2world_matrix = torch.eye(4, device=c.device).unsqueeze(0).repeat_interleave(c.shape[0], dim=0) # c[:, :16] = cam2world_matrix.view(-1, 16) intrinsics = c[:, 16:25].reshape(-1, 3, 3) if neural_rendering_resolution is None: neural_rendering_resolution = self.neural_rendering_resolution else: self.neural_rendering_resolution = neural_rendering_resolution if ray_directions is None: # when output video H = W = self.neural_rendering_resolution # Create a batch of rays for volume rendering # ray_origins, ray_directions, ray_bboxes = self.ray_sampler( # cam2world_matrix, intrinsics, neural_rendering_resolution) if sample_ray_only: # ! for sampling ray_origins, ray_directions, ray_bboxes = self.ray_sampler( cam2world_matrix, intrinsics, self.rendering_kwargs.get( 'patch_rendering_resolution' ), self.neural_rendering_resolution, fg_bbox) # for patch supervision ret_dict = { 'ray_origins': ray_origins, 'ray_directions': ray_directions, 'ray_bboxes': ray_bboxes, } return ret_dict else: # ! for rendering ray_origins, ray_directions, _ = self.ray_sampler( cam2world_matrix, intrinsics, self.neural_rendering_resolution, self.neural_rendering_resolution) else: assert ray_origins is not None H = W = int(ray_directions.shape[1]** 0.5) # dynamically set patch resolution # ! match the batch size, if not returned if planes is None: assert self.planes is not None planes = self.planes.repeat_interleave(c.shape[0], dim=0) return_sampling_details_flag = self.rendering_kwargs.get( 'return_sampling_details_flag', False) if return_sampling_details_flag: return_meta = True # Create triplanes by running StyleGAN backbone N, M, _ = ray_origins.shape # Reshape output into three 32-channel planes if planes.shape[1] == 3 * 2 * self.decoder_in_chans: # if isinstance(planes, tuple): # N *= 2 triplane_bg = True # planes = torch.cat(planes, 0) # inference in parallel # ray_origins = ray_origins.repeat(2,1,1) # ray_directions = ray_directions.repeat(2,1,1) else: triplane_bg = False # assert not triplane_bg # ! hard coded, will fix later # if planes.shape[1] == 3 * self.decoder_in_chans: # else: # planes = planes.view(len(planes), 3, self.decoder_in_chans, planes = planes.reshape( len(planes), 3, -1, # ! support background plane planes.shape[-2], planes.shape[-1]) # BS 96 256 256 # Perform volume rendering rendering_details = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs, return_meta=return_meta) feature_samples, depth_samples, weights_samples = ( rendering_details[k] for k in ['feature_samples', 'depth_samples', 'weights_samples']) if return_sampling_details_flag: shape_synthesized = rendering_details['shape_synthesized'] else: shape_synthesized = None # Reshape into 'raw' neural-rendered image feature_image = feature_samples.permute(0, 2, 1).reshape( N, feature_samples.shape[-1], H, W).contiguous() # B 32 H W, in [-1,1] depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) # Generate Background # if self.bcg_synthesis: # # bg composition # # if self.decoder.activation == "sigmoid": # # feature_image = feature_image * 2 - 1 # Scale to (-1, 1), taken from ray marcher # assert isinstance( # z_bcg, torch.Tensor # ) # 512 latents after reparmaterization, reuse the name # # ws_bcg = ws[:,:self.bcg_synthesis.num_ws] if ws_bcg is None else ws_bcg[:,:self.bcg_synthesis.num_ws] # with torch.autocast(device_type='cuda', # dtype=torch.float16, # enabled=False): # ws_bcg = self.bcg_mapping(z_bcg, c=None) # reuse the name # if ws_bcg.size(1) < self.bcg_synthesis.num_ws: # ws_bcg = torch.cat([ # ws_bcg, ws_bcg[:, -1:].repeat( # 1, self.bcg_synthesis.num_ws - ws_bcg.size(1), 1) # ], 1) # bcg_image = self.bcg_synthesis(ws_bcg, # update_emas=update_emas, # **synthesis_kwargs) # bcg_image = torch.nn.functional.interpolate( # bcg_image, # size=feature_image.shape[2:], # mode='bilinear', # align_corners=False, # antialias=self.rendering_kwargs['sr_antialias']) # feature_image = feature_image + (1 - weights_samples) * bcg_image # # Generate Raw image # assert self.torgb # rgb_image = self.torgb(feature_image, # ws_bcg[:, -1], # fused_modconv=False) # rgb_image = rgb_image.to(dtype=torch.float32, # memory_format=torch.contiguous_format) # # st() # else: mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 if triplane_bg: # true_bs = N // 2 # weights_samples = weights_samples[:true_bs] # mask_image = mask_image[:true_bs] # feature_image = feature_image[:true_bs] * mask_image + feature_image[true_bs:] * (1-mask_image) # the first is foreground # depth_image = depth_image[:true_bs] # ! composited colors # rgb_final = ( # 1 - fg_ret_dict['weights'] # ) * bg_ret_dict['rgb_final'] + fg_ret_dict[ # 'feature_samples'] # https://github.com/SizheAn/PanoHead/blob/17ad915941c7e2703d5aa3eb5ff12eac47c90e53/training/triplane.py#L127C45-L127C64 # ret_dict.update({ # 'feature_samples': rgb_final, # }) # st() feature_image = (1 - mask_image) * rendering_details[ 'bg_ret_dict']['rgb_final'] + feature_image rgb_image = feature_image[:, :3] # # Run superresolution to get final image if self.superresolution is not None and not return_raw_only: # assert ws is not None, 'feed in [cls] token here for SR module' if ws is not None and ws.ndim == 2: ws = ws.unsqueeze( 1)[:, -1:, :] # follow stylegan tradition, B, N, C sr_image = self.superresolution( rgb=rgb_image, x=feature_image, base_x=rgb_image, ws=ws, # only use the last layer noise_mode=self. rendering_kwargs['superresolution_noise_mode'], # none **{ k: synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode' }) else: # sr_image = rgb_image sr_image = None if shape_synthesized is not None: shape_synthesized.update({ 'image_depth': depth_image, }) # for 3D loss easy computation, wrap all 3D in a single dict ret_dict = { 'feature_image': feature_image, # 'image_raw': feature_image[:, :3], 'image_raw': rgb_image, 'image_depth': depth_image, 'weights_samples': weights_samples, # 'silhouette': mask_image, # 'silhouette_normalized_3channel': (mask_image*2-1).repeat_interleave(3,1), # N 3 H W 'shape_synthesized': shape_synthesized, "image_mask": mask_image, } if sr_image is not None: ret_dict.update({ 'image_sr': sr_image, }) if return_meta: ret_dict.update({ 'feature_volume': rendering_details['feature_volume'], 'all_coords': rendering_details['all_coords'], 'weights': rendering_details['weights'], }) return ret_dict class Triplane_fg_bg_plane(Triplane): # a separate background plane def __init__(self, c_dim=25, img_resolution=128, img_channels=3, out_chans=96, triplane_size=224, rendering_kwargs={}, decoder_in_chans=32, decoder_output_dim=32, sr_num_fp16_res=0, sr_kwargs={}, bcg_synthesis_kwargs={}): super().__init__(c_dim, img_resolution, img_channels, out_chans, triplane_size, rendering_kwargs, decoder_in_chans, decoder_output_dim, sr_num_fp16_res, sr_kwargs, bcg_synthesis_kwargs) self.bcg_decoder = Decoder( ch=64, # half channel size out_ch=32, # ch_mult=(1, 2, 4), ch_mult=(1, 2), # use res=64 for now num_res_blocks=2, dropout=0.0, attn_resolutions=(), z_channels=4, resolution=64, in_channels=3, ) # * pure reconstruction def forward( self, planes, bg_plane, # img, c, ws=None, z_bcg=None, neural_rendering_resolution=None, update_emas=False, cache_backbone=False, use_cached_backbone=False, return_meta=False, return_raw_only=False, **synthesis_kwargs): # ! match the batch size if planes is None: assert self.planes is not None planes = self.planes.repeat_interleave(c.shape[0], dim=0) return_sampling_details_flag = self.rendering_kwargs.get( 'return_sampling_details_flag', False) if return_sampling_details_flag: return_meta = True cam2world_matrix = c[:, :16].reshape(-1, 4, 4) # cam2world_matrix = torch.eye(4, device=c.device).unsqueeze(0).repeat_interleave(c.shape[0], dim=0) # c[:, :16] = cam2world_matrix.view(-1, 16) intrinsics = c[:, 16:25].reshape(-1, 3, 3) if neural_rendering_resolution is None: neural_rendering_resolution = self.neural_rendering_resolution else: self.neural_rendering_resolution = neural_rendering_resolution H = W = self.neural_rendering_resolution # Create a batch of rays for volume rendering ray_origins, ray_directions, _ = self.ray_sampler( cam2world_matrix, intrinsics, neural_rendering_resolution) # Create triplanes by running StyleGAN backbone N, M, _ = ray_origins.shape # # Reshape output into three 32-channel planes # if planes.shape[1] == 3 * 2 * self.decoder_in_chans: # # if isinstance(planes, tuple): # # N *= 2 # triplane_bg = True # # planes = torch.cat(planes, 0) # inference in parallel # # ray_origins = ray_origins.repeat(2,1,1) # # ray_directions = ray_directions.repeat(2,1,1) # else: # triplane_bg = False # assert not triplane_bg planes = planes.view( len(planes), 3, -1, # ! support background plane planes.shape[-2], planes.shape[-1]) # BS 96 256 256 # Perform volume rendering rendering_details = self.renderer(planes, self.decoder, ray_origins, ray_directions, self.rendering_kwargs, return_meta=return_meta) feature_samples, depth_samples, weights_samples = ( rendering_details[k] for k in ['feature_samples', 'depth_samples', 'weights_samples']) if return_sampling_details_flag: shape_synthesized = rendering_details['shape_synthesized'] else: shape_synthesized = None # Reshape into 'raw' neural-rendered image feature_image = feature_samples.permute(0, 2, 1).reshape( N, feature_samples.shape[-1], H, W).contiguous() # B 32 H W, in [-1,1] depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) bcg_image = self.bcg_decoder(bg_plane) bcg_image = torch.nn.functional.interpolate( bcg_image, size=feature_image.shape[2:], mode='bilinear', align_corners=False, antialias=self.rendering_kwargs['sr_antialias']) mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 # ! fuse fg/bg model output feature_image = feature_image + (1 - weights_samples) * bcg_image rgb_image = feature_image[:, :3] # # Run superresolution to get final image if self.superresolution is not None and not return_raw_only: # assert ws is not None, 'feed in [cls] token here for SR module' if ws is not None and ws.ndim == 2: ws = ws.unsqueeze( 1)[:, -1:, :] # follow stylegan tradition, B, N, C sr_image = self.superresolution( rgb=rgb_image, x=feature_image, base_x=rgb_image, ws=ws, # only use the last layer noise_mode=self. rendering_kwargs['superresolution_noise_mode'], # none **{ k: synthesis_kwargs[k] for k in synthesis_kwargs.keys() if k != 'noise_mode' }) else: # sr_image = rgb_image sr_image = None if shape_synthesized is not None: shape_synthesized.update({ 'image_depth': depth_image, }) # for 3D loss easy computation, wrap all 3D in a single dict ret_dict = { 'feature_image': feature_image, # 'image_raw': feature_image[:, :3], 'image_raw': rgb_image, 'image_depth': depth_image, 'weights_samples': weights_samples, # 'silhouette': mask_image, # 'silhouette_normalized_3channel': (mask_image*2-1).repeat_interleave(3,1), # N 3 H W 'shape_synthesized': shape_synthesized, "image_mask": mask_image, } if sr_image is not None: ret_dict.update({ 'image_sr': sr_image, }) if return_meta: ret_dict.update({ 'feature_volume': rendering_details['feature_volume'], 'all_coords': rendering_details['all_coords'], 'weights': rendering_details['weights'], }) return ret_dict