thewhole's picture
Upload 245 files
2fa4776
raw
history blame
No virus
3.84 kB
from dataclasses import dataclass
import torch
import torch.nn.functional as F
import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.models.geometry.base import BaseImplicitGeometry
from threestudio.models.materials.base import BaseMaterial
from threestudio.models.renderers.base import VolumeRenderer
from threestudio.utils.typing import *
@threestudio.register("patch-renderer")
class PatchRenderer(VolumeRenderer):
@dataclass
class Config(VolumeRenderer.Config):
patch_size: int = 128
base_renderer_type: str = ""
base_renderer: Optional[VolumeRenderer.Config] = None
global_detach: bool = False
global_downsample: int = 4
cfg: Config
def configure(
self,
geometry: BaseImplicitGeometry,
material: BaseMaterial,
background: BaseBackground,
) -> None:
self.base_renderer = threestudio.find(self.cfg.base_renderer_type)(
self.cfg.base_renderer,
geometry=geometry,
material=material,
background=background,
)
def forward(
self,
rays_o: Float[Tensor, "B H W 3"],
rays_d: Float[Tensor, "B H W 3"],
light_positions: Float[Tensor, "B 3"],
bg_color: Optional[Tensor] = None,
**kwargs
) -> Dict[str, Float[Tensor, "..."]]:
B, H, W, _ = rays_o.shape
if self.base_renderer.training:
downsample = self.cfg.global_downsample
global_rays_o = torch.nn.functional.interpolate(
rays_o.permute(0, 3, 1, 2),
(H // downsample, W // downsample),
mode="bilinear",
).permute(0, 2, 3, 1)
global_rays_d = torch.nn.functional.interpolate(
rays_d.permute(0, 3, 1, 2),
(H // downsample, W // downsample),
mode="bilinear",
).permute(0, 2, 3, 1)
out_global = self.base_renderer(
global_rays_o, global_rays_d, light_positions, bg_color, **kwargs
)
PS = self.cfg.patch_size
patch_x = torch.randint(0, W - PS, (1,)).item()
patch_y = torch.randint(0, H - PS, (1,)).item()
patch_rays_o = rays_o[:, patch_y : patch_y + PS, patch_x : patch_x + PS]
patch_rays_d = rays_d[:, patch_y : patch_y + PS, patch_x : patch_x + PS]
out = self.base_renderer(
patch_rays_o, patch_rays_d, light_positions, bg_color, **kwargs
)
valid_patch_key = []
for key in out:
if torch.is_tensor(out[key]):
if len(out[key].shape) == len(out["comp_rgb"].shape):
if out[key][..., 0].shape == out["comp_rgb"][..., 0].shape:
valid_patch_key.append(key)
for key in valid_patch_key:
out_global[key] = F.interpolate(
out_global[key].permute(0, 3, 1, 2), (H, W), mode="bilinear"
).permute(0, 2, 3, 1)
if self.cfg.global_detach:
out_global[key] = out_global[key].detach()
out_global[key][
:, patch_y : patch_y + PS, patch_x : patch_x + PS
] = out[key]
out = out_global
else:
out = self.base_renderer(
rays_o, rays_d, light_positions, bg_color, **kwargs
)
return out
def update_step(
self, epoch: int, global_step: int, on_load_weights: bool = False
) -> None:
self.base_renderer.update_step(epoch, global_step, on_load_weights)
def train(self, mode=True):
return self.base_renderer.train(mode)
def eval(self):
return self.base_renderer.eval()