|
|
|
import functools |
|
|
|
|
|
class OutputHook: |
|
|
|
def __init__(self, module, outputs=None, as_tensor=False): |
|
self.outputs = outputs |
|
self.as_tensor = as_tensor |
|
self.layer_outputs = {} |
|
self.register(module) |
|
|
|
def register(self, module): |
|
|
|
def hook_wrapper(name): |
|
|
|
def hook(model, input, output): |
|
if self.as_tensor: |
|
self.layer_outputs[name] = output |
|
else: |
|
if isinstance(output, list): |
|
self.layer_outputs[name] = [ |
|
out.detach().cpu().numpy() for out in output |
|
] |
|
else: |
|
self.layer_outputs[name] = output.detach().cpu().numpy( |
|
) |
|
|
|
return hook |
|
|
|
self.handles = [] |
|
if isinstance(self.outputs, (list, tuple)): |
|
for name in self.outputs: |
|
try: |
|
layer = rgetattr(module, name) |
|
h = layer.register_forward_hook(hook_wrapper(name)) |
|
except ModuleNotFoundError as module_not_found: |
|
raise ModuleNotFoundError( |
|
f'Module {name} not found') from module_not_found |
|
self.handles.append(h) |
|
|
|
def remove(self): |
|
for h in self.handles: |
|
h.remove() |
|
|
|
def __enter__(self): |
|
return self |
|
|
|
def __exit__(self, exc_type, exc_val, exc_tb): |
|
self.remove() |
|
|
|
|
|
|
|
|
|
def rgetattr(obj, attr, *args): |
|
|
|
def _getattr(obj, attr): |
|
return getattr(obj, attr, *args) |
|
|
|
return functools.reduce(_getattr, [obj] + attr.split('.')) |
|
|