File size: 3,801 Bytes
ec0c8fa |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 |
# decorator
import numpy as np
from numbers import Number
import inspect
def get_args_order(func, args, kwargs):
"""
Get the order of the arguments of a function.
"""
names = inspect.getfullargspec(func).args
names_idx = {name: i for i, name in enumerate(names)}
args_order = []
kwargs_order = {}
for name, arg in kwargs.items():
if name in names:
kwargs_order[name] = names_idx[name]
names.remove(name)
for i, arg in enumerate(args):
if i < len(names):
args_order.append(names_idx[names[i]])
return args_order, kwargs_order
def broadcast_args(args, kwargs, args_dim, kwargs_dim):
spatial = []
for arg, arg_dim in zip(args + list(kwargs.values()), args_dim + list(kwargs_dim.values())):
if isinstance(arg, np.ndarray) and arg_dim is not None:
arg_spatial = arg.shape[:arg.ndim-arg_dim]
if len(arg_spatial) > len(spatial):
spatial = [1] * (len(arg_spatial) - len(spatial)) + spatial
for j in range(len(arg_spatial)):
if spatial[-j] < arg_spatial[-j]:
if spatial[-j] == 1:
spatial[-j] = arg_spatial[-j]
else:
raise ValueError("Cannot broadcast arguments.")
for i, arg in enumerate(args):
if isinstance(arg, np.ndarray) and args_dim[i] is not None:
args[i] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-args_dim[i]:]])
for key, arg in kwargs.items():
if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
kwargs[key] = np.broadcast_to(arg, [*spatial, *arg.shape[arg.ndim-kwargs_dim[key]:]])
return args, kwargs, spatial
def batched(*dims):
"""
Decorator that allows a function to be called with batched arguments.
"""
def decorator(func):
def wrapper(*args, **kwargs):
args = list(args)
# get arguments dimensions
args_order, kwargs_order = get_args_order(func, args, kwargs)
args_dim = [dims[i] for i in args_order]
kwargs_dim = {key: dims[i] for key, i in kwargs_order.items()}
# convert to numpy array
for i, arg in enumerate(args):
if isinstance(arg, (Number, list, tuple)) and args_dim[i] is not None:
args[i] = np.array(arg)
for key, arg in kwargs.items():
if isinstance(arg, (Number, list, tuple)) and kwargs_dim[key] is not None:
kwargs[key] = np.array(arg)
# broadcast arguments
args, kwargs, spatial = broadcast_args(args, kwargs, args_dim, kwargs_dim)
for i, (arg, arg_dim) in enumerate(zip(args, args_dim)):
if isinstance(arg, np.ndarray) and arg_dim is not None:
args[i] = arg.reshape([-1, *arg.shape[arg.ndim-arg_dim:]])
for key, arg in kwargs.items():
if isinstance(arg, np.ndarray) and kwargs_dim[key] is not None:
kwargs[key] = arg.reshape([-1, *arg.shape[arg.ndim-kwargs_dim[key]:]])
# call function
results = func(*args, **kwargs)
type_results = type(results)
results = list(results) if isinstance(results, (tuple, list)) else [results]
# restore spatial dimensions
for i, result in enumerate(results):
results[i] = result.reshape([*spatial, *result.shape[1:]])
if type_results == tuple:
results = tuple(results)
elif type_results == list:
results = list(results)
else:
results = results[0]
return results
return wrapper
return decorator |