File size: 1,836 Bytes
2cd560a |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 |
# Copyright (c) OpenMMLab. All rights reserved.
import functools
class OutputHook:
def __init__(self, module, outputs=None, as_tensor=False):
self.outputs = outputs
self.as_tensor = as_tensor
self.layer_outputs = {}
self.register(module)
def register(self, module):
def hook_wrapper(name):
def hook(model, input, output):
if self.as_tensor:
self.layer_outputs[name] = output
else:
if isinstance(output, list):
self.layer_outputs[name] = [
out.detach().cpu().numpy() for out in output
]
else:
self.layer_outputs[name] = output.detach().cpu().numpy(
)
return hook
self.handles = []
if isinstance(self.outputs, (list, tuple)):
for name in self.outputs:
try:
layer = rgetattr(module, name)
h = layer.register_forward_hook(hook_wrapper(name))
except ModuleNotFoundError as module_not_found:
raise ModuleNotFoundError(
f'Module {name} not found') from module_not_found
self.handles.append(h)
def remove(self):
for h in self.handles:
h.remove()
def __enter__(self):
return self
def __exit__(self, exc_type, exc_val, exc_tb):
self.remove()
# using wonder's beautiful simplification:
# https://stackoverflow.com/questions/31174295/getattr-and-setattr-on-nested-objects
def rgetattr(obj, attr, *args):
def _getattr(obj, attr):
return getattr(obj, attr, *args)
return functools.reduce(_getattr, [obj] + attr.split('.'))
|