thewhole's picture
Upload 245 files
2fa4776
raw
history blame
No virus
16.1 kB
import json
import math
import os
import random
from dataclasses import dataclass
import cv2
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from scipy.spatial.transform import Rotation as Rot
from scipy.spatial.transform import Slerp
from torch.utils.data import DataLoader, Dataset, IterableDataset
from tqdm import tqdm
import threestudio
from threestudio import register
from threestudio.utils.config import parse_structured
from threestudio.utils.ops import get_mvp_matrix, get_ray_directions, get_rays
from threestudio.utils.typing import *
def convert_pose(C2W):
flip_yz = torch.eye(4)
flip_yz[1, 1] = -1
flip_yz[2, 2] = -1
C2W = torch.matmul(C2W, flip_yz)
return C2W
def convert_proj(K, H, W, near, far):
return [
[2 * K[0, 0] / W, -2 * K[0, 1] / W, (W - 2 * K[0, 2]) / W, 0],
[0, -2 * K[1, 1] / H, (H - 2 * K[1, 2]) / H, 0],
[0, 0, (-far - near) / (far - near), -2 * far * near / (far - near)],
[0, 0, -1, 0],
]
def inter_pose(pose_0, pose_1, ratio):
pose_0 = pose_0.detach().cpu().numpy()
pose_1 = pose_1.detach().cpu().numpy()
pose_0 = np.linalg.inv(pose_0)
pose_1 = np.linalg.inv(pose_1)
rot_0 = pose_0[:3, :3]
rot_1 = pose_1[:3, :3]
rots = Rot.from_matrix(np.stack([rot_0, rot_1]))
key_times = [0, 1]
slerp = Slerp(key_times, rots)
rot = slerp(ratio)
pose = np.diag([1.0, 1.0, 1.0, 1.0])
pose = pose.astype(np.float32)
pose[:3, :3] = rot.as_matrix()
pose[:3, 3] = ((1.0 - ratio) * pose_0 + ratio * pose_1)[:3, 3]
pose = np.linalg.inv(pose)
return pose
@dataclass
class MultiviewsDataModuleConfig:
dataroot: str = ""
train_downsample_resolution: int = 4
eval_downsample_resolution: int = 4
train_data_interval: int = 1
eval_data_interval: int = 1
batch_size: int = 1
eval_batch_size: int = 1
camera_layout: str = "around"
camera_distance: float = -1
eval_interpolation: Optional[Tuple[int, int, int]] = None # (0, 1, 30)
class MultiviewIterableDataset(IterableDataset):
def __init__(self, cfg: Any) -> None:
super().__init__()
self.cfg: MultiviewsDataModuleConfig = cfg
assert self.cfg.batch_size == 1
scale = self.cfg.train_downsample_resolution
camera_dict = json.load(
open(os.path.join(self.cfg.dataroot, "transforms.json"), "r")
)
assert camera_dict["camera_model"] == "OPENCV"
frames = camera_dict["frames"]
frames = frames[:: self.cfg.train_data_interval]
frames_proj = []
frames_c2w = []
frames_position = []
frames_direction = []
frames_img = []
self.frame_w = frames[0]["w"] // scale
self.frame_h = frames[0]["h"] // scale
threestudio.info("Loading frames...")
self.n_frames = len(frames)
c2w_list = []
for frame in tqdm(frames):
extrinsic: Float[Tensor, "4 4"] = torch.as_tensor(
frame["transform_matrix"], dtype=torch.float32
)
c2w = extrinsic
c2w_list.append(c2w)
c2w_list = torch.stack(c2w_list, dim=0)
if self.cfg.camera_layout == "around":
c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0)
elif self.cfg.camera_layout == "front":
assert self.cfg.camera_distance > 0
c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0)
z_vector = torch.zeros(c2w_list.shape[0], 3, 1)
z_vector[:, 2, :] = -1
rot_z_vector = c2w_list[:, :3, :3] @ z_vector
rot_z_vector = torch.mean(rot_z_vector, dim=0).unsqueeze(0)
c2w_list[:, :3, 3] -= rot_z_vector[:, :, 0] * self.cfg.camera_distance
else:
raise ValueError(
f"Unknown camera layout {self.cfg.camera_layout}. Now support only around and front."
)
for idx, frame in tqdm(enumerate(frames)):
intrinsic: Float[Tensor, "4 4"] = torch.eye(4)
intrinsic[0, 0] = frame["fl_x"] / scale
intrinsic[1, 1] = frame["fl_y"] / scale
intrinsic[0, 2] = frame["cx"] / scale
intrinsic[1, 2] = frame["cy"] / scale
frame_path = os.path.join(self.cfg.dataroot, frame["file_path"])
img = cv2.imread(frame_path)[:, :, ::-1].copy()
img = cv2.resize(img, (self.frame_w, self.frame_h))
img: Float[Tensor, "H W 3"] = torch.FloatTensor(img) / 255
frames_img.append(img)
direction: Float[Tensor, "H W 3"] = get_ray_directions(
self.frame_h,
self.frame_w,
(intrinsic[0, 0], intrinsic[1, 1]),
(intrinsic[0, 2], intrinsic[1, 2]),
use_pixel_centers=False,
)
c2w = c2w_list[idx]
camera_position: Float[Tensor, "3"] = c2w[:3, 3:].reshape(-1)
near = 0.1
far = 1000.0
proj = convert_proj(intrinsic, self.frame_h, self.frame_w, near, far)
proj: Float[Tensor, "4 4"] = torch.FloatTensor(proj)
frames_proj.append(proj)
frames_c2w.append(c2w)
frames_position.append(camera_position)
frames_direction.append(direction)
threestudio.info("Loaded frames.")
self.frames_proj: Float[Tensor, "B 4 4"] = torch.stack(frames_proj, dim=0)
self.frames_c2w: Float[Tensor, "B 4 4"] = torch.stack(frames_c2w, dim=0)
self.frames_position: Float[Tensor, "B 3"] = torch.stack(frames_position, dim=0)
self.frames_direction: Float[Tensor, "B H W 3"] = torch.stack(
frames_direction, dim=0
)
self.frames_img: Float[Tensor, "B H W 3"] = torch.stack(frames_img, dim=0)
self.rays_o, self.rays_d = get_rays(
self.frames_direction, self.frames_c2w, keepdim=True
)
self.mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(
self.frames_c2w, self.frames_proj
)
self.light_positions: Float[Tensor, "B 3"] = torch.zeros_like(
self.frames_position
)
def __iter__(self):
while True:
yield {}
def collate(self, batch):
index = torch.randint(0, self.n_frames, (1,)).item()
return {
"index": index,
"rays_o": self.rays_o[index : index + 1],
"rays_d": self.rays_d[index : index + 1],
"mvp_mtx": self.mvp_mtx[index : index + 1],
"c2w": self.frames_c2w[index : index + 1],
"camera_positions": self.frames_position[index : index + 1],
"light_positions": self.light_positions[index : index + 1],
"gt_rgb": self.frames_img[index : index + 1],
"height": self.frame_h,
"width": self.frame_w,
}
class MultiviewDataset(Dataset):
def __init__(self, cfg: Any, split: str) -> None:
super().__init__()
self.cfg: MultiviewsDataModuleConfig = cfg
assert self.cfg.eval_batch_size == 1
scale = self.cfg.eval_downsample_resolution
camera_dict = json.load(
open(os.path.join(self.cfg.dataroot, "transforms.json"), "r")
)
assert camera_dict["camera_model"] == "OPENCV"
frames = camera_dict["frames"]
frames = frames[:: self.cfg.eval_data_interval]
frames_proj = []
frames_c2w = []
frames_position = []
frames_direction = []
frames_img = []
self.frame_w = frames[0]["w"] // scale
self.frame_h = frames[0]["h"] // scale
threestudio.info("Loading frames...")
self.n_frames = len(frames)
c2w_list = []
for frame in tqdm(frames):
extrinsic: Float[Tensor, "4 4"] = torch.as_tensor(
frame["transform_matrix"], dtype=torch.float32
)
c2w = extrinsic
c2w_list.append(c2w)
c2w_list = torch.stack(c2w_list, dim=0)
if self.cfg.camera_layout == "around":
c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0)
elif self.cfg.camera_layout == "front":
assert self.cfg.camera_distance > 0
c2w_list[:, :3, 3] -= torch.mean(c2w_list[:, :3, 3], dim=0).unsqueeze(0)
z_vector = torch.zeros(c2w_list.shape[0], 3, 1)
z_vector[:, 2, :] = -1
rot_z_vector = c2w_list[:, :3, :3] @ z_vector
rot_z_vector = torch.mean(rot_z_vector, dim=0).unsqueeze(0)
c2w_list[:, :3, 3] -= rot_z_vector[:, :, 0] * self.cfg.camera_distance
else:
raise ValueError(
f"Unknown camera layout {self.cfg.camera_layout}. Now support only around and front."
)
if not (self.cfg.eval_interpolation is None):
idx0 = self.cfg.eval_interpolation[0]
idx1 = self.cfg.eval_interpolation[1]
eval_nums = self.cfg.eval_interpolation[2]
frame = frames[idx0]
intrinsic: Float[Tensor, "4 4"] = torch.eye(4)
intrinsic[0, 0] = frame["fl_x"] / scale
intrinsic[1, 1] = frame["fl_y"] / scale
intrinsic[0, 2] = frame["cx"] / scale
intrinsic[1, 2] = frame["cy"] / scale
for ratio in np.linspace(0, 1, eval_nums):
img: Float[Tensor, "H W 3"] = torch.zeros(
(self.frame_h, self.frame_w, 3)
)
frames_img.append(img)
direction: Float[Tensor, "H W 3"] = get_ray_directions(
self.frame_h,
self.frame_w,
(intrinsic[0, 0], intrinsic[1, 1]),
(intrinsic[0, 2], intrinsic[1, 2]),
use_pixel_centers=False,
)
c2w = torch.FloatTensor(
inter_pose(c2w_list[idx0], c2w_list[idx1], ratio)
)
camera_position: Float[Tensor, "3"] = c2w[:3, 3:].reshape(-1)
near = 0.1
far = 1000.0
proj = convert_proj(intrinsic, self.frame_h, self.frame_w, near, far)
proj: Float[Tensor, "4 4"] = torch.FloatTensor(proj)
frames_proj.append(proj)
frames_c2w.append(c2w)
frames_position.append(camera_position)
frames_direction.append(direction)
else:
for idx, frame in tqdm(enumerate(frames)):
intrinsic: Float[Tensor, "4 4"] = torch.eye(4)
intrinsic[0, 0] = frame["fl_x"] / scale
intrinsic[1, 1] = frame["fl_y"] / scale
intrinsic[0, 2] = frame["cx"] / scale
intrinsic[1, 2] = frame["cy"] / scale
frame_path = os.path.join(self.cfg.dataroot, frame["file_path"])
img = cv2.imread(frame_path)[:, :, ::-1].copy()
img = cv2.resize(img, (self.frame_w, self.frame_h))
img: Float[Tensor, "H W 3"] = torch.FloatTensor(img) / 255
frames_img.append(img)
direction: Float[Tensor, "H W 3"] = get_ray_directions(
self.frame_h,
self.frame_w,
(intrinsic[0, 0], intrinsic[1, 1]),
(intrinsic[0, 2], intrinsic[1, 2]),
use_pixel_centers=False,
)
c2w = c2w_list[idx]
camera_position: Float[Tensor, "3"] = c2w[:3, 3:].reshape(-1)
near = 0.1
far = 1000.0
K = intrinsic
proj = [
[
2 * K[0, 0] / self.frame_w,
-2 * K[0, 1] / self.frame_w,
(self.frame_w - 2 * K[0, 2]) / self.frame_w,
0,
],
[
0,
-2 * K[1, 1] / self.frame_h,
(self.frame_h - 2 * K[1, 2]) / self.frame_h,
0,
],
[
0,
0,
(-far - near) / (far - near),
-2 * far * near / (far - near),
],
[0, 0, -1, 0],
]
proj: Float[Tensor, "4 4"] = torch.FloatTensor(proj)
frames_proj.append(proj)
frames_c2w.append(c2w)
frames_position.append(camera_position)
frames_direction.append(direction)
threestudio.info("Loaded frames.")
self.frames_proj: Float[Tensor, "B 4 4"] = torch.stack(frames_proj, dim=0)
self.frames_c2w: Float[Tensor, "B 4 4"] = torch.stack(frames_c2w, dim=0)
self.frames_position: Float[Tensor, "B 3"] = torch.stack(frames_position, dim=0)
self.frames_direction: Float[Tensor, "B H W 3"] = torch.stack(
frames_direction, dim=0
)
self.frames_img: Float[Tensor, "B H W 3"] = torch.stack(frames_img, dim=0)
self.rays_o, self.rays_d = get_rays(
self.frames_direction, self.frames_c2w, keepdim=True
)
self.mvp_mtx: Float[Tensor, "B 4 4"] = get_mvp_matrix(
self.frames_c2w, self.frames_proj
)
self.light_positions: Float[Tensor, "B 3"] = torch.zeros_like(
self.frames_position
)
def __len__(self):
return self.frames_proj.shape[0]
def __getitem__(self, index):
return {
"index": index,
"rays_o": self.rays_o[index],
"rays_d": self.rays_d[index],
"mvp_mtx": self.mvp_mtx[index],
"c2w": self.frames_c2w[index],
"camera_positions": self.frames_position[index],
"light_positions": self.light_positions[index],
"gt_rgb": self.frames_img[index],
}
def __iter__(self):
while True:
yield {}
def collate(self, batch):
batch = torch.utils.data.default_collate(batch)
batch.update({"height": self.frame_h, "width": self.frame_w})
return batch
@register("multiview-camera-datamodule")
class MultiviewDataModule(pl.LightningDataModule):
cfg: MultiviewsDataModuleConfig
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
super().__init__()
self.cfg = parse_structured(MultiviewsDataModuleConfig, cfg)
def setup(self, stage=None) -> None:
if stage in [None, "fit"]:
self.train_dataset = MultiviewIterableDataset(self.cfg)
if stage in [None, "fit", "validate"]:
self.val_dataset = MultiviewDataset(self.cfg, "val")
if stage in [None, "test", "predict"]:
self.test_dataset = MultiviewDataset(self.cfg, "test")
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader:
return DataLoader(
dataset,
num_workers=1, # type: ignore
batch_size=batch_size,
collate_fn=collate_fn,
)
def train_dataloader(self) -> DataLoader:
return self.general_loader(
self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate
)
def val_dataloader(self) -> DataLoader:
return self.general_loader(
self.val_dataset, batch_size=1, collate_fn=self.val_dataset.collate
)
# return self.general_loader(self.train_dataset, batch_size=None, collate_fn=self.train_dataset.collate)
def test_dataloader(self) -> DataLoader:
return self.general_loader(
self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate
)
def predict_dataloader(self) -> DataLoader:
return self.general_loader(
self.test_dataset, batch_size=1, collate_fn=self.test_dataset.collate
)