Spaces:
Runtime error
Runtime error
# Copyright (c) OpenMMLab. All rights reserved. | |
import functools | |
import torch | |
def assert_tensor_type(func): | |
def wrapper(*args, **kwargs): | |
if not isinstance(args[0].data, torch.Tensor): | |
raise AttributeError( | |
f'{args[0].__class__.__name__} has no attribute ' | |
f'{func.__name__} for type {args[0].datatype}') | |
return func(*args, **kwargs) | |
return wrapper | |
class DataContainer: | |
"""A container for any type of objects. | |
Typically tensors will be stacked in the collate function and sliced along | |
some dimension in the scatter function. This behavior has some limitations. | |
1. All tensors have to be the same size. | |
2. Types are limited (numpy array or Tensor). | |
We design `DataContainer` and `MMDataParallel` to overcome these | |
limitations. The behavior can be either of the following. | |
- copy to GPU, pad all tensors to the same size and stack them | |
- copy to GPU without stacking | |
- leave the objects as is and pass it to the model | |
- pad_dims specifies the number of last few dimensions to do padding | |
""" | |
def __init__(self, | |
data, | |
stack=False, | |
padding_value=0, | |
cpu_only=False, | |
pad_dims=2): | |
self._data = data | |
self._cpu_only = cpu_only | |
self._stack = stack | |
self._padding_value = padding_value | |
assert pad_dims in [None, 1, 2, 3] | |
self._pad_dims = pad_dims | |
def __repr__(self): | |
return f'{self.__class__.__name__}({repr(self.data)})' | |
def __len__(self): | |
return len(self._data) | |
def data(self): | |
return self._data | |
def datatype(self): | |
if isinstance(self.data, torch.Tensor): | |
return self.data.type() | |
else: | |
return type(self.data) | |
def cpu_only(self): | |
return self._cpu_only | |
def stack(self): | |
return self._stack | |
def padding_value(self): | |
return self._padding_value | |
def pad_dims(self): | |
return self._pad_dims | |
def size(self, *args, **kwargs): | |
return self.data.size(*args, **kwargs) | |
def dim(self): | |
return self.data.dim() | |