|
|
|
import os |
|
import time |
|
import numpy as np |
|
import warnings |
|
import random |
|
from omegaconf.listconfig import ListConfig |
|
from webdataset import pipelinefilter |
|
import torch |
|
import torchvision.transforms.functional as TVF |
|
from torchvision.transforms import InterpolationMode |
|
from torchvision.transforms.transforms import _interpolation_modes_from_int |
|
from typing import Sequence |
|
|
|
from michelangelo.utils import instantiate_from_config |
|
|
|
|
|
def _uid_buffer_pick(buf_dict, rng): |
|
uid_keys = list(buf_dict.keys()) |
|
selected_uid = rng.choice(uid_keys) |
|
buf = buf_dict[selected_uid] |
|
|
|
k = rng.randint(0, len(buf) - 1) |
|
sample = buf[k] |
|
buf[k] = buf[-1] |
|
buf.pop() |
|
|
|
if len(buf) == 0: |
|
del buf_dict[selected_uid] |
|
|
|
return sample |
|
|
|
|
|
def _add_to_buf_dict(buf_dict, sample): |
|
key = sample["__key__"] |
|
uid, uid_sample_id = key.split("_") |
|
if uid not in buf_dict: |
|
buf_dict[uid] = [] |
|
buf_dict[uid].append(sample) |
|
|
|
return buf_dict |
|
|
|
|
|
def _uid_shuffle(data, bufsize=1000, initial=100, rng=None, handler=None): |
|
"""Shuffle the data in the stream. |
|
|
|
This uses a buffer of size `bufsize`. Shuffling at |
|
startup is less random; this is traded off against |
|
yielding samples quickly. |
|
|
|
data: iterator |
|
bufsize: buffer size for shuffling |
|
returns: iterator |
|
rng: either random module or random.Random instance |
|
|
|
""" |
|
if rng is None: |
|
rng = random.Random(int((os.getpid() + time.time()) * 1e9)) |
|
initial = min(initial, bufsize) |
|
buf_dict = dict() |
|
current_samples = 0 |
|
for sample in data: |
|
_add_to_buf_dict(buf_dict, sample) |
|
current_samples += 1 |
|
|
|
if current_samples < bufsize: |
|
try: |
|
_add_to_buf_dict(buf_dict, next(data)) |
|
current_samples += 1 |
|
except StopIteration: |
|
pass |
|
|
|
if current_samples >= initial: |
|
current_samples -= 1 |
|
yield _uid_buffer_pick(buf_dict, rng) |
|
|
|
while current_samples > 0: |
|
current_samples -= 1 |
|
yield _uid_buffer_pick(buf_dict, rng) |
|
|
|
|
|
uid_shuffle = pipelinefilter(_uid_shuffle) |
|
|
|
|
|
class RandomSample(object): |
|
def __init__(self, |
|
num_volume_samples: int = 1024, |
|
num_near_samples: int = 1024): |
|
|
|
super().__init__() |
|
|
|
self.num_volume_samples = num_volume_samples |
|
self.num_near_samples = num_near_samples |
|
|
|
def __call__(self, sample): |
|
rng = np.random.default_rng() |
|
|
|
|
|
total_surface = sample["surface"] |
|
ind = rng.choice(total_surface.shape[0], replace=False) |
|
surface = total_surface[ind] |
|
|
|
|
|
vol_points = sample["vol_points"] |
|
vol_label = sample["vol_label"] |
|
near_points = sample["near_points"] |
|
near_label = sample["near_label"] |
|
|
|
ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) |
|
vol_points = vol_points[ind] |
|
vol_label = vol_label[ind] |
|
vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) |
|
|
|
ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) |
|
near_points = near_points[ind] |
|
near_label = near_label[ind] |
|
near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) |
|
|
|
|
|
geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) |
|
|
|
sample = { |
|
"surface": surface, |
|
"geo_points": geo_points |
|
} |
|
|
|
return sample |
|
|
|
|
|
class SplitRandomSample(object): |
|
def __init__(self, |
|
use_surface_sample: bool = False, |
|
num_surface_samples: int = 4096, |
|
num_volume_samples: int = 1024, |
|
num_near_samples: int = 1024): |
|
|
|
super().__init__() |
|
|
|
self.use_surface_sample = use_surface_sample |
|
self.num_surface_samples = num_surface_samples |
|
self.num_volume_samples = num_volume_samples |
|
self.num_near_samples = num_near_samples |
|
|
|
def __call__(self, sample): |
|
|
|
rng = np.random.default_rng() |
|
|
|
|
|
surface = sample["surface"] |
|
|
|
if self.use_surface_sample: |
|
replace = surface.shape[0] < self.num_surface_samples |
|
ind = rng.choice(surface.shape[0], self.num_surface_samples, replace=replace) |
|
surface = surface[ind] |
|
|
|
|
|
vol_points = sample["vol_points"] |
|
vol_label = sample["vol_label"] |
|
near_points = sample["near_points"] |
|
near_label = sample["near_label"] |
|
|
|
ind = rng.choice(vol_points.shape[0], self.num_volume_samples, replace=False) |
|
vol_points = vol_points[ind] |
|
vol_label = vol_label[ind] |
|
vol_points_labels = np.concatenate([vol_points, vol_label[:, np.newaxis]], axis=1) |
|
|
|
ind = rng.choice(near_points.shape[0], self.num_near_samples, replace=False) |
|
near_points = near_points[ind] |
|
near_label = near_label[ind] |
|
near_points_labels = np.concatenate([near_points, near_label[:, np.newaxis]], axis=1) |
|
|
|
|
|
geo_points = np.concatenate([vol_points_labels, near_points_labels], axis=0) |
|
|
|
sample = { |
|
"surface": surface, |
|
"geo_points": geo_points |
|
} |
|
|
|
return sample |
|
|
|
|
|
class FeatureSelection(object): |
|
|
|
VALID_SURFACE_FEATURE_DIMS = { |
|
"none": [0, 1, 2], |
|
"watertight_normal": [0, 1, 2, 3, 4, 5], |
|
"normal": [0, 1, 2, 6, 7, 8] |
|
} |
|
|
|
def __init__(self, surface_feature_type: str): |
|
|
|
self.surface_feature_type = surface_feature_type |
|
self.surface_dims = self.VALID_SURFACE_FEATURE_DIMS[surface_feature_type] |
|
|
|
def __call__(self, sample): |
|
sample["surface"] = sample["surface"][:, self.surface_dims] |
|
return sample |
|
|
|
|
|
class AxisScaleTransform(object): |
|
def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): |
|
assert isinstance(interval, (tuple, list, ListConfig)) |
|
self.interval = interval |
|
self.min_val = interval[0] |
|
self.max_val = interval[1] |
|
self.inter_size = interval[1] - interval[0] |
|
self.jitter = jitter |
|
self.jitter_scale = jitter_scale |
|
|
|
def __call__(self, sample): |
|
|
|
surface = sample["surface"][..., 0:3] |
|
geo_points = sample["geo_points"][..., 0:3] |
|
|
|
scaling = torch.rand(1, 3) * self.inter_size + self.min_val |
|
|
|
surface = surface * scaling |
|
geo_points = geo_points * scaling |
|
|
|
scale = (1 / torch.abs(surface).max().item()) * 0.999999 |
|
surface *= scale |
|
geo_points *= scale |
|
|
|
if self.jitter: |
|
surface += self.jitter_scale * torch.randn_like(surface) |
|
surface.clamp_(min=-1.015, max=1.015) |
|
|
|
sample["surface"][..., 0:3] = surface |
|
sample["geo_points"][..., 0:3] = geo_points |
|
|
|
return sample |
|
|
|
|
|
class ToTensor(object): |
|
|
|
def __init__(self, tensor_keys=("surface", "geo_points", "tex_points")): |
|
self.tensor_keys = tensor_keys |
|
|
|
def __call__(self, sample): |
|
for key in self.tensor_keys: |
|
if key not in sample: |
|
continue |
|
|
|
sample[key] = torch.tensor(sample[key], dtype=torch.float32) |
|
|
|
return sample |
|
|
|
|
|
class AxisScale(object): |
|
def __init__(self, interval=(0.75, 1.25), jitter=True, jitter_scale=0.005): |
|
assert isinstance(interval, (tuple, list, ListConfig)) |
|
self.interval = interval |
|
self.jitter = jitter |
|
self.jitter_scale = jitter_scale |
|
|
|
def __call__(self, surface, *args): |
|
scaling = torch.rand(1, 3) * 0.5 + 0.75 |
|
|
|
surface = surface * scaling |
|
scale = (1 / torch.abs(surface).max().item()) * 0.999999 |
|
surface *= scale |
|
|
|
args_outputs = [] |
|
for _arg in args: |
|
_arg = _arg * scaling * scale |
|
args_outputs.append(_arg) |
|
|
|
if self.jitter: |
|
surface += self.jitter_scale * torch.randn_like(surface) |
|
surface.clamp_(min=-1, max=1) |
|
|
|
if len(args) == 0: |
|
return surface |
|
else: |
|
return surface, *args_outputs |
|
|
|
|
|
class RandomResize(torch.nn.Module): |
|
"""Apply randomly Resize with a given probability.""" |
|
|
|
def __init__( |
|
self, |
|
size, |
|
resize_radio=(0.5, 1), |
|
allow_resize_interpolations=(InterpolationMode.BICUBIC, InterpolationMode.BILINEAR, InterpolationMode.BILINEAR), |
|
interpolation=InterpolationMode.BICUBIC, |
|
max_size=None, |
|
antialias=None, |
|
): |
|
super().__init__() |
|
if not isinstance(size, (int, Sequence)): |
|
raise TypeError(f"Size should be int or sequence. Got {type(size)}") |
|
if isinstance(size, Sequence) and len(size) not in (1, 2): |
|
raise ValueError("If size is a sequence, it should have 1 or 2 values") |
|
|
|
self.size = size |
|
self.max_size = max_size |
|
|
|
if isinstance(interpolation, int): |
|
warnings.warn( |
|
"Argument 'interpolation' of type int is deprecated since 0.13 and will be removed in 0.15. " |
|
"Please use InterpolationMode enum." |
|
) |
|
interpolation = _interpolation_modes_from_int(interpolation) |
|
|
|
self.interpolation = interpolation |
|
self.antialias = antialias |
|
|
|
self.resize_radio = resize_radio |
|
self.allow_resize_interpolations = allow_resize_interpolations |
|
|
|
def random_resize_params(self): |
|
radio = torch.rand(1) * (self.resize_radio[1] - self.resize_radio[0]) + self.resize_radio[0] |
|
|
|
if isinstance(self.size, int): |
|
size = int(self.size * radio) |
|
elif isinstance(self.size, Sequence): |
|
size = list(self.size) |
|
size = (int(size[0] * radio), int(size[1] * radio)) |
|
else: |
|
raise RuntimeError() |
|
|
|
interpolation = self.allow_resize_interpolations[ |
|
torch.randint(low=0, high=len(self.allow_resize_interpolations), size=(1,)) |
|
] |
|
return size, interpolation |
|
|
|
def forward(self, img): |
|
size, interpolation = self.random_resize_params() |
|
img = TVF.resize(img, size, interpolation, self.max_size, self.antialias) |
|
img = TVF.resize(img, self.size, self.interpolation, self.max_size, self.antialias) |
|
return img |
|
|
|
def __repr__(self) -> str: |
|
detail = f"(size={self.size}, interpolation={self.interpolation.value}," |
|
detail += f"max_size={self.max_size}, antialias={self.antialias}), resize_radio={self.resize_radio}" |
|
return f"{self.__class__.__name__}{detail}" |
|
|
|
|
|
class Compose(object): |
|
"""Composes several transforms together. This transform does not support torchscript. |
|
Please, see the note below. |
|
|
|
Args: |
|
transforms (list of ``Transform`` objects): list of transforms to compose. |
|
|
|
Example: |
|
>>> transforms.Compose([ |
|
>>> transforms.CenterCrop(10), |
|
>>> transforms.ToTensor(), |
|
>>> ]) |
|
|
|
.. note:: |
|
In order to script the transformations, please use ``torch.nn.Sequential`` as below. |
|
|
|
>>> transforms = torch.nn.Sequential( |
|
>>> transforms.CenterCrop(10), |
|
>>> transforms.Normalize((0.485, 0.456, 0.406), (0.229, 0.224, 0.225)), |
|
>>> ) |
|
>>> scripted_transforms = torch.jit.script(transforms) |
|
|
|
Make sure to use only scriptable transformations, i.e. that work with ``torch.Tensor``, does not require |
|
`lambda` functions or ``PIL.Image``. |
|
|
|
""" |
|
|
|
def __init__(self, transforms): |
|
self.transforms = transforms |
|
|
|
def __call__(self, *args): |
|
for t in self.transforms: |
|
args = t(*args) |
|
return args |
|
|
|
def __repr__(self): |
|
format_string = self.__class__.__name__ + '(' |
|
for t in self.transforms: |
|
format_string += '\n' |
|
format_string += ' {0}'.format(t) |
|
format_string += '\n)' |
|
return format_string |
|
|
|
|
|
def identity(*args, **kwargs): |
|
if len(args) == 1: |
|
return args[0] |
|
else: |
|
return args |
|
|
|
|
|
def build_transforms(cfg): |
|
|
|
if cfg is None: |
|
return identity |
|
|
|
transforms = [] |
|
|
|
for transform_name, cfg_instance in cfg.items(): |
|
transform_instance = instantiate_from_config(cfg_instance) |
|
transforms.append(transform_instance) |
|
print(f"Build transform: {transform_instance}") |
|
|
|
transforms = Compose(transforms) |
|
|
|
return transforms |
|
|
|
|