import bisect import math import random from dataclasses import dataclass, field import pytorch_lightning as pl import torch import torch.nn.functional as F from torch.utils.data import DataLoader, Dataset, IterableDataset import threestudio from threestudio import register from threestudio.utils.base import Updateable from threestudio.utils.config import parse_structured from threestudio.utils.misc import get_device from threestudio.utils.ops import ( get_mvp_matrix, get_projection_matrix, get_ray_directions, get_rays, ) from threestudio.utils.typing import * import os import numpy as np def safe_normalize(x, eps=1e-20): return x / torch.sqrt(torch.clamp(torch.sum(x * x, -1, keepdim=True), min=eps)) def convert_camera_to_world_transform(transform): # 将右手坐标系的变换矩阵转换为左手坐标系 # 复制原始变换矩阵 converted_transform = transform.clone() # 反转观察方向(将平移分量的第三个元素乘以-1) converted_transform[:, 2] *= -1 # 交换第一行和第三行 converted_transform[[0, 2], :] = converted_transform[[2, 0], :] return converted_transform def circle_poses(device, radius=torch.tensor([3.2]), theta=torch.tensor([60]), phi=torch.tensor([0])): theta = theta / 180 * np.pi phi = phi / 180 * np.pi centers = torch.stack([ radius * torch.sin(theta) * torch.sin(phi), radius * torch.cos(theta), radius * torch.sin(theta) * torch.cos(phi), ], dim=-1) # [B, 3] # lookat forward_vector = safe_normalize(centers) up_vector = torch.FloatTensor([0, 1, 0]).to(device).unsqueeze(0).repeat(len(centers), 1) right_vector = safe_normalize(torch.cross(forward_vector, up_vector, dim=-1)) up_vector = safe_normalize(torch.cross(right_vector, forward_vector, dim=-1)) poses = torch.eye(4, dtype=torch.float, device=device).unsqueeze(0).repeat(len(centers), 1, 1) poses[:, :3, :3] = torch.stack((right_vector, up_vector, forward_vector), dim=-1) poses[:, :3, 3] = centers return poses trans_t = lambda t : torch.Tensor([ [1,0,0,0], [0,1,0,0], [0,0,1,t], [0,0,0,1]]).float() rot_phi = lambda phi : torch.Tensor([ [1,0,0,0], [0,np.cos(phi),-np.sin(phi),0], [0,np.sin(phi), np.cos(phi),0], [0,0,0,1]]).float() rot_theta = lambda th : torch.Tensor([ [np.cos(th),0,-np.sin(th),0], [0,1,0,0], [np.sin(th),0, np.cos(th),0], [0,0,0,1]]).float() def rodrigues_mat_to_rot(R): eps =1e-16 trc = np.trace(R) trc2 = (trc - 1.)/ 2. #sinacostrc2 = np.sqrt(1 - trc2 * trc2) s = np.array([R[2, 1] - R[1, 2], R[0, 2] - R[2, 0], R[1, 0] - R[0, 1]]) if (1 - trc2 * trc2) >= eps: tHeta = np.arccos(trc2) tHetaf = tHeta / (2 * (np.sin(tHeta))) else: tHeta = np.real(np.arccos(trc2)) tHetaf = 0.5 / (1 - tHeta / 6) omega = tHetaf * s return omega def rodrigues_rot_to_mat(r): wx,wy,wz = r theta = np.sqrt(wx * wx + wy * wy + wz * wz) a = np.cos(theta) b = (1 - np.cos(theta)) / (theta*theta) c = np.sin(theta) / theta R = np.zeros([3,3]) R[0, 0] = a + b * (wx * wx) R[0, 1] = b * wx * wy - c * wz R[0, 2] = b * wx * wz + c * wy R[1, 0] = b * wx * wy + c * wz R[1, 1] = a + b * (wy * wy) R[1, 2] = b * wy * wz - c * wx R[2, 0] = b * wx * wz - c * wy R[2, 1] = b * wz * wy + c * wx R[2, 2] = a + b * (wz * wz) return R def pose_spherical(theta, phi, radius): c2w = trans_t(radius) c2w = rot_phi(phi/180.*np.pi) @ c2w c2w = rot_theta(theta/180.*np.pi) @ c2w c2w = torch.Tensor(np.array([[-1,0,0,0],[0,0,1,0],[0,1,0,0],[0,0,0,1]])) @ c2w return c2w def convert_camera_pose(camera_pose): # Clone the tensor to avoid in-place operations colmap_pose = camera_pose.clone() # Extract rotation and translation components rotation = colmap_pose[:, :3, :3] translation = colmap_pose[:, :3, 3] # Change rotation orientation rotation[:, 0, :] *= -1 rotation[:, 1, :] *= -1 # Change translation position translation[:, 0] *= -1 translation[:, 1] *= -1 return colmap_pose def convert_camera_pose(camera_pose): # Clone the tensor to avoid in-place operations colmap_pose = camera_pose.clone() # Extract rotation and translation components rotation = colmap_pose[:, :3, :3] translation = colmap_pose[:, :3, 3] # Change rotation orientation rotation[:, 0, :] *= -1 rotation[:, 1, :] *= -1 # Change translation position translation[:, 0] *= -1 translation[:, 1] *= -1 return colmap_pose @dataclass class RandomCameraDataModuleConfig: # height, width, and batch_size should be Union[int, List[int]] # but OmegaConf does not support Union of containers height: Any = 512 width: Any = 512 batch_size: Any = 1 resolution_milestones: List[int] = field(default_factory=lambda: []) eval_height: int = 512 eval_width: int = 512 eval_batch_size: int = 1 n_val_views: int = 1 n_test_views: int = 120 elevation_range: Tuple[float, float] = (-10, 60) azimuth_range: Tuple[float, float] = (-180, 180) camera_distance_range: Tuple[float, float] = (4.,6.) fovy_range: Tuple[float, float] = ( 40, 70, ) # in degrees, in vertical direction (along height) camera_perturb: float = 0. center_perturb: float = 0. up_perturb: float = 0.0 light_position_perturb: float = 1.0 light_distance_range: Tuple[float, float] = (0.8, 1.5) eval_elevation_deg: float = 15.0 eval_camera_distance: float = 6. eval_fovy_deg: float = 70.0 light_sample_strategy: str = "dreamfusion" batch_uniform_azimuth: bool = True progressive_until: int = 0 # progressive ranges for elevation, azimuth, r, fovy class RandomCameraIterableDataset(IterableDataset, Updateable): def __init__(self, cfg: Any) -> None: super().__init__() self.cfg: RandomCameraDataModuleConfig = cfg self.heights: List[int] = ( [self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height ) self.widths: List[int] = ( [self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width ) self.batch_sizes: List[int] = ( [self.cfg.batch_size] if isinstance(self.cfg.batch_size, int) else self.cfg.batch_size ) assert len(self.heights) == len(self.widths) == len(self.batch_sizes) self.resolution_milestones: List[int] if ( len(self.heights) == 1 and len(self.widths) == 1 and len(self.batch_sizes) == 1 ): if len(self.cfg.resolution_milestones) > 0: threestudio.warn( "Ignoring resolution_milestones since height and width are not changing" ) self.resolution_milestones = [-1] else: assert len(self.heights) == len(self.cfg.resolution_milestones) + 1 self.resolution_milestones = [-1] + self.cfg.resolution_milestones self.directions_unit_focals = [ get_ray_directions(H=height, W=width, focal=1.0) for (height, width) in zip(self.heights, self.widths) ] self.height: int = self.heights[0] self.width: int = self.widths[0] self.batch_size: int = self.batch_sizes[0] self.directions_unit_focal = self.directions_unit_focals[0] self.elevation_range = self.cfg.elevation_range self.azimuth_range = self.cfg.azimuth_range self.camera_distance_range = self.cfg.camera_distance_range self.fovy_range = self.cfg.fovy_range def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1 self.height = self.heights[size_ind] self.width = self.widths[size_ind] self.batch_size = self.batch_sizes[size_ind] self.directions_unit_focal = self.directions_unit_focals[size_ind] threestudio.debug( f"Training height: {self.height}, width: {self.width}, batch_size: {self.batch_size}" ) # progressive view self.progressive_view(global_step) def __iter__(self): while True: yield {} def progressive_view(self, global_step): pass # r = min(1.0, global_step / (self.cfg.progressive_until + 1)) # self.elevation_range = [ # (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[0], # (1 - r) * self.cfg.eval_elevation_deg + r * self.cfg.elevation_range[1], # ] # self.azimuth_range = [ # (1 - r) * 0.0 + r * self.cfg.azimuth_range[0], # (1 - r) * 0.0 + r * self.cfg.azimuth_range[1], # ] # self.camera_distance_range = [ # (1 - r) * self.cfg.eval_camera_distance # + r * self.cfg.camera_distance_range[0], # (1 - r) * self.cfg.eval_camera_distance # + r * self.cfg.camera_distance_range[1], # ] # self.fovy_range = [ # (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[0], # (1 - r) * self.cfg.eval_fovy_deg + r * self.cfg.fovy_range[1], # ] def collate(self, batch) -> Dict[str, Any]: # sample elevation angles elevation_deg: Float[Tensor, "B"] elevation: Float[Tensor, "B"] if random.random() < 0.5: # sample elevation angles uniformly with a probability 0.5 (biased towards poles) elevation_deg = ( torch.rand(self.batch_size) * (self.elevation_range[1] - self.elevation_range[0]) + self.elevation_range[0] ) elevation = elevation_deg * math.pi / 180 else: # otherwise sample uniformly on sphere elevation_range_percent = [ (self.elevation_range[0] + 90.0) / 180.0, (self.elevation_range[1] + 90.0) / 180.0, ] # inverse transform sampling elevation = torch.asin( 2 * ( torch.rand(self.batch_size) * (elevation_range_percent[1] - elevation_range_percent[0]) + elevation_range_percent[0] ) - 1.0 ) elevation_deg = elevation / math.pi * 180.0 # sample azimuth angles from a uniform distribution bounded by azimuth_range azimuth_deg: Float[Tensor, "B"] if self.cfg.batch_uniform_azimuth: # ensures sampled azimuth angles in a batch cover the whole range azimuth_deg = ( torch.rand(self.batch_size) + torch.arange(self.batch_size) ) / self.batch_size * ( self.azimuth_range[1] - self.azimuth_range[0] ) + self.azimuth_range[ 0 ] else: # simple random sampling azimuth_deg = ( torch.rand(self.batch_size) * (self.azimuth_range[1] - self.azimuth_range[0]) + self.azimuth_range[0] ) azimuth = azimuth_deg * math.pi / 180 # sample distances from a uniform distribution bounded by distance_range camera_distances: Float[Tensor, "B"] = ( torch.rand(self.batch_size) * (self.camera_distance_range[1] - self.camera_distance_range[0]) + self.camera_distance_range[0] ) # convert spherical coordinates to cartesian coordinates # right hand coordinate system, x back, y right, z up # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) camera_positions: Float[Tensor, "B 3"] = torch.stack( [ camera_distances * torch.cos(elevation) * torch.cos(azimuth), camera_distances * torch.cos(elevation) * torch.sin(azimuth), camera_distances * torch.sin(elevation), ], dim=-1, ) # default scene center at origin center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) # default camera up direction as +z up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ None, : ].repeat(self.batch_size, 1) # sample camera perturbations from a uniform distribution [-camera_perturb, camera_perturb] camera_perturb: Float[Tensor, "B 3"] = ( torch.rand(self.batch_size, 3) * 2 * self.cfg.camera_perturb - self.cfg.camera_perturb ) camera_positions = camera_positions + camera_perturb # sample center perturbations from a normal distribution with mean 0 and std center_perturb center_perturb: Float[Tensor, "B 3"] = ( torch.randn(self.batch_size, 3) * self.cfg.center_perturb ) center = center + center_perturb # sample up perturbations from a normal distribution with mean 0 and std up_perturb up_perturb: Float[Tensor, "B 3"] = ( torch.randn(self.batch_size, 3) * self.cfg.up_perturb ) up = up + up_perturb # sample fovs from a uniform distribution bounded by fov_range fovy_deg: Float[Tensor, "B"] = ( torch.rand(self.batch_size) * (self.fovy_range[1] - self.fovy_range[0]) + self.fovy_range[0] ) fovy = fovy_deg * math.pi / 180 # sample light distance from a uniform distribution bounded by light_distance_range light_distances: Float[Tensor, "B"] = ( torch.rand(self.batch_size) * (self.cfg.light_distance_range[1] - self.cfg.light_distance_range[0]) + self.cfg.light_distance_range[0] ) if self.cfg.light_sample_strategy == "dreamfusion" or self.cfg.light_sample_strategy == "dreamfusion3dgs": # sample light direction from a normal distribution with mean camera_position and std light_position_perturb light_direction: Float[Tensor, "B 3"] = F.normalize( camera_positions + torch.randn(self.batch_size, 3) * self.cfg.light_position_perturb, dim=-1, ) # get light position by scaling light direction by light distance light_positions: Float[Tensor, "B 3"] = ( light_direction * light_distances[:, None] ) elif self.cfg.light_sample_strategy == "magic3d": # sample light direction within restricted angle range (pi/3) local_z = F.normalize(camera_positions, dim=-1) local_x = F.normalize( torch.stack( [local_z[:, 1], -local_z[:, 0], torch.zeros_like(local_z[:, 0])], dim=-1, ), dim=-1, ) local_y = F.normalize(torch.cross(local_z, local_x, dim=-1), dim=-1) rot = torch.stack([local_x, local_y, local_z], dim=-1) light_azimuth = ( torch.rand(self.batch_size) * math.pi * 2 - math.pi ) # [-pi, pi] light_elevation = ( torch.rand(self.batch_size) * math.pi / 3 + math.pi / 6 ) # [pi/6, pi/2] light_positions_local = torch.stack( [ light_distances * torch.cos(light_elevation) * torch.cos(light_azimuth), light_distances * torch.cos(light_elevation) * torch.sin(light_azimuth), light_distances * torch.sin(light_elevation), ], dim=-1, ) light_positions = (rot @ light_positions_local[:, :, None])[:, :, 0] else: raise ValueError( f"Unknown light sample strategy: {self.cfg.light_sample_strategy}" ) lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) up = F.normalize(torch.cross(right, lookat), dim=-1) c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], dim=-1, ) c2w: Float[Tensor, "B 4 4"] = torch.cat( [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 ) c2w[:, 3, 3] = 1.0 # get directions by dividing directions_unit_focal by focal length focal_length: Float[Tensor, "B"] = 0.5 * self.height / torch.tan(0.5 * fovy) directions: Float[Tensor, "B H W 3"] = self.directions_unit_focal[ None, :, :, : ].repeat(self.batch_size, 1, 1, 1) directions[:, :, :, :2] = ( directions[:, :, :, :2] / focal_length[:, None, None, None] ) proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix( fovy, self.width / self.height, 0.1, 1000.0 ) # FIXME: hard-coded near and far mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, proj_mtx) c2w_3dgs = [] for id in range(self.batch_size): render_pose = pose_spherical( azimuth_deg[id] + 180.0, -elevation_deg[id], camera_distances[id]) # print(azimuth_deg[id] , -elevation_deg[id], camera_distances[id]*2.0) # print(render_pose) matrix = torch.linalg.inv(render_pose) # R = -np.transpose(matrix[:3,:3]) # R = -np.transpose(matrix[:3,:3]) R = -torch.transpose(matrix[:3,:3], 0, 1) R[:,0] = -R[:,0] T = -matrix[:3, 3] c2w_single = torch.cat([R, T[:,None]], 1) c2w_single = torch.cat([c2w_single, torch.tensor([[0,0,0,1]])], 0) # c2w_single = convert_camera_to_world_transform(c2w_single) c2w_3dgs.append(c2w_single) c2w_3dgs = torch.stack(c2w_3dgs, 0) return { "mvp_mtx": mvp_mtx, "camera_positions": camera_positions, "c2w": c2w, "c2w_3dgs":c2w_3dgs, "light_positions": light_positions, "elevation": elevation_deg, "azimuth": azimuth_deg, "camera_distances": camera_distances, "height": self.height, "width": self.width, "fovy":fovy, } class RandomCameraDataset(Dataset): def __init__(self, cfg: Any, split: str) -> None: super().__init__() self.cfg: RandomCameraDataModuleConfig = cfg self.split = split if split == "val": self.n_views = self.cfg.n_val_views else: self.n_views = self.cfg.n_test_views azimuth_deg: Float[Tensor, "B"] if self.split == "val": # make sure the first and last view are not the same azimuth_deg = torch.linspace(-180., 180.0, self.n_views + 1)[: self.n_views] else: azimuth_deg = torch.linspace(-180., 180.0, self.n_views) elevation_deg: Float[Tensor, "B"] = torch.full_like( azimuth_deg, self.cfg.eval_elevation_deg ) camera_distances: Float[Tensor, "B"] = torch.full_like( elevation_deg, self.cfg.eval_camera_distance ) elevation = elevation_deg * math.pi / 180 azimuth = azimuth_deg * math.pi / 180 # convert spherical coordinates to cartesian coordinates # right hand coordinate system, x back, y right, z up # elevation in (-90, 90), azimuth from +x to +y in (-180, 180) camera_positions: Float[Tensor, "B 3"] = torch.stack( [ camera_distances * torch.cos(elevation) * torch.cos(azimuth), camera_distances * torch.cos(elevation) * torch.sin(azimuth), camera_distances * torch.sin(elevation), ], dim=-1, ) # default scene center at origin center: Float[Tensor, "B 3"] = torch.zeros_like(camera_positions) # default camera up direction as +z up: Float[Tensor, "B 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[ None, : ].repeat(self.cfg.eval_batch_size, 1) fovy_deg: Float[Tensor, "B"] = torch.full_like( elevation_deg, self.cfg.eval_fovy_deg ) fovy = fovy_deg * math.pi / 180 light_positions: Float[Tensor, "B 3"] = camera_positions lookat: Float[Tensor, "B 3"] = F.normalize(center - camera_positions, dim=-1) right: Float[Tensor, "B 3"] = F.normalize(torch.cross(lookat, up), dim=-1) up = F.normalize(torch.cross(right, lookat), dim=-1) c2w3x4: Float[Tensor, "B 3 4"] = torch.cat( [torch.stack([right, up, -lookat], dim=-1), camera_positions[:, :, None]], dim=-1, ) c2w: Float[Tensor, "B 4 4"] = torch.cat( [c2w3x4, torch.zeros_like(c2w3x4[:, :1])], dim=1 ) c2w[:, 3, 3] = 1.0 # get directions by dividing directions_unit_focal by focal length focal_length: Float[Tensor, "B"] = ( 0.5 * self.cfg.eval_height / torch.tan(0.5 * fovy) ) directions_unit_focal = get_ray_directions( H=self.cfg.eval_height, W=self.cfg.eval_width, focal=1.0 ) directions: Float[Tensor, "B H W 3"] = directions_unit_focal[ None, :, :, : ].repeat(self.n_views, 1, 1, 1) directions[:, :, :, :2] = ( directions[:, :, :, :2] / focal_length[:, None, None, None] ) proj_mtx: Float[Tensor, "B 4 4"] = get_projection_matrix( fovy, self.cfg.eval_width / self.cfg.eval_height, 0.1, 1000.0 ) # FIXME: hard-coded near and far mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(c2w, proj_mtx) c2w_3dgs = [] for id in range(self.n_views): render_pose = pose_spherical( azimuth_deg[id] + 180.0, -elevation_deg[id], camera_distances[id]) matrix = torch.linalg.inv(render_pose) # R = -np.transpose(matrix[:3,:3]) # R = -np.transpose(matrix[:3,:3]) R = -torch.transpose(matrix[:3,:3], 0, 1) R[:,0] = -R[:,0] T = -matrix[:3, 3] c2w_single = torch.cat([R, T[:,None]], 1) c2w_single = torch.cat([c2w_single, torch.tensor([[0,0,0,1]])], 0) # c2w_single = convert_camera_to_world_transform(c2w_single) c2w_3dgs.append(c2w_single) c2w_3dgs = torch.stack(c2w_3dgs, 0) self.mvp_mtx = mvp_mtx self.c2w = c2w self.c2w_3dgs = c2w_3dgs self.camera_positions = camera_positions self.light_positions = light_positions self.elevation, self.azimuth = elevation, azimuth self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg self.camera_distances = camera_distances self.fovy = fovy def __len__(self): return self.n_views def __getitem__(self, index): return { "index": index, "mvp_mtx": self.mvp_mtx[index], "c2w": self.c2w[index], "c2w_3dgs": self.c2w_3dgs[index], "camera_positions": self.camera_positions[index], "light_positions": self.light_positions[index], "elevation": self.elevation_deg[index], "azimuth": self.azimuth_deg[index], "camera_distances": self.camera_distances[index], "height": self.cfg.eval_height, "width": self.cfg.eval_width, "fovy":self.fovy[index], } def collate(self, batch): batch = torch.utils.data.default_collate(batch) batch.update({"height": self.cfg.eval_height, "width": self.cfg.eval_width}) return batch @register("random-camera-datamodule") class RandomCameraDataModule(pl.LightningDataModule): cfg: RandomCameraDataModuleConfig def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: super().__init__() self.cfg = parse_structured(RandomCameraDataModuleConfig, cfg) def setup(self, stage=None) -> None: if stage in [None, "fit"]: self.train_dataset = RandomCameraIterableDataset(self.cfg) if stage in [None, "fit", "validate"]: self.val_dataset = RandomCameraDataset(self.cfg, "val") if stage in [None, "test", "predict"]: self.test_dataset = RandomCameraDataset(self.cfg, "test") def prepare_data(self): pass def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: return DataLoader( dataset, # very important to disable multi-processing if you want to change self attributes at runtime! # (for example setting self.width and self.height in update_step) num_workers=0, # type: ignore batch_size=batch_size, collate_fn=collate_fn, ) def train_dataloader(self) -> DataLoader: return self.general_loader( self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate ) def val_dataloader(self) -> DataLoader: return self.general_loader( self.val_dataset, batch_size=1, collate_fn=self.val_dataset.collate ) # return self.general_loader(self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate) def test_dataloader(self) -> DataLoader: return self.general_loader( self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate ) def predict_dataloader(self) -> DataLoader: return self.general_loader( self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate )