Spaces:
Runtime error
Runtime error
File size: 3,835 Bytes
2fa4776 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 |
from dataclasses import dataclass
import torch
import torch.nn.functional as F
import threestudio
from threestudio.models.background.base import BaseBackground
from threestudio.models.geometry.base import BaseImplicitGeometry
from threestudio.models.materials.base import BaseMaterial
from threestudio.models.renderers.base import VolumeRenderer
from threestudio.utils.typing import *
@threestudio.register("patch-renderer")
class PatchRenderer(VolumeRenderer):
@dataclass
class Config(VolumeRenderer.Config):
patch_size: int = 128
base_renderer_type: str = ""
base_renderer: Optional[VolumeRenderer.Config] = None
global_detach: bool = False
global_downsample: int = 4
cfg: Config
def configure(
self,
geometry: BaseImplicitGeometry,
material: BaseMaterial,
background: BaseBackground,
) -> None:
self.base_renderer = threestudio.find(self.cfg.base_renderer_type)(
self.cfg.base_renderer,
geometry=geometry,
material=material,
background=background,
)
def forward(
self,
rays_o: Float[Tensor, "B H W 3"],
rays_d: Float[Tensor, "B H W 3"],
light_positions: Float[Tensor, "B 3"],
bg_color: Optional[Tensor] = None,
**kwargs
) -> Dict[str, Float[Tensor, "..."]]:
B, H, W, _ = rays_o.shape
if self.base_renderer.training:
downsample = self.cfg.global_downsample
global_rays_o = torch.nn.functional.interpolate(
rays_o.permute(0, 3, 1, 2),
(H // downsample, W // downsample),
mode="bilinear",
).permute(0, 2, 3, 1)
global_rays_d = torch.nn.functional.interpolate(
rays_d.permute(0, 3, 1, 2),
(H // downsample, W // downsample),
mode="bilinear",
).permute(0, 2, 3, 1)
out_global = self.base_renderer(
global_rays_o, global_rays_d, light_positions, bg_color, **kwargs
)
PS = self.cfg.patch_size
patch_x = torch.randint(0, W - PS, (1,)).item()
patch_y = torch.randint(0, H - PS, (1,)).item()
patch_rays_o = rays_o[:, patch_y : patch_y + PS, patch_x : patch_x + PS]
patch_rays_d = rays_d[:, patch_y : patch_y + PS, patch_x : patch_x + PS]
out = self.base_renderer(
patch_rays_o, patch_rays_d, light_positions, bg_color, **kwargs
)
valid_patch_key = []
for key in out:
if torch.is_tensor(out[key]):
if len(out[key].shape) == len(out["comp_rgb"].shape):
if out[key][..., 0].shape == out["comp_rgb"][..., 0].shape:
valid_patch_key.append(key)
for key in valid_patch_key:
out_global[key] = F.interpolate(
out_global[key].permute(0, 3, 1, 2), (H, W), mode="bilinear"
).permute(0, 2, 3, 1)
if self.cfg.global_detach:
out_global[key] = out_global[key].detach()
out_global[key][
:, patch_y : patch_y + PS, patch_x : patch_x + PS
] = out[key]
out = out_global
else:
out = self.base_renderer(
rays_o, rays_d, light_positions, bg_color, **kwargs
)
return out
def update_step(
self, epoch: int, global_step: int, on_load_weights: bool = False
) -> None:
self.base_renderer.update_step(epoch, global_step, on_load_weights)
def train(self, mode=True):
return self.base_renderer.train(mode)
def eval(self):
return self.base_renderer.eval()
|