import pickle import warnings import h5py import numpy as np class Cache: def __init__(self, file: str, mode: str = 'a', overwrite=False): self.db_file = h5py.File(file, mode=mode) self.overwrite = overwrite @staticmethod def _key(key): if isinstance(key, str): return key elif isinstance(key, list): ret = [] for k in key: ret.append(Cache._key(k)) return ' '.join(ret) else: return str(key) @staticmethod def _value(value: np.ndarray): if isinstance(value, h5py.Dataset): value: np.ndarray = value[()] if value.dtype.name.startswith('bytes'): value = pickle.loads(value) return value def __getitem__(self, key): key = self._key(key) if key not in self: raise KeyError return self._value(self.db_file[key]) def __setitem__(self, key, value) -> None: key = self._key(key) if key in self: del self.db_file[key] if not isinstance(value, np.ndarray): value = np.array(pickle.dumps(value)) self.db_file[key] = value def __delitem__(self, key) -> None: key = self._key(key) if key in self: del self.db_file[key] def __len__(self) -> int: return len(self.db_file) def close(self) -> None: self.db_file.close() def __exit__(self, exc_type, exc_val, exc_tb) -> None: self.close() def __contains__(self, item): item = self._key(item) return item in self.db_file def __enter__(self): return self def __call__(self, function): """ The object of the class could also be used as a decorator. Provide an additional argument `cache_id' when calling the function, and the results will be cached. """ def wrapper(*args, **kwargs): if 'cache_id' in kwargs: cache_id = kwargs['cache_id'] del kwargs['cache_id'] if cache_id in self and not self.overwrite: return self[cache_id] rst = function(*args, **kwargs) self[cache_id] = rst return rst else: warnings.warn("`cache_id' argument not found. Cache is disabled.") return function(*args, **kwargs) return wrapper