|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import tarfile |
|
from io import BytesIO |
|
|
|
import numpy as np |
|
import torch |
|
|
|
from .eval_base_dataset import EvaluateBaseDataset, DepthFileNameMode, DatasetMode |
|
|
|
|
|
class DIODEDataset(EvaluateBaseDataset): |
|
def __init__( |
|
self, |
|
**kwargs, |
|
) -> None: |
|
super().__init__( |
|
|
|
min_depth=0.6, |
|
max_depth=350, |
|
has_filled_depth=False, |
|
name_mode=DepthFileNameMode.id, |
|
**kwargs, |
|
) |
|
|
|
def _read_npy_file(self, rel_path): |
|
if self.is_tar: |
|
if self.tar_obj is None: |
|
self.tar_obj = tarfile.open(self.dataset_dir) |
|
fileobj = self.tar_obj.extractfile("./" + rel_path) |
|
npy_path_or_content = BytesIO(fileobj.read()) |
|
else: |
|
npy_path_or_content = os.path.join(self.dataset_dir, rel_path) |
|
data = np.load(npy_path_or_content).squeeze()[np.newaxis, :, :] |
|
return data |
|
|
|
def _read_depth_file(self, rel_path): |
|
depth = self._read_npy_file(rel_path) |
|
return depth |
|
|
|
def _get_data_path(self, index): |
|
return self.filenames[index] |
|
|
|
def _get_data_item(self, index): |
|
|
|
|
|
rgb_rel_path, depth_rel_path, mask_rel_path = self._get_data_path(index=index) |
|
|
|
rasters = {} |
|
|
|
|
|
rasters.update(self._load_rgb_data(rgb_rel_path=rgb_rel_path)) |
|
|
|
|
|
if DatasetMode.RGB_ONLY != self.mode: |
|
|
|
depth_data = self._load_depth_data( |
|
depth_rel_path=depth_rel_path, filled_rel_path=None |
|
) |
|
rasters.update(depth_data) |
|
|
|
|
|
mask = self._read_npy_file(mask_rel_path).astype(bool) |
|
mask = torch.from_numpy(mask).bool() |
|
rasters["valid_mask_raw"] = mask.clone() |
|
rasters["valid_mask_filled"] = mask.clone() |
|
|
|
other = {"index": index, "rgb_relative_path": rgb_rel_path} |
|
|
|
return rasters, other |
|
|