import math import tinycudann as tcnn import torch import torch.nn as nn import torch.nn.functional as F import threestudio from threestudio.utils.base import Updateable from threestudio.utils.config import config_to_primitive from threestudio.utils.misc import get_rank from threestudio.utils.ops import get_activation from threestudio.utils.typing import * class ProgressiveBandFrequency(nn.Module, Updateable): def __init__(self, in_channels: int, config: dict): super().__init__() self.N_freqs = config["n_frequencies"] self.in_channels, self.n_input_dims = in_channels, in_channels self.funcs = [torch.sin, torch.cos] self.freq_bands = 2 ** torch.linspace(0, self.N_freqs - 1, self.N_freqs) self.n_output_dims = self.in_channels * (len(self.funcs) * self.N_freqs) self.n_masking_step = config.get("n_masking_step", 0) self.update_step( None, None ) # mask should be updated at the beginning each step def forward(self, x): out = [] for freq, mask in zip(self.freq_bands, self.mask): for func in self.funcs: out += [func(freq * x) * mask] return torch.cat(out, -1) def update_step(self, epoch, global_step, on_load_weights=False): if self.n_masking_step <= 0 or global_step is None: self.mask = torch.ones(self.N_freqs, dtype=torch.float32) else: self.mask = ( 1.0 - torch.cos( math.pi * ( global_step / self.n_masking_step * self.N_freqs - torch.arange(0, self.N_freqs) ).clamp(0, 1) ) ) / 2.0 threestudio.debug( f"Update mask: {global_step}/{self.n_masking_step} {self.mask}" ) class TCNNEncoding(nn.Module): def __init__(self, in_channels, config, dtype=torch.float32) -> None: super().__init__() self.n_input_dims = in_channels with torch.cuda.device(get_rank()): self.encoding = tcnn.Encoding(in_channels, config, dtype=dtype) self.n_output_dims = self.encoding.n_output_dims def forward(self, x): return self.encoding(x) class ProgressiveBandHashGrid(nn.Module, Updateable): def __init__(self, in_channels, config, dtype=torch.float32): super().__init__() self.n_input_dims = in_channels encoding_config = config.copy() encoding_config["otype"] = "Grid" encoding_config["type"] = "Hash" with torch.cuda.device(get_rank()): self.encoding = tcnn.Encoding(in_channels, encoding_config, dtype=dtype) self.n_output_dims = self.encoding.n_output_dims self.n_level = config["n_levels"] self.n_features_per_level = config["n_features_per_level"] self.start_level, self.start_step, self.update_steps = ( config["start_level"], config["start_step"], config["update_steps"], ) self.current_level = self.start_level self.mask = torch.zeros( self.n_level * self.n_features_per_level, dtype=torch.float32, device=get_rank(), ) def forward(self, x): enc = self.encoding(x) enc = enc * self.mask return enc def update_step(self, epoch, global_step, on_load_weights=False): current_level = min( self.start_level + max(global_step - self.start_step, 0) // self.update_steps, self.n_level, ) if current_level > self.current_level: threestudio.debug(f"Update current level to {current_level}") self.current_level = current_level self.mask[: self.current_level * self.n_features_per_level] = 1.0 class CompositeEncoding(nn.Module, Updateable): def __init__(self, encoding, include_xyz=False, xyz_scale=2.0, xyz_offset=-1.0): super(CompositeEncoding, self).__init__() self.encoding = encoding self.include_xyz, self.xyz_scale, self.xyz_offset = ( include_xyz, xyz_scale, xyz_offset, ) self.n_output_dims = ( int(self.include_xyz) * self.encoding.n_input_dims + self.encoding.n_output_dims ) def forward(self, x, *args): return ( self.encoding(x, *args) if not self.include_xyz else torch.cat( [x * self.xyz_scale + self.xyz_offset, self.encoding(x, *args)], dim=-1 ) ) def get_encoding(n_input_dims: int, config) -> nn.Module: # input suppose to be range [0, 1] encoding: nn.Module if config.otype == "ProgressiveBandFrequency": encoding = ProgressiveBandFrequency(n_input_dims, config_to_primitive(config)) elif config.otype == "ProgressiveBandHashGrid": encoding = ProgressiveBandHashGrid(n_input_dims, config_to_primitive(config)) else: encoding = TCNNEncoding(n_input_dims, config_to_primitive(config)) encoding = CompositeEncoding( encoding, include_xyz=config.get("include_xyz", False), xyz_scale=2.0, xyz_offset=-1.0, ) # FIXME: hard coded return encoding class VanillaMLP(nn.Module): def __init__(self, dim_in: int, dim_out: int, config: dict): super().__init__() self.n_neurons, self.n_hidden_layers = ( config["n_neurons"], config["n_hidden_layers"], ) layers = [ self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation(), ] for i in range(self.n_hidden_layers - 1): layers += [ self.make_linear( self.n_neurons, self.n_neurons, is_first=False, is_last=False ), self.make_activation(), ] layers += [ self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True) ] self.layers = nn.Sequential(*layers) self.output_activation = get_activation(config.get("output_activation", None)) def forward(self, x): # disable autocast # strange that the parameters will have empty gradients if autocast is enabled in AMP with torch.cuda.amp.autocast(enabled=False): x = self.layers(x) x = self.output_activation(x) return x def make_linear(self, dim_in, dim_out, is_first, is_last): layer = nn.Linear(dim_in, dim_out, bias=False) return layer def make_activation(self): return nn.ReLU(inplace=True) class SphereInitVanillaMLP(nn.Module): def __init__(self, dim_in, dim_out, config): super().__init__() self.n_neurons, self.n_hidden_layers = ( config["n_neurons"], config["n_hidden_layers"], ) self.sphere_init, self.weight_norm = True, True self.sphere_init_radius = config["sphere_init_radius"] self.sphere_init_inside_out = config["inside_out"] self.layers = [ self.make_linear(dim_in, self.n_neurons, is_first=True, is_last=False), self.make_activation(), ] for i in range(self.n_hidden_layers - 1): self.layers += [ self.make_linear( self.n_neurons, self.n_neurons, is_first=False, is_last=False ), self.make_activation(), ] self.layers += [ self.make_linear(self.n_neurons, dim_out, is_first=False, is_last=True) ] self.layers = nn.Sequential(*self.layers) self.output_activation = get_activation(config.get("output_activation", None)) def forward(self, x): # disable autocast # strange that the parameters will have empty gradients if autocast is enabled in AMP with torch.cuda.amp.autocast(enabled=False): x = self.layers(x) x = self.output_activation(x) return x def make_linear(self, dim_in, dim_out, is_first, is_last): layer = nn.Linear(dim_in, dim_out, bias=True) if is_last: if not self.sphere_init_inside_out: torch.nn.init.constant_(layer.bias, -self.sphere_init_radius) torch.nn.init.normal_( layer.weight, mean=math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001, ) else: torch.nn.init.constant_(layer.bias, self.sphere_init_radius) torch.nn.init.normal_( layer.weight, mean=-math.sqrt(math.pi) / math.sqrt(dim_in), std=0.0001, ) elif is_first: torch.nn.init.constant_(layer.bias, 0.0) torch.nn.init.constant_(layer.weight[:, 3:], 0.0) torch.nn.init.normal_( layer.weight[:, :3], 0.0, math.sqrt(2) / math.sqrt(dim_out) ) else: torch.nn.init.constant_(layer.bias, 0.0) torch.nn.init.normal_(layer.weight, 0.0, math.sqrt(2) / math.sqrt(dim_out)) if self.weight_norm: layer = nn.utils.weight_norm(layer) return layer def make_activation(self): return nn.Softplus(beta=100) class TCNNNetwork(nn.Module): def __init__(self, dim_in: int, dim_out: int, config: dict) -> None: super().__init__() with torch.cuda.device(get_rank()): self.network = tcnn.Network(dim_in, dim_out, config) def forward(self, x): return self.network(x).float() # transform to float32 def get_mlp(n_input_dims, n_output_dims, config) -> nn.Module: network: nn.Module if config.otype == "VanillaMLP": network = VanillaMLP(n_input_dims, n_output_dims, config_to_primitive(config)) elif config.otype == "SphereInitVanillaMLP": network = SphereInitVanillaMLP( n_input_dims, n_output_dims, config_to_primitive(config) ) else: assert ( config.get("sphere_init", False) is False ), "sphere_init=True only supported by VanillaMLP" network = TCNNNetwork(n_input_dims, n_output_dims, config_to_primitive(config)) return network class NetworkWithInputEncoding(nn.Module, Updateable): def __init__(self, encoding, network): super().__init__() self.encoding, self.network = encoding, network def forward(self, x): return self.network(self.encoding(x)) class TCNNNetworkWithInputEncoding(nn.Module): def __init__( self, n_input_dims: int, n_output_dims: int, encoding_config: dict, network_config: dict, ) -> None: super().__init__() with torch.cuda.device(get_rank()): self.network_with_input_encoding = tcnn.NetworkWithInputEncoding( n_input_dims=n_input_dims, n_output_dims=n_output_dims, encoding_config=encoding_config, network_config=network_config, ) def forward(self, x): return self.network_with_input_encoding(x).float() # transform to float32 def create_network_with_input_encoding( n_input_dims: int, n_output_dims: int, encoding_config, network_config ) -> nn.Module: # input suppose to be range [0, 1] network_with_input_encoding: nn.Module if encoding_config.otype in [ "VanillaFrequency", "ProgressiveBandHashGrid", ] or network_config.otype in ["VanillaMLP", "SphereInitVanillaMLP"]: encoding = get_encoding(n_input_dims, encoding_config) network = get_mlp(encoding.n_output_dims, n_output_dims, network_config) network_with_input_encoding = NetworkWithInputEncoding(encoding, network) else: network_with_input_encoding = TCNNNetworkWithInputEncoding( n_input_dims=n_input_dims, n_output_dims=n_output_dims, encoding_config=config_to_primitive(encoding_config), network_config=config_to_primitive(network_config), ) return network_with_input_encoding class ToDTypeWrapper(nn.Module): def __init__(self, module: nn.Module, dtype: torch.dtype): super().__init__() self.module = module self.dtype = dtype def forward(self, x: Float[Tensor, "..."]) -> Float[Tensor, "..."]: return self.module(x).to(self.dtype)