Spaces:
Runtime error
Runtime error
import bisect | |
import math | |
import os | |
from dataclasses import dataclass, field | |
import cv2 | |
import numpy as np | |
import pytorch_lightning as pl | |
import torch | |
import torch.nn.functional as F | |
from torch.utils.data import DataLoader, Dataset, IterableDataset | |
import threestudio | |
from threestudio import register | |
from threestudio.data.uncond import ( | |
RandomCameraDataModuleConfig, | |
RandomCameraDataset, | |
RandomCameraIterableDataset, | |
) | |
from threestudio.utils.base import Updateable | |
from threestudio.utils.config import parse_structured | |
from threestudio.utils.misc import get_rank | |
from threestudio.utils.ops import ( | |
get_mvp_matrix, | |
get_projection_matrix, | |
get_ray_directions, | |
get_rays, | |
) | |
from threestudio.utils.typing import * | |
class SingleImageDataModuleConfig: | |
# height and width should be Union[int, List[int]] | |
# but OmegaConf does not support Union of containers | |
height: Any = 96 | |
width: Any = 96 | |
resolution_milestones: List[int] = field(default_factory=lambda: []) | |
default_elevation_deg: float = 0.0 | |
default_azimuth_deg: float = -180.0 | |
default_camera_distance: float = 1.2 | |
default_fovy_deg: float = 60.0 | |
image_path: str = "" | |
use_random_camera: bool = True | |
random_camera: dict = field(default_factory=dict) | |
rays_noise_scale: float = 2e-3 | |
batch_size: int = 1 | |
requires_depth: bool = False | |
requires_normal: bool = False | |
class SingleImageDataBase: | |
def setup(self, cfg, split): | |
self.split = split | |
self.rank = get_rank() | |
self.cfg: SingleImageDataModuleConfig = cfg | |
if self.cfg.use_random_camera: | |
random_camera_cfg = parse_structured( | |
RandomCameraDataModuleConfig, self.cfg.get("random_camera", {}) | |
) | |
if split == "train": | |
self.random_pose_generator = RandomCameraIterableDataset( | |
random_camera_cfg | |
) | |
else: | |
self.random_pose_generator = RandomCameraDataset( | |
random_camera_cfg, split | |
) | |
elevation_deg = torch.FloatTensor([self.cfg.default_elevation_deg]) | |
azimuth_deg = torch.FloatTensor([self.cfg.default_azimuth_deg]) | |
camera_distance = torch.FloatTensor([self.cfg.default_camera_distance]) | |
elevation = elevation_deg * math.pi / 180 | |
azimuth = azimuth_deg * math.pi / 180 | |
camera_position: Float[Tensor, "1 3"] = torch.stack( | |
[ | |
camera_distance * torch.cos(elevation) * torch.cos(azimuth), | |
camera_distance * torch.cos(elevation) * torch.sin(azimuth), | |
camera_distance * torch.sin(elevation), | |
], | |
dim=-1, | |
) | |
center: Float[Tensor, "1 3"] = torch.zeros_like(camera_position) | |
up: Float[Tensor, "1 3"] = torch.as_tensor([0, 0, 1], dtype=torch.float32)[None] | |
light_position: Float[Tensor, "1 3"] = camera_position | |
lookat: Float[Tensor, "1 3"] = F.normalize(center - camera_position, dim=-1) | |
right: Float[Tensor, "1 3"] = F.normalize(torch.cross(lookat, up), dim=-1) | |
up = F.normalize(torch.cross(right, lookat), dim=-1) | |
self.c2w: Float[Tensor, "1 3 4"] = torch.cat( | |
[torch.stack([right, up, -lookat], dim=-1), camera_position[:, :, None]], | |
dim=-1, | |
) | |
self.camera_position = camera_position | |
self.light_position = light_position | |
self.elevation_deg, self.azimuth_deg = elevation_deg, azimuth_deg | |
self.camera_distance = camera_distance | |
self.fovy = torch.deg2rad(torch.FloatTensor([self.cfg.default_fovy_deg])) | |
self.heights: List[int] = ( | |
[self.cfg.height] if isinstance(self.cfg.height, int) else self.cfg.height | |
) | |
self.widths: List[int] = ( | |
[self.cfg.width] if isinstance(self.cfg.width, int) else self.cfg.width | |
) | |
assert len(self.heights) == len(self.widths) | |
self.resolution_milestones: List[int] | |
if len(self.heights) == 1 and len(self.widths) == 1: | |
if len(self.cfg.resolution_milestones) > 0: | |
threestudio.warn( | |
"Ignoring resolution_milestones since height and width are not changing" | |
) | |
self.resolution_milestones = [-1] | |
else: | |
assert len(self.heights) == len(self.cfg.resolution_milestones) + 1 | |
self.resolution_milestones = [-1] + self.cfg.resolution_milestones | |
self.directions_unit_focals = [ | |
get_ray_directions(H=height, W=width, focal=1.0) | |
for (height, width) in zip(self.heights, self.widths) | |
] | |
self.focal_lengths = [ | |
0.5 * height / torch.tan(0.5 * self.fovy) for height in self.heights | |
] | |
self.height: int = self.heights[0] | |
self.width: int = self.widths[0] | |
self.directions_unit_focal = self.directions_unit_focals[0] | |
self.focal_length = self.focal_lengths[0] | |
self.set_rays() | |
self.load_images() | |
self.prev_height = self.height | |
def set_rays(self): | |
# get directions by dividing directions_unit_focal by focal length | |
directions: Float[Tensor, "1 H W 3"] = self.directions_unit_focal[None] | |
directions[:, :, :, :2] = directions[:, :, :, :2] / self.focal_length | |
rays_o, rays_d = get_rays( | |
directions, self.c2w, keepdim=True, noise_scale=self.cfg.rays_noise_scale | |
) | |
proj_mtx: Float[Tensor, "4 4"] = get_projection_matrix( | |
self.fovy, self.width / self.height, 0.1, 100.0 | |
) # FIXME: hard-coded near and far | |
mvp_mtx: Float[Tensor, "4 4"] = get_mvp_matrix(self.c2w, proj_mtx) | |
self.rays_o, self.rays_d = rays_o, rays_d | |
self.mvp_mtx = mvp_mtx | |
def load_images(self): | |
# load image | |
assert os.path.exists( | |
self.cfg.image_path | |
), f"Could not find image {self.cfg.image_path}!" | |
rgba = cv2.cvtColor( | |
cv2.imread(self.cfg.image_path, cv2.IMREAD_UNCHANGED), cv2.COLOR_BGRA2RGBA | |
) | |
rgba = ( | |
cv2.resize( | |
rgba, (self.width, self.height), interpolation=cv2.INTER_AREA | |
).astype(np.float32) | |
/ 255.0 | |
) | |
rgb = rgba[..., :3] | |
self.rgb: Float[Tensor, "1 H W 3"] = ( | |
torch.from_numpy(rgb).unsqueeze(0).contiguous().to(self.rank) | |
) | |
self.mask: Float[Tensor, "1 H W 1"] = ( | |
torch.from_numpy(rgba[..., 3:] > 0.5).unsqueeze(0).to(self.rank) | |
) | |
print( | |
f"[INFO] single image dataset: load image {self.cfg.image_path} {self.rgb.shape}" | |
) | |
# load depth | |
if self.cfg.requires_depth: | |
depth_path = self.cfg.image_path.replace("_rgba.png", "_depth.png") | |
assert os.path.exists(depth_path) | |
depth = cv2.imread(depth_path, cv2.IMREAD_UNCHANGED) | |
depth = cv2.resize( | |
depth, (self.width, self.height), interpolation=cv2.INTER_AREA | |
) | |
self.depth: Float[Tensor, "1 H W 1"] = ( | |
torch.from_numpy(depth.astype(np.float32) / 255.0) | |
.unsqueeze(0) | |
.to(self.rank) | |
) | |
print( | |
f"[INFO] single image dataset: load depth {depth_path} {self.depth.shape}" | |
) | |
else: | |
self.depth = None | |
# load normal | |
if self.cfg.requires_normal: | |
normal_path = self.cfg.image_path.replace("_rgba.png", "_normal.png") | |
assert os.path.exists(normal_path) | |
normal = cv2.imread(normal_path, cv2.IMREAD_UNCHANGED) | |
normal = cv2.resize( | |
normal, (self.width, self.height), interpolation=cv2.INTER_AREA | |
) | |
self.normal: Float[Tensor, "1 H W 3"] = ( | |
torch.from_numpy(normal.astype(np.float32) / 255.0) | |
.unsqueeze(0) | |
.to(self.rank) | |
) | |
print( | |
f"[INFO] single image dataset: load normal {normal_path} {self.normal.shape}" | |
) | |
else: | |
self.normal = None | |
def get_all_images(self): | |
return self.rgb | |
def update_step_(self, epoch: int, global_step: int, on_load_weights: bool = False): | |
size_ind = bisect.bisect_right(self.resolution_milestones, global_step) - 1 | |
self.height = self.heights[size_ind] | |
if self.height == self.prev_height: | |
return | |
self.prev_height = self.height | |
self.width = self.widths[size_ind] | |
self.directions_unit_focal = self.directions_unit_focals[size_ind] | |
self.focal_length = self.focal_lengths[size_ind] | |
threestudio.debug(f"Training height: {self.height}, width: {self.width}") | |
self.set_rays() | |
self.load_images() | |
class SingleImageIterableDataset(IterableDataset, SingleImageDataBase, Updateable): | |
def __init__(self, cfg: Any, split: str) -> None: | |
super().__init__() | |
self.setup(cfg, split) | |
def collate(self, batch) -> Dict[str, Any]: | |
batch = { | |
"rays_o": self.rays_o, | |
"rays_d": self.rays_d, | |
"mvp_mtx": self.mvp_mtx, | |
"camera_positions": self.camera_position, | |
"light_positions": self.light_position, | |
"elevation": self.elevation_deg, | |
"azimuth": self.azimuth_deg, | |
"camera_distances": self.camera_distance, | |
"rgb": self.rgb, | |
"ref_depth": self.depth, | |
"ref_normal": self.normal, | |
"mask": self.mask, | |
"height": self.cfg.height, | |
"width": self.cfg.width, | |
} | |
import pdb; pdb.set_trace() | |
if self.cfg.use_random_camera: | |
batch["random_camera"] = self.random_pose_generator.collate(None) | |
return batch | |
def update_step(self, epoch: int, global_step: int, on_load_weights: bool = False): | |
self.update_step_(epoch, global_step, on_load_weights) | |
self.random_pose_generator.update_step(epoch, global_step, on_load_weights) | |
def __iter__(self): | |
while True: | |
yield {} | |
class SingleImageDataset(Dataset, SingleImageDataBase): | |
def __init__(self, cfg: Any, split: str) -> None: | |
super().__init__() | |
self.setup(cfg, split) | |
def __len__(self): | |
return len(self.random_pose_generator) | |
def __getitem__(self, index): | |
return self.random_pose_generator[index] | |
# if index == 0: | |
# return { | |
# 'rays_o': self.rays_o[0], | |
# 'rays_d': self.rays_d[0], | |
# 'mvp_mtx': self.mvp_mtx[0], | |
# 'camera_positions': self.camera_position[0], | |
# 'light_positions': self.light_position[0], | |
# 'elevation': self.elevation_deg[0], | |
# 'azimuth': self.azimuth_deg[0], | |
# 'camera_distances': self.camera_distance[0], | |
# 'rgb': self.rgb[0], | |
# 'depth': self.depth[0], | |
# 'mask': self.mask[0] | |
# } | |
# else: | |
# return self.random_pose_generator[index - 1] | |
class SingleImageDataModule(pl.LightningDataModule): | |
cfg: SingleImageDataModuleConfig | |
def __init__(self, cfg: Optional[Union[dict, DictConfig]] = None) -> None: | |
super().__init__() | |
self.cfg = parse_structured(SingleImageDataModuleConfig, cfg) | |
def setup(self, stage=None) -> None: | |
if stage in [None, "fit"]: | |
self.train_dataset = SingleImageIterableDataset(self.cfg, "train") | |
if stage in [None, "fit", "validate"]: | |
self.val_dataset = SingleImageDataset(self.cfg, "val") | |
if stage in [None, "test", "predict"]: | |
self.test_dataset = SingleImageDataset(self.cfg, "test") | |
def prepare_data(self): | |
pass | |
def general_loader(self, dataset, batch_size, collate_fn=None) -> DataLoader: | |
return DataLoader( | |
dataset, num_workers=0, batch_size=batch_size, collate_fn=collate_fn | |
) | |
def train_dataloader(self) -> DataLoader: | |
return self.general_loader( | |
self.train_dataset, | |
batch_size=self.cfg.batch_size, | |
collate_fn=self.train_dataset.collate, | |
) | |
def val_dataloader(self) -> DataLoader: | |
return self.general_loader(self.val_dataset, batch_size=1) | |
def test_dataloader(self) -> DataLoader: | |
return self.general_loader(self.test_dataset, batch_size=1) | |
def predict_dataloader(self) -> DataLoader: | |
return self.general_loader(self.test_dataset, batch_size=1) | |