Spaces:
Runtime error
Runtime error
import random | |
from dataclasses import dataclass, field | |
import torch | |
import torch.nn as nn | |
import torch.nn.functional as F | |
import threestudio | |
from threestudio.models.materials.base import BaseMaterial | |
from threestudio.models.networks import get_encoding, get_mlp | |
from threestudio.utils.ops import dot, get_activation | |
from threestudio.utils.typing import * | |
class HybridRGBLatentMaterial(BaseMaterial): | |
class Config(BaseMaterial.Config): | |
n_output_dims: int = 3 | |
color_activation: str = "sigmoid" | |
requires_normal: bool = True | |
cfg: Config | |
def configure(self) -> None: | |
self.requires_normal = self.cfg.requires_normal | |
def forward( | |
self, features: Float[Tensor, "B ... Nf"], **kwargs | |
) -> Float[Tensor, "B ... Nc"]: | |
assert ( | |
features.shape[-1] == self.cfg.n_output_dims | |
), f"Expected {self.cfg.n_output_dims} output dims, only got {features.shape[-1]} dims input." | |
color = features | |
color[..., :3] = get_activation(self.cfg.color_activation)(color[..., :3]) | |
return color | |