thewhole's picture
Upload 245 files
2fa4776
raw
history blame
No virus
12.7 kB
import bisect
import math
import os
from dataclasses import dataclass, field
import cv2
import numpy as np
import pytorch_lightning as pl
import torch
import torch.nn.functional as F
from torch.utils.data import DataLoader, Dataset, IterableDataset
import threestudio
from threestudio import register
from threestudio.data.uncond import (
RandomCameraDataModuleConfig,
RandomCameraDataset,
RandomCameraIterableDataset,
)
from threestudio.utils.base import Updateable
from threestudio.utils.config import parse_structured
from threestudio.utils.misc import get_rank
from threestudio.utils.ops import (
get_mvp_matrix,
get_projection_matrix,
get_ray_directions,
get_rays,
)
from threestudio.utils.typing import *
@dataclass
class SingleImageDataModuleConfig:
# height and width should be Union[int, List[int]]
# but OmegaConf does not support Union of containers
height: Any = 96
width: Any = 96
resolution_milestones: List[int] = field(default_factory=lambda: [])
default_elevation_deg: float = 0.0
default_azimuth_deg: float = -180.0
default_camera_distance: float = 1.2
default_fovy_deg: float = 60.0
image_path: str = ""
use_random_camera: bool = True
random_camera: dict = field(default_factory=dict)
rays_noise_scale: float = 2e-3
batch_size: int = 1
requires_depth: bool = False
requires_normal: bool = False
class SingleImageDataBase:
def setup(self, cfg, split):
self.split = split
self.rank = get_rank()
self.cfg: SingleImageDataModuleConfig = cfg
if self.cfg.use_random_camera:
random_camera_cfg = parse_structured(
RandomCameraDataModuleConfig, self.cfg.get("random_camera", {})
)
if split == "train":
self.random_pose_generator = RandomCameraIterableDataset(
random_camera_cfg
)
else:
self.random_pose_generator = RandomCameraDataset(
random_camera_cfg, split
)
elevation_deg = torch.FloatTensor([self.cfg.default_elevation_deg])
azimuth_deg = torch.FloatTensor([self.cfg.default_azimuth_deg])
camera_distance = torch.FloatTensor([self.cfg.default_camera_distance])
elevation = elevation_deg * math.pi / 180
azimuth = azimuth_deg * math.pi / 180
camera_position: Float[Tensor, "1 3"] = torch.stack(
[
camera_distance * torch.cos(elevation) * torch.cos(azimuth),
camera_distance * torch.cos(elevation) * torch.sin(azimuth),
camera_distance * torch.sin(elevation),
],
dim=-1,
)
center: Float[Tensor, "1 3"] = torch.zeros_like(camera_position)
up: Float[Tensor, "1 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None]
light_position: Float[Tensor, "1 3"] = camera_position
lookat: Float[Tensor, "1 3"] = F.normalize(center - camera_position, dim=-1)
right: Float[Tensor, "1 3"] = F.normalize(torch.cross(lookat, up), dim=-1)
up = F.normalize(torch.cross(right, lookat), dim=-1)
self.c2w: Float[Tensor, "1 3 4"] = torch.cat(
[torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]],
dim=-1,
)
self.camera_position = camera_position
self.light_position = light_position
self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg
self.camera_distance = camera_distance
self.fovy = torch.deg2rad(torch.FloatTensor([self.cfg.default_fovy_deg]))
self.heights: List[int] = (
[self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height
)
self.widths: List[int] = (
[self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width
)
assert len(self.heights) == len(self.widths)
self.resolution_milestones: List[int]
if len(self.heights) == 1 and len(self.widths) == 1:
if len(self.cfg.resolution_milestones) > 0:
threestudio.warn(
"Ignoring resolution_milestones since height and width are not changing"
)
self.resolution_milestones = [-1]
else:
assert len(self.heights) == len(self.cfg.resolution_milestones) + 1
self.resolution_milestones = [-1] + self.cfg.resolution_milestones
self.directions_unit_focals = [
get_ray_directions(H=height, W=width, focal=1.0)
for (height, width) in zip(self.heights, self.widths)
]
self.focal_lengths = [
0.5 * height / torch.tan(0.5 * self.fovy) for height in self.heights
]
self.height: int = self.heights[0]
self.width: int = self.widths[0]
self.directions_unit_focal = self.directions_unit_focals[0]
self.focal_length = self.focal_lengths[0]
self.set_rays()
self.load_images()
self.prev_height = self.height
def set_rays(self):
# get directions by dividing directions_unit_focal by focal length
directions: Float[Tensor, "1 H W 3"] = self.directions_unit_focal[None]
directions[:, :, :, :2] = directions[:, :, :, :2] / self.focal_length
rays_o, rays_d = get_rays(
directions, self.c2w, keepdim=True, noise_scale=self.cfg.rays_noise_scale
)
proj_mtx: Float[Tensor, "4 4"] = get_projection_matrix(
self.fovy, self.width / self.height, 0.1, 100.0
) # FIXME: hard-coded near and far
mvp_mtx: Float[Tensor, "4 4"] = get_mvp_matrix(self.c2w, proj_mtx)
self.rays_o, self.rays_d = rays_o, rays_d
self.mvp_mtx = mvp_mtx
def load_images(self):
# load image
assert os.path.exists(
self.cfg.image_path
), f"Could not find image {self.cfg.image_path}!"
rgba = cv2.cvtColor(
cv2.imread(self.cfg.image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA
)
rgba = (
cv2.resize(
rgba, (self.width, self.height), interpolation=cv2.INTER_AREA
).astype(np.float32)
/ 255.0
)
rgb = rgba[..., :3]
self.rgb: Float[Tensor, "1 H W 3"] = (
torch.from_numpy(rgb).unsqueeze(0).contiguous().to(self.rank)
)
self.mask: Float[Tensor, "1 H W 1"] = (
torch.from_numpy(rgba[..., 3:] > 0.5).unsqueeze(0).to(self.rank)
)
print(
f"[INFO] single image dataset: load image {self.cfg.image_path} {self.rgb.shape}"
)
# load depth
if self.cfg.requires_depth:
depth_path = self.cfg.image_path.replace("_rgba.png", "_depth.png")
assert os.path.exists(depth_path)
depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED)
depth = cv2.resize(
depth, (self.width, self.height), interpolation=cv2.INTER_AREA
)
self.depth: Float[Tensor, "1 H W 1"] = (
torch.from_numpy(depth.astype(np.float32) / 255.0)
.unsqueeze(0)
.to(self.rank)
)
print(
f"[INFO] single image dataset: load depth {depth_path} {self.depth.shape}"
)
else:
self.depth = None
# load normal
if self.cfg.requires_normal:
normal_path = self.cfg.image_path.replace("_rgba.png", "_normal.png")
assert os.path.exists(normal_path)
normal = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED)
normal = cv2.resize(
normal, (self.width, self.height), interpolation=cv2.INTER_AREA
)
self.normal: Float[Tensor, "1 H W 3"] = (
torch.from_numpy(normal.astype(np.float32) / 255.0)
.unsqueeze(0)
.to(self.rank)
)
print(
f"[INFO] single image dataset: load normal {normal_path} {self.normal.shape}"
)
else:
self.normal = None
def get_all_images(self):
return self.rgb
def update_step_(self, epoch: int, global_step: int, on_load_weights: bool = False):
size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1
self.height = self.heights[size_ind]
if self.height == self.prev_height:
return
self.prev_height = self.height
self.width = self.widths[size_ind]
self.directions_unit_focal = self.directions_unit_focals[size_ind]
self.focal_length = self.focal_lengths[size_ind]
threestudio.debug(f"Training height: {self.height}, width: {self.width}")
self.set_rays()
self.load_images()
class SingleImageIterableDataset(IterableDataset, SingleImageDataBase, Updateable):
def __init__(self, cfg: Any, split: str) -> None:
super().__init__()
self.setup(cfg, split)
def collate(self, batch) -> Dict[str, Any]:
batch = {
"rays_o": self.rays_o,
"rays_d": self.rays_d,
"mvp_mtx": self.mvp_mtx,
"camera_positions": self.camera_position,
"light_positions": self.light_position,
"elevation": self.elevation_deg,
"azimuth": self.azimuth_deg,
"camera_distances": self.camera_distance,
"rgb": self.rgb,
"ref_depth": self.depth,
"ref_normal": self.normal,
"mask": self.mask,
"height": self.cfg.height,
"width": self.cfg.width,
}
import pdb; pdb.set_trace()
if self.cfg.use_random_camera:
batch["random_camera"] = self.random_pose_generator.collate(None)
return batch
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False):
self.update_step_(epoch, global_step, on_load_weights)
self.random_pose_generator.update_step(epoch, global_step, on_load_weights)
def __iter__(self):
while True:
yield {}
class SingleImageDataset(Dataset, SingleImageDataBase):
def __init__(self, cfg: Any, split: str) -> None:
super().__init__()
self.setup(cfg, split)
def __len__(self):
return len(self.random_pose_generator)
def __getitem__(self, index):
return self.random_pose_generator[index]
# if index == 0:
# return {
# 'rays_o': self.rays_o[0],
# 'rays_d': self.rays_d[0],
# 'mvp_mtx': self.mvp_mtx[0],
# 'camera_positions': self.camera_position[0],
# 'light_positions': self.light_position[0],
# 'elevation': self.elevation_deg[0],
# 'azimuth': self.azimuth_deg[0],
# 'camera_distances': self.camera_distance[0],
# 'rgb': self.rgb[0],
# 'depth': self.depth[0],
# 'mask': self.mask[0]
# }
# else:
# return self.random_pose_generator[index - 1]
@register("single-image-datamodule")
class SingleImageDataModule(pl.LightningDataModule):
cfg: SingleImageDataModuleConfig
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None:
super().__init__()
self.cfg = parse_structured(SingleImageDataModuleConfig, cfg)
def setup(self, stage=None) -> None:
if stage in [None, "fit"]:
self.train_dataset = SingleImageIterableDataset(self.cfg, "train")
if stage in [None, "fit", "validate"]:
self.val_dataset = SingleImageDataset(self.cfg, "val")
if stage in [None, "test", "predict"]:
self.test_dataset = SingleImageDataset(self.cfg, "test")
def prepare_data(self):
pass
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader:
return DataLoader(
dataset, num_workers=0, batch_size=batch_size, collate_fn=collate_fn
)
def train_dataloader(self) -> DataLoader:
return self.general_loader(
self.train_dataset,
batch_size=self.cfg.batch_size,
collate_fn=self.train_dataset.collate,
)
def val_dataloader(self) -> DataLoader:
return self.general_loader(self.val_dataset, batch_size=1)
def test_dataloader(self) -> DataLoader:
return self.general_loader(self.test_dataset, batch_size=1)
def predict_dataloader(self) -> DataLoader:
return self.general_loader(self.test_dataset, batch_size=1)