|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
from threading import local |
|
import torch |
|
import torch.nn as nn |
|
from utils.torch_utils import persistence |
|
from .networks_stylegan2 import Generator as StyleGAN2Backbone |
|
from .networks_stylegan2 import ToRGBLayer, SynthesisNetwork, MappingNetwork |
|
from .volumetric_rendering.renderer import ImportanceRenderer |
|
from .volumetric_rendering.ray_sampler import RaySampler, PatchRaySampler |
|
import dnnlib |
|
from pdb import set_trace as st |
|
import math |
|
|
|
import torch.nn.functional as F |
|
import itertools |
|
from ldm.modules.diffusionmodules.model import SimpleDecoder, Decoder |
|
|
|
|
|
@persistence.persistent_class |
|
class TriPlaneGenerator(torch.nn.Module): |
|
|
|
def __init__( |
|
self, |
|
z_dim, |
|
c_dim, |
|
w_dim, |
|
img_resolution, |
|
img_channels, |
|
sr_num_fp16_res=0, |
|
mapping_kwargs={}, |
|
rendering_kwargs={}, |
|
sr_kwargs={}, |
|
bcg_synthesis_kwargs={}, |
|
|
|
|
|
**synthesis_kwargs, |
|
): |
|
super().__init__() |
|
self.z_dim = z_dim |
|
self.c_dim = c_dim |
|
self.w_dim = w_dim |
|
self.img_resolution = img_resolution |
|
self.img_channels = img_channels |
|
self.renderer = ImportanceRenderer() |
|
|
|
|
|
|
|
|
|
self.backbone = StyleGAN2Backbone(z_dim, |
|
c_dim, |
|
w_dim, |
|
img_resolution=256, |
|
img_channels=32 * 3, |
|
mapping_kwargs=mapping_kwargs, |
|
**synthesis_kwargs) |
|
self.superresolution = dnnlib.util.construct_class_by_name( |
|
class_name=rendering_kwargs['superresolution_module'], |
|
channels=32, |
|
img_resolution=img_resolution, |
|
sr_num_fp16_res=sr_num_fp16_res, |
|
sr_antialias=rendering_kwargs['sr_antialias'], |
|
**sr_kwargs) |
|
|
|
|
|
if rendering_kwargs.get('use_background', False): |
|
self.bcg_synthesis = SynthesisNetwork( |
|
w_dim, |
|
img_resolution=self.superresolution.input_resolution, |
|
img_channels=32, |
|
**bcg_synthesis_kwargs) |
|
self.bcg_mapping = MappingNetwork(z_dim=z_dim, |
|
c_dim=c_dim, |
|
w_dim=w_dim, |
|
num_ws=self.num_ws, |
|
**mapping_kwargs) |
|
|
|
|
|
self.decoder = OSGDecoder( |
|
32, { |
|
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), |
|
'decoder_output_dim': 32 |
|
}) |
|
self.neural_rendering_resolution = 64 |
|
self.rendering_kwargs = rendering_kwargs |
|
|
|
self._last_planes = None |
|
self.pool_256 = torch.nn.AdaptiveAvgPool2d((256, 256)) |
|
|
|
def mapping(self, |
|
z, |
|
c, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
update_emas=False): |
|
if self.rendering_kwargs['c_gen_conditioning_zero']: |
|
c = torch.zeros_like(c) |
|
return self.backbone.mapping(z, |
|
c * |
|
self.rendering_kwargs.get('c_scale', 0), |
|
truncation_psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, |
|
update_emas=update_emas) |
|
|
|
def synthesis(self, |
|
ws, |
|
c, |
|
neural_rendering_resolution=None, |
|
update_emas=False, |
|
cache_backbone=False, |
|
use_cached_backbone=False, |
|
return_meta=False, |
|
return_raw_only=False, |
|
**synthesis_kwargs): |
|
|
|
return_sampling_details_flag = self.rendering_kwargs.get( |
|
'return_sampling_details_flag', False) |
|
|
|
if return_sampling_details_flag: |
|
return_meta = True |
|
|
|
cam2world_matrix = c[:, :16].view(-1, 4, 4) |
|
|
|
|
|
intrinsics = c[:, 16:25].view(-1, 3, 3) |
|
|
|
if neural_rendering_resolution is None: |
|
neural_rendering_resolution = self.neural_rendering_resolution |
|
else: |
|
self.neural_rendering_resolution = neural_rendering_resolution |
|
|
|
H = W = self.neural_rendering_resolution |
|
|
|
ray_origins, ray_directions = self.ray_sampler( |
|
cam2world_matrix, intrinsics, neural_rendering_resolution) |
|
|
|
|
|
N, M, _ = ray_origins.shape |
|
if use_cached_backbone and self._last_planes is not None: |
|
planes = self._last_planes |
|
else: |
|
planes = self.backbone.synthesis( |
|
ws[:, :self.backbone.num_ws, :], |
|
update_emas=update_emas, |
|
**synthesis_kwargs) |
|
if cache_backbone: |
|
self._last_planes = planes |
|
|
|
|
|
planes = planes.view(len(planes), 3, 32, planes.shape[-2], |
|
planes.shape[-1]) |
|
|
|
|
|
|
|
rendering_details = self.renderer( |
|
planes, |
|
self.decoder, |
|
ray_origins, |
|
ray_directions, |
|
self.rendering_kwargs, |
|
|
|
return_meta=return_meta) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
feature_samples, depth_samples, weights_samples = ( |
|
rendering_details[k] |
|
for k in ['feature_samples', 'depth_samples', 'weights_samples']) |
|
|
|
if return_sampling_details_flag: |
|
shape_synthesized = rendering_details['shape_synthesized'] |
|
else: |
|
shape_synthesized = None |
|
|
|
|
|
feature_image = feature_samples.permute(0, 2, 1).reshape( |
|
N, feature_samples.shape[-1], H, W).contiguous() |
|
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
|
|
|
|
rgb_image = feature_image[:, :3] |
|
if not return_raw_only: |
|
sr_image = self.superresolution( |
|
rgb_image, |
|
feature_image, |
|
ws[:, -1:, :], |
|
noise_mode=self.rendering_kwargs['superresolution_noise_mode'], |
|
**{ |
|
k: synthesis_kwargs[k] |
|
for k in synthesis_kwargs.keys() if k != 'noise_mode' |
|
}) |
|
else: |
|
sr_image = rgb_image |
|
|
|
ret_dict = { |
|
'image': sr_image, |
|
'image_raw': rgb_image, |
|
'image_depth': depth_image, |
|
'weights_samples': weights_samples, |
|
'shape_synthesized': shape_synthesized |
|
} |
|
if return_meta: |
|
ret_dict.update({ |
|
|
|
'feature_volume': |
|
rendering_details['feature_volume'], |
|
'all_coords': |
|
rendering_details['all_coords'], |
|
'weights': |
|
rendering_details['weights'], |
|
}) |
|
|
|
return ret_dict |
|
|
|
def sample(self, |
|
coordinates, |
|
directions, |
|
z, |
|
c, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
update_emas=False, |
|
**synthesis_kwargs): |
|
|
|
ws = self.mapping(z, |
|
c, |
|
truncation_psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, |
|
update_emas=update_emas) |
|
planes = self.backbone.synthesis(ws, |
|
update_emas=update_emas, |
|
**synthesis_kwargs) |
|
planes = planes.view(len(planes), 3, 32, planes.shape[-2], |
|
planes.shape[-1]) |
|
return self.renderer.run_model(planes, self.decoder, coordinates, |
|
directions, self.rendering_kwargs) |
|
|
|
def sample_mixed(self, |
|
coordinates, |
|
directions, |
|
ws, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
update_emas=False, |
|
**synthesis_kwargs): |
|
|
|
planes = self.backbone.synthesis(ws, |
|
update_emas=update_emas, |
|
**synthesis_kwargs) |
|
planes = planes.view(len(planes), 3, 32, planes.shape[-2], |
|
planes.shape[-1]) |
|
return self.renderer.run_model(planes, self.decoder, coordinates, |
|
directions, self.rendering_kwargs) |
|
|
|
def forward(self, |
|
z, |
|
c, |
|
truncation_psi=1, |
|
truncation_cutoff=None, |
|
neural_rendering_resolution=None, |
|
update_emas=False, |
|
cache_backbone=False, |
|
use_cached_backbone=False, |
|
**synthesis_kwargs): |
|
|
|
ws = self.mapping(z, |
|
c, |
|
truncation_psi=truncation_psi, |
|
truncation_cutoff=truncation_cutoff, |
|
update_emas=update_emas) |
|
return self.synthesis( |
|
ws, |
|
c, |
|
update_emas=update_emas, |
|
neural_rendering_resolution=neural_rendering_resolution, |
|
cache_backbone=cache_backbone, |
|
use_cached_backbone=use_cached_backbone, |
|
**synthesis_kwargs) |
|
|
|
|
|
from .networks_stylegan2 import FullyConnectedLayer |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
@persistence.persistent_class |
|
class OSGDecoder(torch.nn.Module): |
|
|
|
def __init__(self, n_features, options): |
|
super().__init__() |
|
self.hidden_dim = 64 |
|
self.decoder_output_dim = options['decoder_output_dim'] |
|
|
|
self.net = torch.nn.Sequential( |
|
FullyConnectedLayer(n_features, |
|
self.hidden_dim, |
|
lr_multiplier=options['decoder_lr_mul']), |
|
torch.nn.Softplus(), |
|
FullyConnectedLayer(self.hidden_dim, |
|
1 + options['decoder_output_dim'], |
|
lr_multiplier=options['decoder_lr_mul'])) |
|
self.activation = options.get('decoder_activation', 'sigmoid') |
|
|
|
def forward(self, sampled_features, ray_directions): |
|
|
|
sampled_features = sampled_features.mean(1) |
|
x = sampled_features |
|
|
|
N, M, C = x.shape |
|
x = x.view(N * M, C) |
|
|
|
x = self.net(x) |
|
x = x.view(N, M, -1) |
|
rgb = x[..., 1:] |
|
sigma = x[..., 0:1] |
|
if self.activation == "sigmoid": |
|
|
|
rgb = torch.sigmoid(rgb) * (1 + 2 * 0.001) - 0.001 |
|
elif self.activation == "lrelu": |
|
|
|
rgb = torch.nn.functional.leaky_relu(rgb, 0.2, |
|
inplace=True) * math.sqrt(2) |
|
return {'rgb': rgb, 'sigma': sigma} |
|
|
|
|
|
class LRMOSGDecoder(nn.Module): |
|
""" |
|
Triplane decoder that gives RGB and sigma values from sampled features. |
|
Using ReLU here instead of Softplus in the original implementation. |
|
|
|
Reference: |
|
EG3D: https://github.com/NVlabs/eg3d/blob/main/eg3d/training/triplane.py#L112 |
|
""" |
|
def __init__(self, n_features: int, |
|
hidden_dim: int = 64, num_layers: int = 4, activation: nn.Module = nn.ReLU): |
|
super().__init__() |
|
self.decoder_output_dim = 3 |
|
self.net = nn.Sequential( |
|
nn.Linear(3 * n_features, hidden_dim), |
|
activation(), |
|
*itertools.chain(*[[ |
|
nn.Linear(hidden_dim, hidden_dim), |
|
activation(), |
|
] for _ in range(num_layers - 2)]), |
|
nn.Linear(hidden_dim, 1 + self.decoder_output_dim), |
|
) |
|
|
|
for m in self.modules(): |
|
if isinstance(m, nn.Linear): |
|
nn.init.zeros_(m.bias) |
|
|
|
def forward(self, sampled_features, ray_directions): |
|
|
|
|
|
|
|
_N, n_planes, _M, _C = sampled_features.shape |
|
sampled_features = sampled_features.permute(0, 2, 1, 3).reshape(_N, _M, n_planes*_C) |
|
x = sampled_features |
|
|
|
N, M, C = x.shape |
|
x = x.contiguous().view(N*M, C) |
|
|
|
x = self.net(x) |
|
x = x.view(N, M, -1) |
|
rgb = torch.sigmoid(x[..., 1:])*(1 + 2*0.001) - 0.001 |
|
sigma = x[..., 0:1] |
|
|
|
return {'rgb': rgb, 'sigma': sigma} |
|
|
|
|
|
class Triplane(torch.nn.Module): |
|
|
|
def __init__( |
|
self, |
|
c_dim=25, |
|
img_resolution=128, |
|
img_channels=3, |
|
out_chans=96, |
|
triplane_size=224, |
|
rendering_kwargs={}, |
|
decoder_in_chans=32, |
|
decoder_output_dim=32, |
|
sr_num_fp16_res=0, |
|
sr_kwargs={}, |
|
create_triplane=False, |
|
bcg_synthesis_kwargs={}, |
|
lrm_decoder=False, |
|
): |
|
super().__init__() |
|
self.c_dim = c_dim |
|
self.img_resolution = img_resolution |
|
self.img_channels = img_channels |
|
self.triplane_size = triplane_size |
|
|
|
self.decoder_in_chans = decoder_in_chans |
|
self.out_chans = out_chans |
|
|
|
self.renderer = ImportanceRenderer() |
|
|
|
if 'PatchRaySampler' in rendering_kwargs: |
|
self.ray_sampler = PatchRaySampler() |
|
else: |
|
self.ray_sampler = RaySampler() |
|
|
|
if lrm_decoder: |
|
self.decoder = LRMOSGDecoder( |
|
decoder_in_chans,) |
|
else: |
|
self.decoder = OSGDecoder( |
|
decoder_in_chans, |
|
{ |
|
'decoder_lr_mul': rendering_kwargs.get('decoder_lr_mul', 1), |
|
|
|
'decoder_output_dim': decoder_output_dim |
|
}) |
|
|
|
self.neural_rendering_resolution = img_resolution |
|
|
|
self.rendering_kwargs = rendering_kwargs |
|
self.create_triplane = create_triplane |
|
if create_triplane: |
|
self.planes = nn.Parameter(torch.randn(1, out_chans, 256, 256)) |
|
|
|
if bool(sr_kwargs): |
|
assert decoder_in_chans == decoder_output_dim, 'tradition' |
|
if rendering_kwargs['superresolution_module'] in [ |
|
'utils.torch_utils.components.PixelUnshuffleUpsample', |
|
'utils.torch_utils.components.NearestConvSR', |
|
'utils.torch_utils.components.NearestConvSR_Residual' |
|
]: |
|
self.superresolution = dnnlib.util.construct_class_by_name( |
|
class_name=rendering_kwargs['superresolution_module'], |
|
|
|
sr_ratio=2, |
|
output_dim=decoder_output_dim, |
|
num_out_ch=3, |
|
) |
|
else: |
|
self.superresolution = dnnlib.util.construct_class_by_name( |
|
class_name=rendering_kwargs['superresolution_module'], |
|
|
|
channels=decoder_output_dim, |
|
img_resolution=img_resolution, |
|
sr_num_fp16_res=sr_num_fp16_res, |
|
sr_antialias=rendering_kwargs['sr_antialias'], |
|
**sr_kwargs) |
|
else: |
|
self.superresolution = None |
|
|
|
self.bcg_synthesis = None |
|
|
|
|
|
def forward( |
|
self, |
|
planes=None, |
|
|
|
c=None, |
|
ws=None, |
|
ray_origins=None, |
|
ray_directions=None, |
|
z_bcg=None, |
|
neural_rendering_resolution=None, |
|
update_emas=False, |
|
cache_backbone=False, |
|
use_cached_backbone=False, |
|
return_meta=False, |
|
return_raw_only=False, |
|
sample_ray_only=False, |
|
fg_bbox=None, |
|
**synthesis_kwargs): |
|
|
|
cam2world_matrix = c[:, :16].reshape(-1, 4, 4) |
|
|
|
|
|
intrinsics = c[:, 16:25].reshape(-1, 3, 3) |
|
|
|
if neural_rendering_resolution is None: |
|
neural_rendering_resolution = self.neural_rendering_resolution |
|
else: |
|
self.neural_rendering_resolution = neural_rendering_resolution |
|
|
|
if ray_directions is None: |
|
H = W = self.neural_rendering_resolution |
|
|
|
|
|
|
|
|
|
if sample_ray_only: |
|
ray_origins, ray_directions, ray_bboxes = self.ray_sampler( |
|
cam2world_matrix, intrinsics, |
|
self.rendering_kwargs.get( 'patch_rendering_resolution' ), |
|
self.neural_rendering_resolution, fg_bbox) |
|
|
|
|
|
ret_dict = { |
|
'ray_origins': ray_origins, |
|
'ray_directions': ray_directions, |
|
'ray_bboxes': ray_bboxes, |
|
} |
|
|
|
return ret_dict |
|
|
|
else: |
|
ray_origins, ray_directions, _ = self.ray_sampler( |
|
cam2world_matrix, intrinsics, self.neural_rendering_resolution, |
|
self.neural_rendering_resolution) |
|
|
|
else: |
|
assert ray_origins is not None |
|
H = W = int(ray_directions.shape[1]** |
|
0.5) |
|
|
|
|
|
if planes is None: |
|
assert self.planes is not None |
|
planes = self.planes.repeat_interleave(c.shape[0], dim=0) |
|
return_sampling_details_flag = self.rendering_kwargs.get( |
|
'return_sampling_details_flag', False) |
|
|
|
if return_sampling_details_flag: |
|
return_meta = True |
|
|
|
|
|
N, M, _ = ray_origins.shape |
|
|
|
|
|
if planes.shape[1] == 3 * 2 * self.decoder_in_chans: |
|
|
|
|
|
triplane_bg = True |
|
|
|
|
|
|
|
|
|
else: |
|
triplane_bg = False |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
planes = planes.reshape( |
|
len(planes), |
|
3, |
|
-1, |
|
planes.shape[-2], |
|
planes.shape[-1]) |
|
|
|
|
|
rendering_details = self.renderer(planes, |
|
self.decoder, |
|
ray_origins, |
|
ray_directions, |
|
self.rendering_kwargs, |
|
return_meta=return_meta) |
|
|
|
feature_samples, depth_samples, weights_samples = ( |
|
rendering_details[k] |
|
for k in ['feature_samples', 'depth_samples', 'weights_samples']) |
|
|
|
if return_sampling_details_flag: |
|
shape_synthesized = rendering_details['shape_synthesized'] |
|
else: |
|
shape_synthesized = None |
|
|
|
|
|
feature_image = feature_samples.permute(0, 2, 1).reshape( |
|
N, feature_samples.shape[-1], H, |
|
W).contiguous() |
|
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 |
|
if triplane_bg: |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
feature_image = (1 - mask_image) * rendering_details[ |
|
'bg_ret_dict']['rgb_final'] + feature_image |
|
|
|
rgb_image = feature_image[:, :3] |
|
|
|
|
|
if self.superresolution is not None and not return_raw_only: |
|
|
|
|
|
if ws is not None and ws.ndim == 2: |
|
ws = ws.unsqueeze( |
|
1)[:, -1:, :] |
|
|
|
sr_image = self.superresolution( |
|
rgb=rgb_image, |
|
x=feature_image, |
|
base_x=rgb_image, |
|
ws=ws, |
|
noise_mode=self. |
|
rendering_kwargs['superresolution_noise_mode'], |
|
**{ |
|
k: synthesis_kwargs[k] |
|
for k in synthesis_kwargs.keys() if k != 'noise_mode' |
|
}) |
|
else: |
|
|
|
sr_image = None |
|
|
|
if shape_synthesized is not None: |
|
shape_synthesized.update({ |
|
'image_depth': depth_image, |
|
}) |
|
|
|
ret_dict = { |
|
'feature_image': feature_image, |
|
|
|
'image_raw': rgb_image, |
|
'image_depth': depth_image, |
|
'weights_samples': weights_samples, |
|
|
|
|
|
'shape_synthesized': shape_synthesized, |
|
"image_mask": mask_image, |
|
} |
|
|
|
if sr_image is not None: |
|
ret_dict.update({ |
|
'image_sr': sr_image, |
|
}) |
|
|
|
if return_meta: |
|
ret_dict.update({ |
|
'feature_volume': |
|
rendering_details['feature_volume'], |
|
'all_coords': |
|
rendering_details['all_coords'], |
|
'weights': |
|
rendering_details['weights'], |
|
}) |
|
|
|
return ret_dict |
|
|
|
|
|
class Triplane_fg_bg_plane(Triplane): |
|
|
|
|
|
def __init__(self, |
|
c_dim=25, |
|
img_resolution=128, |
|
img_channels=3, |
|
out_chans=96, |
|
triplane_size=224, |
|
rendering_kwargs={}, |
|
decoder_in_chans=32, |
|
decoder_output_dim=32, |
|
sr_num_fp16_res=0, |
|
sr_kwargs={}, |
|
bcg_synthesis_kwargs={}): |
|
super().__init__(c_dim, img_resolution, img_channels, out_chans, |
|
triplane_size, rendering_kwargs, decoder_in_chans, |
|
decoder_output_dim, sr_num_fp16_res, sr_kwargs, |
|
bcg_synthesis_kwargs) |
|
|
|
self.bcg_decoder = Decoder( |
|
ch=64, |
|
out_ch=32, |
|
|
|
ch_mult=(1, 2), |
|
num_res_blocks=2, |
|
dropout=0.0, |
|
attn_resolutions=(), |
|
z_channels=4, |
|
resolution=64, |
|
in_channels=3, |
|
) |
|
|
|
|
|
def forward( |
|
self, |
|
planes, |
|
bg_plane, |
|
|
|
c, |
|
ws=None, |
|
z_bcg=None, |
|
neural_rendering_resolution=None, |
|
update_emas=False, |
|
cache_backbone=False, |
|
use_cached_backbone=False, |
|
return_meta=False, |
|
return_raw_only=False, |
|
**synthesis_kwargs): |
|
|
|
|
|
if planes is None: |
|
assert self.planes is not None |
|
planes = self.planes.repeat_interleave(c.shape[0], dim=0) |
|
return_sampling_details_flag = self.rendering_kwargs.get( |
|
'return_sampling_details_flag', False) |
|
|
|
if return_sampling_details_flag: |
|
return_meta = True |
|
|
|
cam2world_matrix = c[:, :16].reshape(-1, 4, 4) |
|
|
|
|
|
intrinsics = c[:, 16:25].reshape(-1, 3, 3) |
|
|
|
if neural_rendering_resolution is None: |
|
neural_rendering_resolution = self.neural_rendering_resolution |
|
else: |
|
self.neural_rendering_resolution = neural_rendering_resolution |
|
|
|
H = W = self.neural_rendering_resolution |
|
|
|
ray_origins, ray_directions, _ = self.ray_sampler( |
|
cam2world_matrix, intrinsics, neural_rendering_resolution) |
|
|
|
|
|
N, M, _ = ray_origins.shape |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
planes = planes.view( |
|
len(planes), |
|
3, |
|
-1, |
|
planes.shape[-2], |
|
planes.shape[-1]) |
|
|
|
|
|
rendering_details = self.renderer(planes, |
|
self.decoder, |
|
ray_origins, |
|
ray_directions, |
|
self.rendering_kwargs, |
|
return_meta=return_meta) |
|
|
|
feature_samples, depth_samples, weights_samples = ( |
|
rendering_details[k] |
|
for k in ['feature_samples', 'depth_samples', 'weights_samples']) |
|
|
|
if return_sampling_details_flag: |
|
shape_synthesized = rendering_details['shape_synthesized'] |
|
else: |
|
shape_synthesized = None |
|
|
|
|
|
feature_image = feature_samples.permute(0, 2, 1).reshape( |
|
N, feature_samples.shape[-1], H, |
|
W).contiguous() |
|
depth_image = depth_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
weights_samples = weights_samples.permute(0, 2, 1).reshape(N, 1, H, W) |
|
|
|
bcg_image = self.bcg_decoder(bg_plane) |
|
bcg_image = torch.nn.functional.interpolate( |
|
bcg_image, |
|
size=feature_image.shape[2:], |
|
mode='bilinear', |
|
align_corners=False, |
|
antialias=self.rendering_kwargs['sr_antialias']) |
|
|
|
mask_image = weights_samples * (1 + 2 * 0.001) - 0.001 |
|
|
|
|
|
feature_image = feature_image + (1 - weights_samples) * bcg_image |
|
|
|
rgb_image = feature_image[:, :3] |
|
|
|
|
|
if self.superresolution is not None and not return_raw_only: |
|
|
|
|
|
if ws is not None and ws.ndim == 2: |
|
ws = ws.unsqueeze( |
|
1)[:, -1:, :] |
|
|
|
sr_image = self.superresolution( |
|
rgb=rgb_image, |
|
x=feature_image, |
|
base_x=rgb_image, |
|
ws=ws, |
|
noise_mode=self. |
|
rendering_kwargs['superresolution_noise_mode'], |
|
**{ |
|
k: synthesis_kwargs[k] |
|
for k in synthesis_kwargs.keys() if k != 'noise_mode' |
|
}) |
|
else: |
|
|
|
sr_image = None |
|
|
|
if shape_synthesized is not None: |
|
shape_synthesized.update({ |
|
'image_depth': depth_image, |
|
}) |
|
|
|
ret_dict = { |
|
'feature_image': feature_image, |
|
|
|
'image_raw': rgb_image, |
|
'image_depth': depth_image, |
|
'weights_samples': weights_samples, |
|
|
|
|
|
'shape_synthesized': shape_synthesized, |
|
"image_mask": mask_image, |
|
} |
|
|
|
if sr_image is not None: |
|
ret_dict.update({ |
|
'image_sr': sr_image, |
|
}) |
|
|
|
if return_meta: |
|
ret_dict.update({ |
|
'feature_volume': |
|
rendering_details['feature_volume'], |
|
'all_coords': |
|
rendering_details['all_coords'], |
|
'weights': |
|
rendering_details['weights'], |
|
}) |
|
|
|
return ret_dict |
|
|