|
import os |
|
import sys |
|
import random |
|
import numpy as np |
|
from tqdm import tqdm, trange |
|
from PIL import Image, ImageOps, ImageFilter |
|
|
|
import torch |
|
import torch.utils.data as data |
|
import torchvision.transforms as transform |
|
|
|
from datasets.base import BaseDataset |
|
|
|
class CitySegmentation(BaseDataset): |
|
NUM_CLASS = 19 |
|
def __init__(self, root, split='val', mode='testval', transform=None, target_transform=None, **kwargs): |
|
super(CitySegmentation, self).__init__( |
|
root, split, mode, transform, target_transform, **kwargs) |
|
self.images, self.mask_paths = get_city_pairs(self.root, self.split) |
|
assert (len(self.images) == len(self.mask_paths)) |
|
if len(self.images) == 0: |
|
raise RuntimeError("Found 0 images in subfolders of: \ |
|
" + self.root + "\n") |
|
self._indices = np.array(range(-1, 19)) |
|
self._classes = np.array([0, 7, 8, 11, 12, 13, 17, 19, 20, 21, 22, |
|
23, 24, 25, 26, 27, 28, 31, 32, 33]) |
|
self._key = np.array([-1, -1, -1, -1, -1, -1, |
|
-1, -1, 0, 1, -1, -1, |
|
2, 3, 4, -1, -1, -1, |
|
5, -1, 6, 7, 8, 9, |
|
10, 11, 12, 13, 14, 15, |
|
-1, -1, 16, 17, 18]) |
|
self._mapping = np.array(range(-1, len(self._key)-1)).astype('int32') |
|
|
|
def _class_to_index(self, mask): |
|
|
|
values = np.unique(mask) |
|
for i in range(len(values)): |
|
assert(values[i] in self._mapping) |
|
index = np.digitize(mask.ravel(), self._mapping, right=True) |
|
return self._key[index].reshape(mask.shape) |
|
|
|
def __getitem__(self, index): |
|
img = Image.open(self.images[index]).convert('RGB') |
|
mask = Image.open(self.mask_paths[index]) |
|
if self.mode == 'testval': |
|
img, mask = self._testval_transform(img, mask) |
|
elif self.mode == 'val': |
|
img, mask = self._val_transform(img, mask) |
|
elif self.mode == 'train': |
|
img, mask = self._train_transform(img, mask) |
|
|
|
if self.transform is not None: |
|
img = self.transform(img) |
|
if self.target_transform is not None: |
|
mask = self.target_transform(mask) |
|
return img, mask |
|
|
|
def _mask_transform(self, mask): |
|
target = self._class_to_index(np.array(mask).astype('int32')) |
|
return torch.from_numpy(target).long() |
|
|
|
def __len__(self): |
|
return len(self.images) |
|
|
|
|
|
def get_city_pairs(folder, split='val'): |
|
def get_path_pairs(img_folder, mask_folder): |
|
img_paths = [] |
|
mask_paths = [] |
|
for root, directories, files in os.walk(img_folder): |
|
for filename in files: |
|
if filename.endswith(".png"): |
|
imgpath = os.path.join(root, filename) |
|
foldername = os.path.basename(os.path.dirname(imgpath)) |
|
maskname = filename.replace('leftImg8bit','gtFine_labelIds') |
|
maskpath = os.path.join(mask_folder, foldername, maskname) |
|
if os.path.isfile(imgpath) and os.path.isfile(maskpath): |
|
img_paths.append(imgpath) |
|
mask_paths.append(maskpath) |
|
else: |
|
print('cannot find the mask or image:', imgpath, maskpath) |
|
print('Found {} images in the folder {}'.format(len(img_paths), img_folder)) |
|
return img_paths, mask_paths |
|
|
|
img_folder = os.path.join(folder, 'leftImg8bit/' + split) |
|
mask_folder = os.path.join(folder, 'gtFine/'+ split) |
|
img_paths, mask_paths = get_path_pairs(img_folder, mask_folder) |
|
return img_paths, mask_paths |
|
|