|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
import pdb |
|
|
|
from .base_depth_dataset import BaseDepthDataset |
|
from .eval_base_dataset import EvaluateBaseDataset, DatasetMode, get_pred_name |
|
from .diode_dataset import DIODEDataset |
|
from .eth3d_dataset import ETH3DDataset |
|
from .hypersim_dataset import HypersimDataset |
|
from .kitti_dataset import KITTIDataset |
|
from .nyu_dataset import NYUDataset |
|
from .scannet_dataset import ScanNetDataset |
|
from .vkitti_dataset import VirtualKITTIDataset |
|
from .depthanything_dataset import DepthAnythingDataset |
|
from .base_inpaint_dataset import BaseInpaintDataset |
|
|
|
dataset_name_class_dict = { |
|
"hypersim": HypersimDataset, |
|
"vkitti": VirtualKITTIDataset, |
|
"nyu_v2": NYUDataset, |
|
"kitti": KITTIDataset, |
|
"eth3d": ETH3DDataset, |
|
"diode": DIODEDataset, |
|
"scannet": ScanNetDataset, |
|
'depthanything': DepthAnythingDataset, |
|
'inpainting': BaseInpaintDataset |
|
} |
|
|
|
|
|
def get_dataset( |
|
cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs |
|
): |
|
if "mixed" == cfg_data_split.name: |
|
|
|
dataset_ls = [ |
|
get_dataset(_cfg, base_data_dir, mode, **kwargs) |
|
for _cfg in cfg_data_split.dataset_list |
|
] |
|
return dataset_ls |
|
elif cfg_data_split.name in dataset_name_class_dict.keys(): |
|
dataset_class = dataset_name_class_dict[cfg_data_split.name] |
|
dataset = dataset_class( |
|
mode=mode, |
|
filename_ls_path=cfg_data_split.filenames, |
|
dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir), |
|
**cfg_data_split, |
|
**kwargs, |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
return dataset |
|
|
|
def get_eval_dataset( |
|
cfg_data_split, base_data_dir: str, mode: DatasetMode, **kwargs |
|
) -> EvaluateBaseDataset: |
|
if "mixed" == cfg_data_split.name: |
|
assert DatasetMode.TRAIN == mode, "Only training mode supports mixed datasets." |
|
dataset_ls = [ |
|
get_dataset(_cfg, base_data_dir, mode, **kwargs) |
|
for _cfg in cfg_data_split.dataset_list |
|
] |
|
return dataset_ls |
|
elif cfg_data_split.name in dataset_name_class_dict.keys(): |
|
dataset_class = dataset_name_class_dict[cfg_data_split.name] |
|
dataset = dataset_class( |
|
mode=mode, |
|
filename_ls_path=cfg_data_split.filenames, |
|
dataset_dir=os.path.join(base_data_dir, cfg_data_split.dir), |
|
**cfg_data_split, |
|
**kwargs, |
|
) |
|
else: |
|
raise NotImplementedError |
|
|
|
return dataset |
|
|