import gzip import json import os import warnings from dataclasses import dataclass, field from typing import List import cv2 import numpy as np import pytorch_lightning as pl import torch import torchvision.transforms.functional as TF from PIL import Image from torch.utils.data import DataLoader, Dataset, IterableDataset from threestudio import register from threestudio.data.uncond import ( RandomCameraDataModuleConfig, RandomCameraDataset, RandomCameraIterableDataset, ) from threestudio.utils.config import parse_structured from threestudio.utils.misc import get_rank from threestudio.utils.ops import ( get_mvp_matrix, get_projection_matrix, get_ray_directions, get_rays, ) from threestudio.utils.typing import * def _load_16big_png_depth(depth_png) -> np.ndarray: with Image.open(depth_png) as depth_pil: # the image is stored with 16-bit depth but PIL reads it as I (32 bit). # we cast it to uint16, then reinterpret as float16, then cast to float32 depth = ( np.frombuffer(np.array(depth_pil, dtype=np.uint16), dtype=np.float16) .astype(np.float32) .reshape((depth_pil.size[1], depth_pil.size[0])) ) return depth def _load_depth(path, scale_adjustment) -> np.ndarray: if not path.lower().endswith(".png"): raise ValueError('unsupported depth file name "%s"' % path) d = _load_16big_png_depth(path) * scale_adjustment d[~np.isfinite(d)] = 0.0 return d[None] # fake feature channel # Code adapted from https://github.com/eldar/snes/blob/473ff2b1f6/3rdparty/co3d/dataset/co3d_dataset.py def _get_1d_bounds(arr): nz = np.flatnonzero(arr) return nz[0], nz[-1] def get_bbox_from_mask(mask, thr, decrease_quant=0.05): # bbox in xywh masks_for_box = np.zeros_like(mask) while masks_for_box.sum() <= 1.0: masks_for_box = (mask > thr).astype(np.float32) thr -= decrease_quant if thr <= 0.0: warnings.warn(f"Empty masks_for_bbox (thr={thr}) => using full image.") x0, x1 = _get_1d_bounds(masks_for_box.sum(axis=-2)) y0, y1 = _get_1d_bounds(masks_for_box.sum(axis=-1)) return x0, y0, x1 - x0, y1 - y0 def get_clamp_bbox(bbox, box_crop_context=0.0, impath=""): # box_crop_context: rate of expansion for bbox # returns possibly expanded bbox xyxy as float # increase box size if box_crop_context > 0.0: c = box_crop_context bbox = bbox.astype(np.float32) bbox[0] -= bbox[2] * c / 2 bbox[1] -= bbox[3] * c / 2 bbox[2] += bbox[2] * c bbox[3] += bbox[3] * c if (bbox[2:] <= 1.0).any(): warnings.warn(f"squashed image {impath}!!") return None # bbox[2:] = np.clip(bbox[2:], 2, ) bbox[2:] = np.maximum(bbox[2:], 2) bbox[2:] += bbox[0:2] + 1 # convert to [xmin, ymin, xmax, ymax] # +1 because upper bound is not inclusive return bbox def crop_around_box(tensor, bbox, impath=""): bbox[[0, 2]] = np.clip(bbox[[0, 2]], 0.0, tensor.shape[-2]) bbox[[1, 3]] = np.clip(bbox[[1, 3]], 0.0, tensor.shape[-3]) bbox = bbox.round().astype(np.longlong) return tensor[bbox[1] : bbox[3], bbox[0] : bbox[2], ...] def resize_image(image, height, width, mode="bilinear"): if image.shape[:2] == (height, width): return image, 1.0, np.ones_like(image[..., :1]) image = torch.from_numpy(image).permute(2, 0, 1) minscale = min(height / image.shape[-2], width / image.shape[-1]) imre = torch.nn.functional.interpolate( image[None], scale_factor=minscale, mode=mode, align_corners=False if mode == "bilinear" else None, recompute_scale_factor=True, )[0] # pyre-fixme[19]: Expected 1 positional argument. imre_ = torch.zeros(image.shape[0], height, width) imre_[:, 0 : imre.shape[1], 0 : imre.shape[2]] = imre # pyre-fixme[6]: For 2nd param expected `int` but got `Optional[int]`. # pyre-fixme[6]: For 3rd param expected `int` but got `Optional[int]`. mask = torch.zeros(1, height, width) mask[:, 0 : imre.shape[1], 0 : imre.shape[2]] = 1.0 return imre_.permute(1, 2, 0).numpy(), minscale, mask.permute(1, 2, 0).numpy() # Code adapted from https://github.com/POSTECH-CVLab/PeRFception/data_util/co3d.py def similarity_from_cameras(c2w, fix_rot=False, radius=1.0): """ Get a similarity transform to normalize dataset from c2w (OpenCV convention) cameras :param c2w: (N, 4) :return T (4,4) , scale (float) """ t = c2w[:, :3, 3] R = c2w[:, :3, :3] # (1) Rotate the world so that z+ is the up axis # we estimate the up axis by averaging the camera up axes ups = np.sum(R * np.array([0, -1.0, 0]), axis=-1) world_up = np.mean(ups, axis=0) world_up /= np.linalg.norm(world_up) up_camspace = np.array([0.0, 0.0, 1.0]) c = (up_camspace * world_up).sum() cross = np.cross(world_up, up_camspace) skew = np.array( [ [0.0, -cross[2], cross[1]], [cross[2], 0.0, -cross[0]], [-cross[1], cross[0], 0.0], ] ) if c > -1: R_align = np.eye(3) + skew + (skew @ skew) * 1 / (1 + c) else: # In the unlikely case the original data has y+ up axis, # rotate 180-deg about x axis R_align = np.array([[-1.0, 0.0, 0.0], [0.0, 1.0, 0.0], [0.0, 0.0, 1.0]]) if fix_rot: R_align = np.eye(3) R = np.eye(3) else: R = R_align @ R fwds = np.sum(R * np.array([0, 0.0, 1.0]), axis=-1) t = (R_align @ t[..., None])[..., 0] # (2) Recenter the scene using camera center rays # find the closest point to the origin for each camera's center ray nearest = t + (fwds * -t).sum(-1)[:, None] * fwds # median for more robustness translate = -np.median(nearest, axis=0) # translate = -np.mean(t, axis=0) # DEBUG transform = np.eye(4) transform[:3, 3] = translate transform[:3, :3] = R_align # (3) Rescale the scene using camera distances scale = radius / np.median(np.linalg.norm(t + translate, axis=-1)) return transform, scale @dataclass class Co3dDataModuleConfig: root_dir: str = "" batch_size: int = 1 height: int = 256 width: int = 256 load_preprocessed: bool = False cam_scale_factor: float = 0.95 max_num_frames: int = 300 v2_mode: bool = True use_mask: bool = True box_crop: bool = True box_crop_mask_thr: float = 0.4 box_crop_context: float = 0.3 train_num_rays: int = -1 train_views: Optional[list] = None train_split: str = "train" val_split: str = "val" test_split: str = "test" scale_radius: float = 1.0 use_random_camera: bool = True random_camera: dict = field(default_factory=dict) rays_noise_scale: float = 0.0 render_path: str = "circle" class Co3dDatasetBase: def setup(self, cfg, split): self.split = split self.rank = get_rank() self.cfg: Co3dDataModuleConfig = cfg if self.cfg.use_random_camera: random_camera_cfg = parse_structured( RandomCameraDataModuleConfig, self.cfg.get("random_camera", {}) ) if split == "train": self.random_pose_generator = RandomCameraIterableDataset( random_camera_cfg ) else: self.random_pose_generator = RandomCameraDataset( random_camera_cfg, split ) self.use_mask = self.cfg.use_mask cam_scale_factor = self.cfg.cam_scale_factor assert os.path.exists(self.cfg.root_dir), f"{self.cfg.root_dir} doesn't exist!" cam_trans = np.diag(np.array([-1, -1, 1, 1], dtype=np.float32)) scene_number = self.cfg.root_dir.split("/")[-1] json_path = os.path.join(self.cfg.root_dir, "..", "frame_annotations.jgz") with gzip.open(json_path, "r") as fp: all_frames_data = json.load(fp) frame_data, images, intrinsics, extrinsics, image_sizes = [], [], [], [], [] masks = [] depths = [] for temporal_data in all_frames_data: if temporal_data["sequence_name"] == scene_number: frame_data.append(temporal_data) self.all_directions = [] self.all_fg_masks = [] for frame in frame_data: if "unseen" in frame["meta"]["frame_type"]: continue img = cv2.imread( os.path.join(self.cfg.root_dir, "..", "..", frame["image"]["path"]) ) img = cv2.cvtColor(img, cv2.COLOR_BGR2RGB) / 255.0 # TODO: use estimated depth depth = _load_depth( os.path.join(self.cfg.root_dir, "..", "..", frame["depth"]["path"]), frame["depth"]["scale_adjustment"], )[0] H, W = frame["image"]["size"] image_size = np.array([H, W]) fxy = np.array(frame["viewpoint"]["focal_length"]) cxy = np.array(frame["viewpoint"]["principal_point"]) R = np.array(frame["viewpoint"]["R"]) T = np.array(frame["viewpoint"]["T"]) if self.cfg.v2_mode: min_HW = min(W, H) image_size_half = np.array([W * 0.5, H * 0.5], dtype=np.float32) scale_arr = np.array([min_HW * 0.5, min_HW * 0.5], dtype=np.float32) fxy_x = fxy * scale_arr prp_x = np.array([W * 0.5, H * 0.5], dtype=np.float32) - cxy * scale_arr cxy = (image_size_half - prp_x) / image_size_half fxy = fxy_x / image_size_half scale_arr = np.array([W * 0.5, H * 0.5], dtype=np.float32) focal = fxy * scale_arr prp = -1.0 * (cxy - 1.0) * scale_arr pose = np.eye(4) pose[:3, :3] = R pose[:3, 3:] = -R @ T[..., None] # original camera: x left, y up, z in (Pytorch3D) # transformed camera: x right, y down, z in (OpenCV) pose = pose @ cam_trans intrinsic = np.array( [ [focal[0], 0.0, prp[0], 0.0], [0.0, focal[1], prp[1], 0.0], [0.0, 0.0, 1.0, 0.0], [0.0, 0.0, 0.0, 1.0], ] ) if any([np.all(pose == _pose) for _pose in extrinsics]): continue image_sizes.append(image_size) intrinsics.append(intrinsic) extrinsics.append(pose) images.append(img) depths.append(depth) self.all_directions.append(get_ray_directions(W, H, focal, prp)) # vis_utils.vis_depth_pcd([depth], [pose], intrinsic, [(img * 255).astype(np.uint8)]) if self.use_mask: mask = np.array( Image.open( os.path.join( self.cfg.root_dir, "..", "..", frame["mask"]["path"] ) ) ) mask = mask.astype(np.float32) / 255.0 # (h, w) else: mask = torch.ones_like(img[..., 0]) self.all_fg_masks.append(mask) intrinsics = np.stack(intrinsics) extrinsics = np.stack(extrinsics) image_sizes = np.stack(image_sizes) self.all_directions = torch.stack(self.all_directions, dim=0) self.all_fg_masks = np.stack(self.all_fg_masks, 0) H_median, W_median = np.median( np.stack([image_size for image_size in image_sizes]), axis=0 ) H_inlier = np.abs(image_sizes[:, 0] - H_median) / H_median < 0.1 W_inlier = np.abs(image_sizes[:, 1] - W_median) / W_median < 0.1 inlier = np.logical_and(H_inlier, W_inlier) dists = np.linalg.norm( extrinsics[:, :3, 3] - np.median(extrinsics[:, :3, 3], axis=0), axis=-1 ) med = np.median(dists) good_mask = dists < (med * 5.0) inlier = np.logical_and(inlier, good_mask) if inlier.sum() != 0: intrinsics = intrinsics[inlier] extrinsics = extrinsics[inlier] image_sizes = image_sizes[inlier] images = [images[i] for i in range(len(inlier)) if inlier[i]] depths = [depths[i] for i in range(len(inlier)) if inlier[i]] self.all_directions = self.all_directions[inlier] self.all_fg_masks = self.all_fg_masks[inlier] extrinsics = np.stack(extrinsics) T, sscale = similarity_from_cameras(extrinsics, radius=self.cfg.scale_radius) extrinsics = T @ extrinsics extrinsics[:, :3, 3] *= sscale * cam_scale_factor depths = [depth * sscale * cam_scale_factor for depth in depths] num_frames = len(extrinsics) if self.cfg.max_num_frames < num_frames: num_frames = self.cfg.max_num_frames extrinsics = extrinsics[:num_frames] intrinsics = intrinsics[:num_frames] image_sizes = image_sizes[:num_frames] images = images[:num_frames] depths = depths[:num_frames] self.all_directions = self.all_directions[:num_frames] self.all_fg_masks = self.all_fg_masks[:num_frames] if self.cfg.box_crop: print("cropping...") crop_masks = [] crop_imgs = [] crop_depths = [] crop_directions = [] crop_xywhs = [] max_sl = 0 for i in range(num_frames): bbox_xywh = np.array( get_bbox_from_mask(self.all_fg_masks[i], self.cfg.box_crop_mask_thr) ) clamp_bbox_xywh = get_clamp_bbox(bbox_xywh, self.cfg.box_crop_context) max_sl = max(clamp_bbox_xywh[2] - clamp_bbox_xywh[0], max_sl) max_sl = max(clamp_bbox_xywh[3] - clamp_bbox_xywh[1], max_sl) mask = crop_around_box(self.all_fg_masks[i][..., None], clamp_bbox_xywh) img = crop_around_box(images[i], clamp_bbox_xywh) depth = crop_around_box(depths[i][..., None], clamp_bbox_xywh) # resize to the same shape mask, _, _ = resize_image(mask, self.cfg.height, self.cfg.width) depth, _, _ = resize_image(depth, self.cfg.height, self.cfg.width) img, scale, _ = resize_image(img, self.cfg.height, self.cfg.width) fx, fy, cx, cy = ( intrinsics[i][0, 0], intrinsics[i][1, 1], intrinsics[i][0, 2], intrinsics[i][1, 2], ) crop_masks.append(mask) crop_imgs.append(img) crop_depths.append(depth) crop_xywhs.append(clamp_bbox_xywh) crop_directions.append( get_ray_directions( self.cfg.height, self.cfg.width, (fx * scale, fy * scale), ( (cx - clamp_bbox_xywh[0]) * scale, (cy - clamp_bbox_xywh[1]) * scale, ), ) ) # # pad all images to the same shape # for i in range(num_frames): # uh = (max_sl - crop_imgs[i].shape[0]) // 2 # h # dh = max_sl - crop_imgs[i].shape[0] - uh # lw = (max_sl - crop_imgs[i].shape[1]) // 2 # rw = max_sl - crop_imgs[i].shape[1] - lw # crop_masks[i] = np.pad(crop_masks[i], pad_width=((uh, dh), (lw, rw), (0, 0)), mode='constant', constant_values=0.) # crop_imgs[i] = np.pad(crop_imgs[i], pad_width=((uh, dh), (lw, rw), (0, 0)), mode='constant', constant_values=1.) # crop_depths[i] = np.pad(crop_depths[i], pad_width=((uh, dh), (lw, rw), (0, 0)), mode='constant', constant_values=0.) # fx, fy, cx, cy = intrinsics[i][0, 0], intrinsics[i][1, 1], intrinsics[i][0, 2], intrinsics[i][1, 2] # crop_directions.append(get_ray_directions(max_sl, max_sl, (fx, fy), (cx - crop_xywhs[i][0] + lw, cy - crop_xywhs[i][1] + uh))) # self.w, self.h = max_sl, max_sl images = crop_imgs depths = crop_depths self.all_fg_masks = np.stack(crop_masks, 0) self.all_directions = torch.from_numpy(np.stack(crop_directions, 0)) # self.width, self.height = self.w, self.h self.all_c2w = torch.from_numpy( ( extrinsics @ np.diag(np.array([1, -1, -1, 1], dtype=np.float32))[None, ...] )[..., :3, :4] ) self.all_images = torch.from_numpy(np.stack(images, axis=0)) self.all_depths = torch.from_numpy(np.stack(depths, axis=0)) # self.all_c2w = [] # self.all_images = [] # for i in range(num_frames): # # convert to: x right, y up, z back (OpenGL) # c2w = torch.from_numpy(extrinsics[i] @ np.diag(np.array([1, -1, -1, 1], dtype=np.float32)))[:3, :4] # self.all_c2w.append(c2w) # img = torch.from_numpy(images[i]) # self.all_images.append(img) # TODO: save data for fast loading next time if self.cfg.load_preprocessed and os.path.exists( self.cfg.root_dir, "nerf_preprocessed.npy" ): pass i_all = np.arange(num_frames) if self.cfg.train_views is None: i_test = i_all[::10] i_val = i_test i_train = np.array([i for i in i_all if not i in i_test]) else: # use provided views i_train = self.cfg.train_views i_test = np.array([i for i in i_all if not i in i_train]) i_val = i_test if self.split == "train": print("[INFO] num of train views: ", len(i_train)) print("[INFO] train view ids = ", i_train) i_split = {"train": i_train, "val": i_val, "test": i_all} # if self.split == 'test': # self.all_c2w = create_spheric_poses(self.all_c2w[:,:,3], n_steps=self.cfg.n_test_traj_steps) # self.all_images = torch.zeros((self.cfg.n_test_traj_steps, self.h, self.w, 3), dtype=torch.float32) # self.all_fg_masks = torch.zeros((self.cfg.n_test_traj_steps, self.h, self.w), dtype=torch.float32) # self.directions = self.directions[0].to(self.rank) # else: self.all_images, self.all_c2w = ( self.all_images[i_split[self.split]], self.all_c2w[i_split[self.split]], ) self.all_directions = self.all_directions[i_split[self.split]].to(self.rank) self.all_fg_masks = torch.from_numpy(self.all_fg_masks)[i_split[self.split]] self.all_depths = self.all_depths[i_split[self.split]] # if render_random_pose: # render_poses = random_pose(extrinsics[i_all], 50) # elif render_scene_interp: # render_poses = pose_interp(extrinsics[i_all], interp_fac) # render_poses = spherical_poses(sscale * cam_scale_factor * np.eye(4)) # near, far = 0., 1. # ndc_coeffs = (-1., -1.) self.all_c2w, self.all_images, self.all_fg_masks = ( self.all_c2w.float().to(self.rank), self.all_images.float().to(self.rank), self.all_fg_masks.float().to(self.rank), ) # self.all_c2w, self.all_images, self.all_fg_masks = \ # self.all_c2w.float(), \ # self.all_images.float(), \ # self.all_fg_masks.float() self.all_depths = self.all_depths.float().to(self.rank) def get_all_images(self): return self.all_images class Co3dDataset(Dataset, Co3dDatasetBase): def __init__(self, cfg, split): self.setup(cfg, split) def __len__(self): if self.split == "test": if self.cfg.render_path == "circle": return len(self.random_pose_generator) else: return len(self.all_images) else: return len(self.random_pose_generator) # return len(self.all_images) def prepare_data(self, index): # prepare batch data here c2w = self.all_c2w[index] light_positions = c2w[..., :3, -1] directions = self.all_directions[index] rays_o, rays_d = get_rays( directions, c2w, keepdim=True, noise_scale=self.cfg.rays_noise_scale ) rgb = self.all_images[index] depth = self.all_depths[index] mask = self.all_fg_masks[index] # TODO: get projection matrix and mvp matrix # proj_mtx = get_projection_matrix() batch = { "rays_o": rays_o, "rays_d": rays_d, "mvp_mtx": 0, "camera_positions": c2w[..., :3, -1], "light_positions": light_positions, "elevation": 0, "azimuth": 0, "camera_distances": 0, "rgb": rgb, "depth": depth, "mask": mask, } # c2w = self.all_c2w[index] # return { # 'index': index, # 'c2w': c2w, # 'light_positions': c2w[:3, -1], # 'H': self.h, # 'W': self.w # } return batch def __getitem__(self, index): if self.split == "test": if self.cfg.render_path == "circle": return self.random_pose_generator[index] else: return self.prepare_data(index) else: return self.random_pose_generator[index] class Co3dIterableDataset(IterableDataset, Co3dDatasetBase): def __init__(self, cfg, split): self.setup(cfg, split) self.idx = 0 self.image_perm = torch.randperm(len(self.all_images)) def __iter__(self): while True: yield {} def collate(self, batch) -> Dict[str, Any]: idx = self.image_perm[self.idx] # prepare batch data here c2w = self.all_c2w[idx][None] light_positions = c2w[..., :3, -1] directions = self.all_directions[idx][None] rays_o, rays_d = get_rays( directions, c2w, keepdim=True, noise_scale=self.cfg.rays_noise_scale ) rgb = self.all_images[idx][None] depth = self.all_depths[idx][None] mask = self.all_fg_masks[idx][None] if ( self.cfg.train_num_rays != -1 and self.cfg.train_num_rays < self.cfg.height * self.cfg.width ): _, height, width, _ = rays_o.shape x = torch.randint( 0, width, size=(self.cfg.train_num_rays,), device=rays_o.device ) y = torch.randint( 0, height, size=(self.cfg.train_num_rays,), device=rays_o.device ) rays_o = rays_o[:, y, x].unsqueeze(-2) rays_d = rays_d[:, y, x].unsqueeze(-2) directions = directions[:, y, x].unsqueeze(-2) rgb = rgb[:, y, x].unsqueeze(-2) mask = mask[:, y, x].unsqueeze(-2) depth = depth[:, y, x].unsqueeze(-2) # TODO: get projection matrix and mvp matrix # proj_mtx = get_projection_matrix() batch = { "rays_o": rays_o, "rays_d": rays_d, "mvp_mtx": None, "camera_positions": c2w[..., :3, -1], "light_positions": light_positions, "elevation": None, "azimuth": None, "camera_distances": None, "rgb": rgb, "depth": depth, "mask": mask, } if self.cfg.use_random_camera: batch["random_camera"] = self.random_pose_generator.collate(None) # prepare batch data in system # c2w = self.all_c2w[idx][None] # batch = { # 'index': torch.tensor([idx]), # 'c2w': c2w, # 'light_positions': c2w[..., :3, -1], # 'H': self.h, # 'W': self.w # } self.idx += 1 if self.idx == len(self.all_images): self.idx = 0 self.image_perm = torch.randperm(len(self.all_images)) # self.idx = (self.idx + 1) % len(self.all_images) return batch @register("co3d-datamodule") class Co3dDataModule(pl.LightningDataModule): def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: super().__init__() self.cfg = parse_structured(Co3dDataModuleConfig, cfg) def setup(self, stage=None): if stage in [None, "fit"]: self.train_dataset = Co3dIterableDataset(self.cfg, self.cfg.train_split) if stage in [None, "fit", "validate"]: self.val_dataset = Co3dDataset(self.cfg, self.cfg.val_split) if stage in [None, "test", "predict"]: self.test_dataset = Co3dDataset(self.cfg, self.cfg.test_split) def prepare_data(self): pass def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: sampler = None return DataLoader( dataset, num_workers=0, batch_size=batch_size, # pin_memory=True, collate_fn=collate_fn, ) def train_dataloader(self): return self.general_loader( self.train_dataset, batch_size=1, collate_fn=self.train_dataset.collate ) def val_dataloader(self): return self.general_loader(self.val_dataset, batch_size=1) def test_dataloader(self): return self.general_loader(self.test_dataset, batch_size=1) def predict_dataloader(self): return self.general_loader(self.test_dataset, batch_size=1)