|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import glob |
|
import io |
|
import json |
|
import os |
|
import pdb |
|
import random |
|
import tarfile |
|
from enum import Enum |
|
from typing import Union |
|
|
|
import numpy as np |
|
import torch |
|
from PIL import Image |
|
from torch.utils.data import Dataset |
|
from torchvision.transforms import InterpolationMode, Resize, CenterCrop |
|
import torchvision.transforms as transforms |
|
from transformers import CLIPTextModel, CLIPTokenizer |
|
from src.util.depth_transform import DepthNormalizerBase |
|
import random |
|
|
|
from src.dataset.eval_base_dataset import DatasetMode, DepthFileNameMode |
|
from pycocotools import mask as coco_mask |
|
from scipy.ndimage import gaussian_filter |
|
|
|
def read_image_from_tar(tar_obj, img_rel_path): |
|
image = tar_obj.extractfile("./" + img_rel_path) |
|
image = image.read() |
|
image = Image.open(io.BytesIO(image)) |
|
|
|
|
|
class BaseInpaintDataset(Dataset): |
|
def __init__( |
|
self, |
|
mode: DatasetMode, |
|
filename_ls_path: str, |
|
dataset_dir: str, |
|
disp_name: str, |
|
depth_transform: Union[DepthNormalizerBase, None] = None, |
|
tokenizer: CLIPTokenizer = None, |
|
augmentation_args: dict = None, |
|
resize_to_hw=None, |
|
move_invalid_to_far_plane: bool = True, |
|
rgb_transform=lambda x: x / 255.0 * 2 - 1, |
|
**kwargs, |
|
) -> None: |
|
super().__init__() |
|
self.mode = mode |
|
|
|
self.filename_ls_path = filename_ls_path |
|
self.disp_name = disp_name |
|
|
|
self.depth_transform: DepthNormalizerBase = depth_transform |
|
self.augm_args = augmentation_args |
|
self.resize_to_hw = resize_to_hw |
|
self.rgb_transform = rgb_transform |
|
self.move_invalid_to_far_plane = move_invalid_to_far_plane |
|
self.tokenizer = tokenizer |
|
|
|
self.filenames = [] |
|
filename_paths = glob.glob(self.filename_ls_path) |
|
for path in filename_paths: |
|
with open(path, "r") as f: |
|
self.filenames += json.load(f) |
|
|
|
self.tar_obj = None |
|
self.is_tar = ( |
|
True |
|
if os.path.isfile(dataset_dir) and tarfile.is_tarfile(dataset_dir) |
|
else False |
|
) |
|
|
|
def __len__(self): |
|
return len(self.filenames) |
|
|
|
def __getitem__(self, index): |
|
rasters, other = self._get_data_item(index) |
|
if DatasetMode.TRAIN == self.mode: |
|
rasters = self._training_preprocess(rasters) |
|
|
|
outputs = rasters |
|
outputs.update(other) |
|
return outputs |
|
|
|
def _get_data_item(self, index): |
|
rgb_path = self.filenames[index]['rgb_path'] |
|
mask_path = None |
|
if 'valid_mask' in self.filenames[index]: |
|
mask_path = self.filenames[index]['valid_mask'] |
|
if self.filenames[index]['caption'] is not None: |
|
coca_caption = self.filenames[index]['caption']['coca_caption'] |
|
spatial_caption = self.filenames[index]['caption']['spatial_caption'] |
|
empty_caption = '' |
|
caption_choices = [coca_caption, spatial_caption, empty_caption] |
|
probabilities = [0.4, 0.4, 0.2] |
|
caption = random.choices(caption_choices, probabilities)[0] |
|
else: |
|
caption = '' |
|
|
|
rasters = {} |
|
|
|
rasters.update(self._load_rgb_data(rgb_path)) |
|
|
|
try: |
|
anno = json.load(open(rgb_path.replace('.jpg', '.json')))['annotations'] |
|
random.shuffle(anno) |
|
object_num = random.randint(5, 10) |
|
mask = np.array(coco_mask.decode(anno[0]['segmentation']), dtype=np.uint8) |
|
for single_anno in (anno[0:object_num] if len(anno)>object_num else anno): |
|
mask += np.array(coco_mask.decode(single_anno['segmentation']), dtype=np.uint8) |
|
except: |
|
mask = None |
|
|
|
a = random.random() |
|
if a < 0.1 or mask is None: |
|
mask = np.zeros(rasters['rgb_int'].shape[-2:]) |
|
rows, cols = mask.shape |
|
grid_size = random.randint(5, 14) |
|
grid_rows, grid_cols = rows // grid_size, cols // grid_size |
|
for i in range(grid_rows): |
|
for j in range(grid_cols): |
|
random_prob = np.random.rand() |
|
if random_prob < 0.2: |
|
row_start = i * grid_size |
|
row_end = (i + 1) * grid_size |
|
col_start = j * grid_size |
|
col_end = (j + 1) * grid_size |
|
mask[row_start:row_end, col_start:col_end] = 1 |
|
|
|
rasters['mask'] = torch.from_numpy(mask).unsqueeze(0).to(torch.float32) |
|
|
|
if self.resize_to_hw is not None: |
|
resize_transform = transforms.Compose([ |
|
Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT), |
|
CenterCrop(size=self.resize_to_hw)]) |
|
rasters = {k: resize_transform(v) for k, v in rasters.items()} |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
other = {"index": index, "rgb_path": rgb_path, 'text': caption} |
|
return rasters, other |
|
|
|
def _load_rgb_data(self, rgb_path): |
|
|
|
rgb = self._read_rgb_file(rgb_path) |
|
rgb_norm = rgb / 255.0 * 2.0 - 1.0 |
|
|
|
outputs = { |
|
"rgb_int": torch.from_numpy(rgb).int(), |
|
"rgb_norm": torch.from_numpy(rgb_norm).float(), |
|
} |
|
return outputs |
|
|
|
def _get_data_path(self, index): |
|
filename_line = self.filenames[index] |
|
|
|
|
|
rgb_rel_path = filename_line[0] |
|
|
|
depth_rel_path, text_rel_path = None, None |
|
if DatasetMode.RGB_ONLY != self.mode: |
|
depth_rel_path = filename_line[1] |
|
if len(filename_line) > 2: |
|
text_rel_path = filename_line[2] |
|
return rgb_rel_path, depth_rel_path, text_rel_path |
|
|
|
def _read_image(self, img_path) -> np.ndarray: |
|
image_to_read = img_path |
|
image = Image.open(image_to_read) |
|
image = np.asarray(image) |
|
return image |
|
|
|
def _read_rgb_file(self, path) -> np.ndarray: |
|
rgb = self._read_image(path) |
|
rgb = np.transpose(rgb, (2, 0, 1)).astype(int) |
|
return rgb |
|
|
|
def _read_depth_file(self, path): |
|
depth_in = self._read_image(path) |
|
|
|
depth_decoded = depth_in |
|
return depth_decoded |
|
|
|
def _training_preprocess(self, rasters): |
|
|
|
if self.augm_args is not None: |
|
rasters = self._augment_data(rasters) |
|
|
|
|
|
|
|
|
|
|
|
rasters["depth_raw_norm"] = self.depth_transform( |
|
rasters["depth_raw_linear"], rasters["valid_mask_raw"] |
|
).clone() |
|
rasters["depth_filled_norm"] = self.depth_transform( |
|
rasters["depth_filled_linear"], rasters["valid_mask_filled"] |
|
).clone() |
|
|
|
|
|
if self.move_invalid_to_far_plane: |
|
if self.depth_transform.far_plane_at_max: |
|
rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( |
|
self.depth_transform.norm_max |
|
) |
|
else: |
|
rasters["depth_filled_norm"][~rasters["valid_mask_filled"]] = ( |
|
self.depth_transform.norm_min |
|
) |
|
|
|
|
|
if self.resize_to_hw is not None: |
|
resize_transform = transforms.Compose([ |
|
Resize(size=max(self.resize_to_hw), interpolation=InterpolationMode.NEAREST_EXACT), |
|
CenterCrop(size=self.resize_to_hw)]) |
|
rasters = {k: resize_transform(v) for k, v in rasters.items()} |
|
return rasters |
|
|
|
def _augment_data(self, rasters_dict): |
|
|
|
lr_flip_p = self.augm_args.lr_flip_p |
|
if random.random() < lr_flip_p: |
|
rasters_dict = {k: v.flip(-1) for k, v in rasters_dict.items()} |
|
|
|
return rasters_dict |
|
|
|
def __del__(self): |
|
if hasattr(self, "tar_obj") and self.tar_obj is not None: |
|
self.tar_obj.close() |
|
self.tar_obj = None |
|
|
|
def get_pred_name(rgb_basename, name_mode, suffix=".png"): |
|
if DepthFileNameMode.rgb_id == name_mode: |
|
pred_basename = "pred_" + rgb_basename.split("_")[1] |
|
elif DepthFileNameMode.i_d_rgb == name_mode: |
|
pred_basename = rgb_basename.replace("_rgb.", "_pred.") |
|
elif DepthFileNameMode.id == name_mode: |
|
pred_basename = "pred_" + rgb_basename |
|
elif DepthFileNameMode.rgb_i_d == name_mode: |
|
pred_basename = "pred_" + "_".join(rgb_basename.split("_")[1:]) |
|
else: |
|
raise NotImplementedError |
|
|
|
pred_basename = os.path.splitext(pred_basename)[0] + suffix |
|
|
|
return pred_basename |
|
|