thewhole's picture
Upload 245 files
2fa4776
raw
history blame
7.7 kB
from dataclasses import dataclass, field
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
import threestudio
from threestudio.models.geometry.base import BaseImplicitGeometry, contract_to_unisphere
from threestudio.utils.ops import get_activation
from threestudio.utils.typing import *
@threestudio.register("volume-grid")
class VolumeGrid(BaseImplicitGeometry):
@dataclass
class Config(BaseImplicitGeometry.Config):
grid_size: Tuple[int, int, int] = field(default_factory=lambda: (100, 100, 100))
n_feature_dims: int = 3
density_activation: Optional[str] = "softplus"
density_bias: Union[float, str] = "blob"
density_blob_scale: float = 5.0
density_blob_std: float = 0.5
normal_type: Optional[
str
] = "finite_difference" # in ['pred', 'finite_difference', 'finite_difference_laplacian']
# automatically determine the threshold
isosurface_threshold: Union[float, str] = "auto"
cfg: Config
def configure(self) -> None:
super().configure()
self.grid_size = self.cfg.grid_size
self.grid = nn.Parameter(
torch.zeros(1, self.cfg.n_feature_dims + 1, *self.grid_size)
)
if self.cfg.density_bias == "blob":
self.register_buffer("density_scale", torch.tensor(0.0))
else:
self.density_scale = nn.Parameter(torch.tensor(0.0))
if self.cfg.normal_type == "pred":
self.normal_grid = nn.Parameter(torch.zeros(1, 3, *self.grid_size))
def get_density_bias(self, points: Float[Tensor, "*N Di"]):
if self.cfg.density_bias == "blob":
# density_bias: Float[Tensor, "*N 1"] = self.cfg.density_blob_scale * torch.exp(-0.5 * (points ** 2).sum(dim=-1) / self.cfg.density_blob_std ** 2)[...,None]
density_bias: Float[Tensor, "*N 1"] = (
self.cfg.density_blob_scale
* (
1
- torch.sqrt((points.detach() ** 2).sum(dim=-1))
/ self.cfg.density_blob_std
)[..., None]
)
return density_bias
elif isinstance(self.cfg.density_bias, float):
return self.cfg.density_bias
else:
raise AttributeError(f"Unknown density bias {self.cfg.density_bias}")
def get_trilinear_feature(
self, points: Float[Tensor, "*N Di"], grid: Float[Tensor, "1 Df G1 G2 G3"]
) -> Float[Tensor, "*N Df"]:
points_shape = points.shape[:-1]
df = grid.shape[1]
di = points.shape[-1]
out = F.grid_sample(
grid, points.view(1, 1, 1, -1, di), align_corners=False, mode="bilinear"
)
out = out.reshape(df, -1).T.reshape(*points_shape, df)
return out
def forward(
self, points: Float[Tensor, "*N Di"], output_normal: bool = False
) -> Dict[str, Float[Tensor, "..."]]:
points_unscaled = points # points in the original scale
points = contract_to_unisphere(
points, self.bbox, self.unbounded
) # points normalized to (0, 1)
points = points * 2 - 1 # convert to [-1, 1] for grid sample
out = self.get_trilinear_feature(points, self.grid)
density, features = out[..., 0:1], out[..., 1:]
density = density * torch.exp(self.density_scale) # exp scaling in DreamFusion
# breakpoint()
density = get_activation(self.cfg.density_activation)(
density + self.get_density_bias(points_unscaled)
)
output = {
"density": density,
"features": features,
}
if output_normal:
if (
self.cfg.normal_type == "finite_difference"
or self.cfg.normal_type == "finite_difference_laplacian"
):
eps = 1.0e-3
if self.cfg.normal_type == "finite_difference_laplacian":
offsets: Float[Tensor, "6 3"] = torch.as_tensor(
[
[eps, 0.0, 0.0],
[-eps, 0.0, 0.0],
[0.0, eps, 0.0],
[0.0, -eps, 0.0],
[0.0, 0.0, eps],
[0.0, 0.0, -eps],
]
).to(points_unscaled)
points_offset: Float[Tensor, "... 6 3"] = (
points_unscaled[..., None, :] + offsets
).clamp(-self.cfg.radius, self.cfg.radius)
density_offset: Float[Tensor, "... 6 1"] = self.forward_density(
points_offset
)
normal = (
-0.5
* (density_offset[..., 0::2, 0] - density_offset[..., 1::2, 0])
/ eps
)
else:
offsets: Float[Tensor, "3 3"] = torch.as_tensor(
[[eps, 0.0, 0.0], [0.0, eps, 0.0], [0.0, 0.0, eps]]
).to(points_unscaled)
points_offset: Float[Tensor, "... 3 3"] = (
points_unscaled[..., None, :] + offsets
).clamp(-self.cfg.radius, self.cfg.radius)
density_offset: Float[Tensor, "... 3 1"] = self.forward_density(
points_offset
)
normal = -(density_offset[..., 0::1, 0] - density) / eps
normal = F.normalize(normal, dim=-1)
elif self.cfg.normal_type == "pred":
normal = self.get_trilinear_feature(points, self.normal_grid)
normal = F.normalize(normal, dim=-1)
else:
raise AttributeError(f"Unknown normal type {self.cfg.normal_type}")
output.update({"normal": normal, "shading_normal": normal})
return output
def forward_density(self, points: Float[Tensor, "*N Di"]) -> Float[Tensor, "*N 1"]:
points_unscaled = points
points = contract_to_unisphere(points_unscaled, self.bbox, self.unbounded)
points = points * 2 - 1 # convert to [-1, 1] for grid sample
out = self.get_trilinear_feature(points, self.grid)
density = out[..., 0:1]
density = density * torch.exp(self.density_scale)
density = get_activation(self.cfg.density_activation)(
density + self.get_density_bias(points_unscaled)
)
return density
def forward_field(
self, points: Float[Tensor, "*N Di"]
) -> Tuple[Float[Tensor, "*N 1"], Optional[Float[Tensor, "*N 3"]]]:
if self.cfg.isosurface_deformable_grid:
threestudio.warn(
f"{self.__class__.__name__} does not support isosurface_deformable_grid. Ignoring."
)
density = self.forward_density(points)
return density, None
def forward_level(
self, field: Float[Tensor, "*N 1"], threshold: float
) -> Float[Tensor, "*N 1"]:
return -(field - threshold)
def export(self, points: Float[Tensor, "*N Di"], **kwargs) -> Dict[str, Any]:
out: Dict[str, Any] = {}
if self.cfg.n_feature_dims == 0:
return out
points_unscaled = points
points = contract_to_unisphere(points, self.bbox, self.unbounded)
points = points * 2 - 1 # convert to [-1, 1] for grid sample
features = self.get_trilinear_feature(points, self.grid)[..., 1:]
out.update(
{
"features": features,
}
)
return out