thewhole's picture
Upload 245 files
2fa4776
raw
history blame
6.04 kB
from dataclasses import dataclass
import torch
import torch.nn.functional as F
import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.models.geometry.base import BaseImplicitGeometry
from threestudio.models.materials.base import BaseMaterial
from threestudio.models.renderers.base import VolumeRenderer
from threestudio.utils.GAN.discriminator import NLayerDiscriminator, weights_init
from threestudio.utils.GAN.distribution import DiagonalGaussianDistribution
from threestudio.utils.GAN.mobilenet import MobileNetV3 as GlobalEncoder
from threestudio.utils.GAN.vae import Decoder as Generator
from threestudio.utils.GAN.vae import Encoder as LocalEncoder
from threestudio.utils.typing import *
@threestudio.register("gan-volume-renderer")
class GANVolumeRenderer(VolumeRenderer):
@dataclass
class Config(VolumeRenderer.Config):
base_renderer_type: str = ""
base_renderer: Optional[VolumeRenderer.Config] = None
cfg: Config
def configure(
self,
geometry: BaseImplicitGeometry,
material: BaseMaterial,
background: BaseBackground,
) -> None:
self.base_renderer = threestudio.find(self.cfg.base_renderer_type)(
self.cfg.base_renderer,
geometry=geometry,
material=material,
background=background,
)
self.ch_mult = [1, 2, 4]
self.generator = Generator(
ch=64,
out_ch=3,
ch_mult=self.ch_mult,
num_res_blocks=1,
attn_resolutions=[],
dropout=0.0,
resamp_with_conv=True,
in_channels=7,
resolution=512,
z_channels=4,
)
self.local_encoder = LocalEncoder(
ch=32,
out_ch=3,
ch_mult=self.ch_mult,
num_res_blocks=1,
attn_resolutions=[],
dropout=0.0,
resamp_with_conv=True,
in_channels=3,
resolution=512,
z_channels=4,
)
self.global_encoder = GlobalEncoder(n_class=64)
self.discriminator = NLayerDiscriminator(
input_nc=3, n_layers=3, use_actnorm=False, ndf=64
).apply(weights_init)
def forward(
self,
rays_o: Float[Tensor, "B H W 3"],
rays_d: Float[Tensor, "B H W 3"],
light_positions: Float[Tensor, "B 3"],
bg_color: Optional[Tensor] = None,
gt_rgb: Float[Tensor, "B H W 3"] = None,
multi_level_guidance: Bool = False,
**kwargs
) -> Dict[str, Float[Tensor, "..."]]:
B, H, W, _ = rays_o.shape
if gt_rgb is not None and multi_level_guidance:
generator_level = torch.randint(0, 3, (1,)).item()
interval_x = torch.randint(0, 8, (1,)).item()
interval_y = torch.randint(0, 8, (1,)).item()
int_rays_o = rays_o[:, interval_y::8, interval_x::8]
int_rays_d = rays_d[:, interval_y::8, interval_x::8]
out = self.base_renderer(
int_rays_o, int_rays_d, light_positions, bg_color, **kwargs
)
comp_int_rgb = out["comp_rgb"][..., :3]
comp_gt_rgb = gt_rgb[:, interval_y::8, interval_x::8]
else:
generator_level = 0
scale_ratio = 2 ** (len(self.ch_mult) - 1)
rays_o = torch.nn.functional.interpolate(
rays_o.permute(0, 3, 1, 2),
(H // scale_ratio, W // scale_ratio),
mode="bilinear",
).permute(0, 2, 3, 1)
rays_d = torch.nn.functional.interpolate(
rays_d.permute(0, 3, 1, 2),
(H // scale_ratio, W // scale_ratio),
mode="bilinear",
).permute(0, 2, 3, 1)
out = self.base_renderer(rays_o, rays_d, light_positions, bg_color, **kwargs)
comp_rgb = out["comp_rgb"][..., :3]
latent = out["comp_rgb"][..., 3:]
out["comp_lr_rgb"] = comp_rgb.clone()
posterior = DiagonalGaussianDistribution(latent.permute(0, 3, 1, 2))
if multi_level_guidance:
z_map = posterior.sample()
else:
z_map = posterior.mode()
lr_rgb = comp_rgb.permute(0, 3, 1, 2)
if generator_level == 0:
g_code_rgb = self.global_encoder(F.interpolate(lr_rgb, (224, 224)))
comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb)
elif generator_level == 1:
g_code_rgb = self.global_encoder(
F.interpolate(gt_rgb.permute(0, 3, 1, 2), (224, 224))
)
comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb)
elif generator_level == 2:
g_code_rgb = self.global_encoder(
F.interpolate(gt_rgb.permute(0, 3, 1, 2), (224, 224))
)
l_code_rgb = self.local_encoder(gt_rgb.permute(0, 3, 1, 2))
posterior = DiagonalGaussianDistribution(l_code_rgb)
z_map = posterior.sample()
comp_gan_rgb = self.generator(torch.cat([lr_rgb, z_map], dim=1), g_code_rgb)
comp_rgb = F.interpolate(comp_rgb.permute(0, 3, 1, 2), (H, W), mode="bilinear")
comp_gan_rgb = F.interpolate(comp_gan_rgb, (H, W), mode="bilinear")
out.update(
{
"posterior": posterior,
"comp_gan_rgb": comp_gan_rgb.permute(0, 2, 3, 1),
"comp_rgb": comp_rgb.permute(0, 2, 3, 1),
"generator_level": generator_level,
}
)
if gt_rgb is not None and multi_level_guidance:
out.update({"comp_int_rgb": comp_int_rgb, "comp_gt_rgb": comp_gt_rgb})
return out
def update_step(
self, epoch: int, global_step: int, on_load_weights: bool = False
) -> None:
self.base_renderer.update_step(epoch, global_step, on_load_weights)
def train(self, mode=True):
return self.base_renderer.train(mode)
def eval(self):
return self.base_renderer.eval()