Spaces:
Runtime error
Runtime error
# -*- coding: utf-8 -*- | |
""" | |
Provides a reload() function that acts recursively. | |
Python's normal :func:`python:reload` function only reloads the module that it's | |
passed. The :func:`reload` function in this module also reloads everything | |
imported from that module, which is useful when you're changing files deep | |
inside a package. | |
To use this as your default reload function, type this:: | |
import builtins | |
from IPython.lib import deepreload | |
builtins.reload = deepreload.reload | |
A reference to the original :func:`python:reload` is stored in this module as | |
:data:`original_reload`, so you can restore it later. | |
This code is almost entirely based on knee.py, which is a Python | |
re-implementation of hierarchical module import. | |
""" | |
#***************************************************************************** | |
# Copyright (C) 2001 Nathaniel Gray <[email protected]> | |
# | |
# Distributed under the terms of the BSD License. The full license is in | |
# the file COPYING, distributed as part of this software. | |
#***************************************************************************** | |
import builtins as builtin_mod | |
from contextlib import contextmanager | |
import importlib | |
import sys | |
from types import ModuleType | |
from warnings import warn | |
import types | |
original_import = builtin_mod.__import__ | |
def replace_import_hook(new_import): | |
saved_import = builtin_mod.__import__ | |
builtin_mod.__import__ = new_import | |
try: | |
yield | |
finally: | |
builtin_mod.__import__ = saved_import | |
def get_parent(globals, level): | |
""" | |
parent, name = get_parent(globals, level) | |
Return the package that an import is being performed in. If globals comes | |
from the module foo.bar.bat (not itself a package), this returns the | |
sys.modules entry for foo.bar. If globals is from a package's __init__.py, | |
the package's entry in sys.modules is returned. | |
If globals doesn't come from a package or a module in a package, or a | |
corresponding entry is not found in sys.modules, None is returned. | |
""" | |
orig_level = level | |
if not level or not isinstance(globals, dict): | |
return None, '' | |
pkgname = globals.get('__package__', None) | |
if pkgname is not None: | |
# __package__ is set, so use it | |
if not hasattr(pkgname, 'rindex'): | |
raise ValueError('__package__ set to non-string') | |
if len(pkgname) == 0: | |
if level > 0: | |
raise ValueError('Attempted relative import in non-package') | |
return None, '' | |
name = pkgname | |
else: | |
# __package__ not set, so figure it out and set it | |
if '__name__' not in globals: | |
return None, '' | |
modname = globals['__name__'] | |
if '__path__' in globals: | |
# __path__ is set, so modname is already the package name | |
globals['__package__'] = name = modname | |
else: | |
# Normal module, so work out the package name if any | |
lastdot = modname.rfind('.') | |
if lastdot < 0 < level: | |
raise ValueError("Attempted relative import in non-package") | |
if lastdot < 0: | |
globals['__package__'] = None | |
return None, '' | |
globals['__package__'] = name = modname[:lastdot] | |
dot = len(name) | |
for x in range(level, 1, -1): | |
try: | |
dot = name.rindex('.', 0, dot) | |
except ValueError as e: | |
raise ValueError("attempted relative import beyond top-level " | |
"package") from e | |
name = name[:dot] | |
try: | |
parent = sys.modules[name] | |
except BaseException as e: | |
if orig_level < 1: | |
warn("Parent module '%.200s' not found while handling absolute " | |
"import" % name) | |
parent = None | |
else: | |
raise SystemError("Parent module '%.200s' not loaded, cannot " | |
"perform relative import" % name) from e | |
# We expect, but can't guarantee, if parent != None, that: | |
# - parent.__name__ == name | |
# - parent.__dict__ is globals | |
# If this is violated... Who cares? | |
return parent, name | |
def load_next(mod, altmod, name, buf): | |
""" | |
mod, name, buf = load_next(mod, altmod, name, buf) | |
altmod is either None or same as mod | |
""" | |
if len(name) == 0: | |
# completely empty module name should only happen in | |
# 'from . import' (or '__import__("")') | |
return mod, None, buf | |
dot = name.find('.') | |
if dot == 0: | |
raise ValueError('Empty module name') | |
if dot < 0: | |
subname = name | |
next = None | |
else: | |
subname = name[:dot] | |
next = name[dot+1:] | |
if buf != '': | |
buf += '.' | |
buf += subname | |
result = import_submodule(mod, subname, buf) | |
if result is None and mod != altmod: | |
result = import_submodule(altmod, subname, subname) | |
if result is not None: | |
buf = subname | |
if result is None: | |
raise ImportError("No module named %.200s" % name) | |
return result, next, buf | |
# Need to keep track of what we've already reloaded to prevent cyclic evil | |
found_now = {} | |
def import_submodule(mod, subname, fullname): | |
"""m = import_submodule(mod, subname, fullname)""" | |
# Require: | |
# if mod == None: subname == fullname | |
# else: mod.__name__ + "." + subname == fullname | |
global found_now | |
if fullname in found_now and fullname in sys.modules: | |
m = sys.modules[fullname] | |
else: | |
print('Reloading', fullname) | |
found_now[fullname] = 1 | |
oldm = sys.modules.get(fullname, None) | |
try: | |
if oldm is not None: | |
m = importlib.reload(oldm) | |
else: | |
m = importlib.import_module(subname, mod) | |
except: | |
# load_module probably removed name from modules because of | |
# the error. Put back the original module object. | |
if oldm: | |
sys.modules[fullname] = oldm | |
raise | |
add_submodule(mod, m, fullname, subname) | |
return m | |
def add_submodule(mod, submod, fullname, subname): | |
"""mod.{subname} = submod""" | |
if mod is None: | |
return #Nothing to do here. | |
if submod is None: | |
submod = sys.modules[fullname] | |
setattr(mod, subname, submod) | |
return | |
def ensure_fromlist(mod, fromlist, buf, recursive): | |
"""Handle 'from module import a, b, c' imports.""" | |
if not hasattr(mod, '__path__'): | |
return | |
for item in fromlist: | |
if not hasattr(item, 'rindex'): | |
raise TypeError("Item in ``from list'' not a string") | |
if item == '*': | |
if recursive: | |
continue # avoid endless recursion | |
try: | |
all = mod.__all__ | |
except AttributeError: | |
pass | |
else: | |
ret = ensure_fromlist(mod, all, buf, 1) | |
if not ret: | |
return 0 | |
elif not hasattr(mod, item): | |
import_submodule(mod, item, buf + '.' + item) | |
def deep_import_hook(name, globals=None, locals=None, fromlist=None, level=-1): | |
"""Replacement for __import__()""" | |
parent, buf = get_parent(globals, level) | |
head, name, buf = load_next(parent, None if level < 0 else parent, name, buf) | |
tail = head | |
while name: | |
tail, name, buf = load_next(tail, tail, name, buf) | |
# If tail is None, both get_parent and load_next found | |
# an empty module name: someone called __import__("") or | |
# doctored faulty bytecode | |
if tail is None: | |
raise ValueError('Empty module name') | |
if not fromlist: | |
return head | |
ensure_fromlist(tail, fromlist, buf, 0) | |
return tail | |
modules_reloading = {} | |
def deep_reload_hook(m): | |
"""Replacement for reload().""" | |
# Hardcode this one as it would raise a NotImplementedError from the | |
# bowels of Python and screw up the import machinery after. | |
# unlike other imports the `exclude` list already in place is not enough. | |
if m is types: | |
return m | |
if not isinstance(m, ModuleType): | |
raise TypeError("reload() argument must be module") | |
name = m.__name__ | |
if name not in sys.modules: | |
raise ImportError("reload(): module %.200s not in sys.modules" % name) | |
global modules_reloading | |
try: | |
return modules_reloading[name] | |
except: | |
modules_reloading[name] = m | |
try: | |
newm = importlib.reload(m) | |
except: | |
sys.modules[name] = m | |
raise | |
finally: | |
modules_reloading.clear() | |
return newm | |
# Save the original hooks | |
original_reload = importlib.reload | |
# Replacement for reload() | |
def reload( | |
module, | |
exclude=( | |
*sys.builtin_module_names, | |
"sys", | |
"os.path", | |
"builtins", | |
"__main__", | |
"numpy", | |
"numpy._globals", | |
), | |
): | |
"""Recursively reload all modules used in the given module. Optionally | |
takes a list of modules to exclude from reloading. The default exclude | |
list contains modules listed in sys.builtin_module_names with additional | |
sys, os.path, builtins and __main__, to prevent, e.g., resetting | |
display, exception, and io hooks. | |
""" | |
global found_now | |
for i in exclude: | |
found_now[i] = 1 | |
try: | |
with replace_import_hook(deep_import_hook): | |
return deep_reload_hook(module) | |
finally: | |
found_now = {} | |