Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import array | |
import asyncio | |
import concurrent.futures | |
import math | |
import socket | |
import sys | |
from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined] | |
from collections import OrderedDict, deque | |
from concurrent.futures import Future | |
from contextvars import Context, copy_context | |
from dataclasses import dataclass | |
from functools import partial, wraps | |
from inspect import ( | |
CORO_RUNNING, | |
CORO_SUSPENDED, | |
GEN_RUNNING, | |
GEN_SUSPENDED, | |
getcoroutinestate, | |
getgeneratorstate, | |
) | |
from io import IOBase | |
from os import PathLike | |
from queue import Queue | |
from socket import AddressFamily, SocketKind | |
from threading import Thread | |
from types import TracebackType | |
from typing import ( | |
IO, | |
Any, | |
AsyncGenerator, | |
Awaitable, | |
Callable, | |
Collection, | |
Coroutine, | |
Generator, | |
Iterable, | |
Mapping, | |
Optional, | |
Sequence, | |
Tuple, | |
TypeVar, | |
Union, | |
cast, | |
) | |
from weakref import WeakKeyDictionary | |
import sniffio | |
from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc | |
from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable | |
from .._core._eventloop import claim_worker_thread, threadlocals | |
from .._core._exceptions import ( | |
BrokenResourceError, | |
BusyResourceError, | |
ClosedResourceError, | |
EndOfStream, | |
WouldBlock, | |
) | |
from .._core._exceptions import ExceptionGroup as BaseExceptionGroup | |
from .._core._sockets import GetAddrInfoReturnType, convert_ipv6_sockaddr | |
from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter | |
from .._core._synchronization import Event as BaseEvent | |
from .._core._synchronization import ResourceGuard | |
from .._core._tasks import CancelScope as BaseCancelScope | |
from ..abc import IPSockAddrType, UDPPacketType | |
from ..lowlevel import RunVar | |
if sys.version_info >= (3, 8): | |
def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: | |
return task.get_coro() | |
else: | |
def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: | |
return task._coro | |
from asyncio import all_tasks, create_task, current_task, get_running_loop | |
from asyncio import run as native_run | |
def _get_task_callbacks(task: asyncio.Task) -> Iterable[Callable]: | |
return [cb for cb, context in task._callbacks] | |
T_Retval = TypeVar("T_Retval") | |
T_contra = TypeVar("T_contra", contravariant=True) | |
# Check whether there is native support for task names in asyncio (3.8+) | |
_native_task_names = hasattr(asyncio.Task, "get_name") | |
_root_task: RunVar[asyncio.Task | None] = RunVar("_root_task") | |
def find_root_task() -> asyncio.Task: | |
root_task = _root_task.get(None) | |
if root_task is not None and not root_task.done(): | |
return root_task | |
# Look for a task that has been started via run_until_complete() | |
for task in all_tasks(): | |
if task._callbacks and not task.done(): | |
for cb in _get_task_callbacks(task): | |
if ( | |
cb is _run_until_complete_cb | |
or getattr(cb, "__module__", None) == "uvloop.loop" | |
): | |
_root_task.set(task) | |
return task | |
# Look up the topmost task in the AnyIO task tree, if possible | |
task = cast(asyncio.Task, current_task()) | |
state = _task_states.get(task) | |
if state: | |
cancel_scope = state.cancel_scope | |
while cancel_scope and cancel_scope._parent_scope is not None: | |
cancel_scope = cancel_scope._parent_scope | |
if cancel_scope is not None: | |
return cast(asyncio.Task, cancel_scope._host_task) | |
return task | |
def get_callable_name(func: Callable) -> str: | |
module = getattr(func, "__module__", None) | |
qualname = getattr(func, "__qualname__", None) | |
return ".".join([x for x in (module, qualname) if x]) | |
# | |
# Event loop | |
# | |
_run_vars = ( | |
WeakKeyDictionary() | |
) # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] | |
current_token = get_running_loop | |
def _task_started(task: asyncio.Task) -> bool: | |
"""Return ``True`` if the task has been started and has not finished.""" | |
coro = cast(Coroutine[Any, Any, Any], get_coro(task)) | |
try: | |
return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED) | |
except AttributeError: | |
try: | |
return getgeneratorstate(cast(Generator, coro)) in ( | |
GEN_RUNNING, | |
GEN_SUSPENDED, | |
) | |
except AttributeError: | |
# task coro is async_genenerator_asend https://bugs.python.org/issue37771 | |
raise Exception(f"Cannot determine if task {task} has started or not") | |
def _maybe_set_event_loop_policy( | |
policy: asyncio.AbstractEventLoopPolicy | None, use_uvloop: bool | |
) -> None: | |
# On CPython, use uvloop when possible if no other policy has been given and if not | |
# explicitly disabled | |
if policy is None and use_uvloop and sys.implementation.name == "cpython": | |
try: | |
import uvloop | |
except ImportError: | |
pass | |
else: | |
# Test for missing shutdown_default_executor() (uvloop 0.14.0 and earlier) | |
if not hasattr( | |
asyncio.AbstractEventLoop, "shutdown_default_executor" | |
) or hasattr(uvloop.loop.Loop, "shutdown_default_executor"): | |
policy = uvloop.EventLoopPolicy() | |
if policy is not None: | |
asyncio.set_event_loop_policy(policy) | |
def run( | |
func: Callable[..., Awaitable[T_Retval]], | |
*args: object, | |
debug: bool = False, | |
use_uvloop: bool = False, | |
policy: asyncio.AbstractEventLoopPolicy | None = None, | |
) -> T_Retval: | |
async def wrapper() -> T_Retval: | |
task = cast(asyncio.Task, current_task()) | |
task_state = TaskState(None, get_callable_name(func), None) | |
_task_states[task] = task_state | |
if _native_task_names: | |
task.set_name(task_state.name) | |
try: | |
return await func(*args) | |
finally: | |
del _task_states[task] | |
_maybe_set_event_loop_policy(policy, use_uvloop) | |
return native_run(wrapper(), debug=debug) | |
# | |
# Miscellaneous | |
# | |
sleep = asyncio.sleep | |
# | |
# Timeouts and cancellation | |
# | |
CancelledError = asyncio.CancelledError | |
class CancelScope(BaseCancelScope): | |
def __new__( | |
cls, *, deadline: float = math.inf, shield: bool = False | |
) -> CancelScope: | |
return object.__new__(cls) | |
def __init__(self, deadline: float = math.inf, shield: bool = False): | |
self._deadline = deadline | |
self._shield = shield | |
self._parent_scope: CancelScope | None = None | |
self._cancel_called = False | |
self._active = False | |
self._timeout_handle: asyncio.TimerHandle | None = None | |
self._cancel_handle: asyncio.Handle | None = None | |
self._tasks: set[asyncio.Task] = set() | |
self._host_task: asyncio.Task | None = None | |
self._timeout_expired = False | |
self._cancel_calls: int = 0 | |
def __enter__(self) -> CancelScope: | |
if self._active: | |
raise RuntimeError( | |
"Each CancelScope may only be used for a single 'with' block" | |
) | |
self._host_task = host_task = cast(asyncio.Task, current_task()) | |
self._tasks.add(host_task) | |
try: | |
task_state = _task_states[host_task] | |
except KeyError: | |
task_name = host_task.get_name() if _native_task_names else None | |
task_state = TaskState(None, task_name, self) | |
_task_states[host_task] = task_state | |
else: | |
self._parent_scope = task_state.cancel_scope | |
task_state.cancel_scope = self | |
self._timeout() | |
self._active = True | |
# Start cancelling the host task if the scope was cancelled before entering | |
if self._cancel_called: | |
self._deliver_cancellation() | |
return self | |
def __exit__( | |
self, | |
exc_type: type[BaseException] | None, | |
exc_val: BaseException | None, | |
exc_tb: TracebackType | None, | |
) -> bool | None: | |
if not self._active: | |
raise RuntimeError("This cancel scope is not active") | |
if current_task() is not self._host_task: | |
raise RuntimeError( | |
"Attempted to exit cancel scope in a different task than it was " | |
"entered in" | |
) | |
assert self._host_task is not None | |
host_task_state = _task_states.get(self._host_task) | |
if host_task_state is None or host_task_state.cancel_scope is not self: | |
raise RuntimeError( | |
"Attempted to exit a cancel scope that isn't the current tasks's " | |
"current cancel scope" | |
) | |
self._active = False | |
if self._timeout_handle: | |
self._timeout_handle.cancel() | |
self._timeout_handle = None | |
self._tasks.remove(self._host_task) | |
host_task_state.cancel_scope = self._parent_scope | |
# Restart the cancellation effort in the farthest directly cancelled parent scope if this | |
# one was shielded | |
if self._shield: | |
self._deliver_cancellation_to_parent() | |
if exc_val is not None: | |
exceptions = ( | |
exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val] | |
) | |
if all(isinstance(exc, CancelledError) for exc in exceptions): | |
if self._timeout_expired: | |
return self._uncancel() | |
elif not self._cancel_called: | |
# Task was cancelled natively | |
return None | |
elif not self._parent_cancelled(): | |
# This scope was directly cancelled | |
return self._uncancel() | |
return None | |
def _uncancel(self) -> bool: | |
if sys.version_info < (3, 11) or self._host_task is None: | |
self._cancel_calls = 0 | |
return True | |
# Uncancel all AnyIO cancellations | |
for i in range(self._cancel_calls): | |
self._host_task.uncancel() | |
self._cancel_calls = 0 | |
return not self._host_task.cancelling() | |
def _timeout(self) -> None: | |
if self._deadline != math.inf: | |
loop = get_running_loop() | |
if loop.time() >= self._deadline: | |
self._timeout_expired = True | |
self.cancel() | |
else: | |
self._timeout_handle = loop.call_at(self._deadline, self._timeout) | |
def _deliver_cancellation(self) -> None: | |
""" | |
Deliver cancellation to directly contained tasks and nested cancel scopes. | |
Schedule another run at the end if we still have tasks eligible for cancellation. | |
""" | |
should_retry = False | |
current = current_task() | |
for task in self._tasks: | |
if task._must_cancel: # type: ignore[attr-defined] | |
continue | |
# The task is eligible for cancellation if it has started and is not in a cancel | |
# scope shielded from this one | |
cancel_scope = _task_states[task].cancel_scope | |
while cancel_scope is not self: | |
if cancel_scope is None or cancel_scope._shield: | |
break | |
else: | |
cancel_scope = cancel_scope._parent_scope | |
else: | |
should_retry = True | |
if task is not current and ( | |
task is self._host_task or _task_started(task) | |
): | |
self._cancel_calls += 1 | |
task.cancel() | |
# Schedule another callback if there are still tasks left | |
if should_retry: | |
self._cancel_handle = get_running_loop().call_soon( | |
self._deliver_cancellation | |
) | |
else: | |
self._cancel_handle = None | |
def _deliver_cancellation_to_parent(self) -> None: | |
"""Start cancellation effort in the farthest directly cancelled parent scope""" | |
scope = self._parent_scope | |
scope_to_cancel: CancelScope | None = None | |
while scope is not None: | |
if scope._cancel_called and scope._cancel_handle is None: | |
scope_to_cancel = scope | |
# No point in looking beyond any shielded scope | |
if scope._shield: | |
break | |
scope = scope._parent_scope | |
if scope_to_cancel is not None: | |
scope_to_cancel._deliver_cancellation() | |
def _parent_cancelled(self) -> bool: | |
# Check whether any parent has been cancelled | |
cancel_scope = self._parent_scope | |
while cancel_scope is not None and not cancel_scope._shield: | |
if cancel_scope._cancel_called: | |
return True | |
else: | |
cancel_scope = cancel_scope._parent_scope | |
return False | |
def cancel(self) -> DeprecatedAwaitable: | |
if not self._cancel_called: | |
if self._timeout_handle: | |
self._timeout_handle.cancel() | |
self._timeout_handle = None | |
self._cancel_called = True | |
if self._host_task is not None: | |
self._deliver_cancellation() | |
return DeprecatedAwaitable(self.cancel) | |
def deadline(self) -> float: | |
return self._deadline | |
def deadline(self, value: float) -> None: | |
self._deadline = float(value) | |
if self._timeout_handle is not None: | |
self._timeout_handle.cancel() | |
self._timeout_handle = None | |
if self._active and not self._cancel_called: | |
self._timeout() | |
def cancel_called(self) -> bool: | |
return self._cancel_called | |
def shield(self) -> bool: | |
return self._shield | |
def shield(self, value: bool) -> None: | |
if self._shield != value: | |
self._shield = value | |
if not value: | |
self._deliver_cancellation_to_parent() | |
async def checkpoint() -> None: | |
await sleep(0) | |
async def checkpoint_if_cancelled() -> None: | |
task = current_task() | |
if task is None: | |
return | |
try: | |
cancel_scope = _task_states[task].cancel_scope | |
except KeyError: | |
return | |
while cancel_scope: | |
if cancel_scope.cancel_called: | |
await sleep(0) | |
elif cancel_scope.shield: | |
break | |
else: | |
cancel_scope = cancel_scope._parent_scope | |
async def cancel_shielded_checkpoint() -> None: | |
with CancelScope(shield=True): | |
await sleep(0) | |
def current_effective_deadline() -> float: | |
try: | |
cancel_scope = _task_states[current_task()].cancel_scope # type: ignore[index] | |
except KeyError: | |
return math.inf | |
deadline = math.inf | |
while cancel_scope: | |
deadline = min(deadline, cancel_scope.deadline) | |
if cancel_scope._cancel_called: | |
deadline = -math.inf | |
break | |
elif cancel_scope.shield: | |
break | |
else: | |
cancel_scope = cancel_scope._parent_scope | |
return deadline | |
def current_time() -> float: | |
return get_running_loop().time() | |
# | |
# Task states | |
# | |
class TaskState: | |
""" | |
Encapsulates auxiliary task information that cannot be added to the Task instance itself | |
because there are no guarantees about its implementation. | |
""" | |
__slots__ = "parent_id", "name", "cancel_scope" | |
def __init__( | |
self, | |
parent_id: int | None, | |
name: str | None, | |
cancel_scope: CancelScope | None, | |
): | |
self.parent_id = parent_id | |
self.name = name | |
self.cancel_scope = cancel_scope | |
_task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState] | |
# | |
# Task groups | |
# | |
class ExceptionGroup(BaseExceptionGroup): | |
def __init__(self, exceptions: list[BaseException]): | |
super().__init__() | |
self.exceptions = exceptions | |
class _AsyncioTaskStatus(abc.TaskStatus): | |
def __init__(self, future: asyncio.Future, parent_id: int): | |
self._future = future | |
self._parent_id = parent_id | |
def started(self, value: T_contra | None = None) -> None: | |
try: | |
self._future.set_result(value) | |
except asyncio.InvalidStateError: | |
raise RuntimeError( | |
"called 'started' twice on the same task status" | |
) from None | |
task = cast(asyncio.Task, current_task()) | |
_task_states[task].parent_id = self._parent_id | |
class TaskGroup(abc.TaskGroup): | |
def __init__(self) -> None: | |
self.cancel_scope: CancelScope = CancelScope() | |
self._active = False | |
self._exceptions: list[BaseException] = [] | |
async def __aenter__(self) -> TaskGroup: | |
self.cancel_scope.__enter__() | |
self._active = True | |
return self | |
async def __aexit__( | |
self, | |
exc_type: type[BaseException] | None, | |
exc_val: BaseException | None, | |
exc_tb: TracebackType | None, | |
) -> bool | None: | |
ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) | |
if exc_val is not None: | |
self.cancel_scope.cancel() | |
self._exceptions.append(exc_val) | |
while self.cancel_scope._tasks: | |
try: | |
await asyncio.wait(self.cancel_scope._tasks) | |
except asyncio.CancelledError: | |
self.cancel_scope.cancel() | |
self._active = False | |
if not self.cancel_scope._parent_cancelled(): | |
exceptions = self._filter_cancellation_errors(self._exceptions) | |
else: | |
exceptions = self._exceptions | |
try: | |
if len(exceptions) > 1: | |
if all( | |
isinstance(e, CancelledError) and not e.args for e in exceptions | |
): | |
# Tasks were cancelled natively, without a cancellation message | |
raise CancelledError | |
else: | |
raise ExceptionGroup(exceptions) | |
elif exceptions and exceptions[0] is not exc_val: | |
raise exceptions[0] | |
except BaseException as exc: | |
# Clear the context here, as it can only be done in-flight. | |
# If the context is not cleared, it can result in recursive tracebacks (see #145). | |
exc.__context__ = None | |
raise | |
return ignore_exception | |
def _filter_cancellation_errors( | |
exceptions: Sequence[BaseException], | |
) -> list[BaseException]: | |
filtered_exceptions: list[BaseException] = [] | |
for exc in exceptions: | |
if isinstance(exc, ExceptionGroup): | |
new_exceptions = TaskGroup._filter_cancellation_errors(exc.exceptions) | |
if len(new_exceptions) > 1: | |
filtered_exceptions.append(exc) | |
elif len(new_exceptions) == 1: | |
filtered_exceptions.append(new_exceptions[0]) | |
elif new_exceptions: | |
new_exc = ExceptionGroup(new_exceptions) | |
new_exc.__cause__ = exc.__cause__ | |
new_exc.__context__ = exc.__context__ | |
new_exc.__traceback__ = exc.__traceback__ | |
filtered_exceptions.append(new_exc) | |
elif not isinstance(exc, CancelledError) or exc.args: | |
filtered_exceptions.append(exc) | |
return filtered_exceptions | |
async def _run_wrapped_task( | |
self, coro: Coroutine, task_status_future: asyncio.Future | None | |
) -> None: | |
# This is the code path for Python 3.7 on which asyncio freaks out if a task | |
# raises a BaseException. | |
__traceback_hide__ = __tracebackhide__ = True # noqa: F841 | |
task = cast(asyncio.Task, current_task()) | |
try: | |
await coro | |
except BaseException as exc: | |
if task_status_future is None or task_status_future.done(): | |
self._exceptions.append(exc) | |
self.cancel_scope.cancel() | |
else: | |
task_status_future.set_exception(exc) | |
else: | |
if task_status_future is not None and not task_status_future.done(): | |
task_status_future.set_exception( | |
RuntimeError("Child exited without calling task_status.started()") | |
) | |
finally: | |
if task in self.cancel_scope._tasks: | |
self.cancel_scope._tasks.remove(task) | |
del _task_states[task] | |
def _spawn( | |
self, | |
func: Callable[..., Awaitable[Any]], | |
args: tuple, | |
name: object, | |
task_status_future: asyncio.Future | None = None, | |
) -> asyncio.Task: | |
def task_done(_task: asyncio.Task) -> None: | |
# This is the code path for Python 3.8+ | |
assert _task in self.cancel_scope._tasks | |
self.cancel_scope._tasks.remove(_task) | |
del _task_states[_task] | |
try: | |
exc = _task.exception() | |
except CancelledError as e: | |
while isinstance(e.__context__, CancelledError): | |
e = e.__context__ | |
exc = e | |
if exc is not None: | |
if task_status_future is None or task_status_future.done(): | |
self._exceptions.append(exc) | |
self.cancel_scope.cancel() | |
else: | |
task_status_future.set_exception(exc) | |
elif task_status_future is not None and not task_status_future.done(): | |
task_status_future.set_exception( | |
RuntimeError("Child exited without calling task_status.started()") | |
) | |
if not self._active: | |
raise RuntimeError( | |
"This task group is not active; no new tasks can be started." | |
) | |
options: dict[str, Any] = {} | |
name = get_callable_name(func) if name is None else str(name) | |
if _native_task_names: | |
options["name"] = name | |
kwargs = {} | |
if task_status_future: | |
parent_id = id(current_task()) | |
kwargs["task_status"] = _AsyncioTaskStatus( | |
task_status_future, id(self.cancel_scope._host_task) | |
) | |
else: | |
parent_id = id(self.cancel_scope._host_task) | |
coro = func(*args, **kwargs) | |
if not asyncio.iscoroutine(coro): | |
raise TypeError( | |
f"Expected an async function, but {func} appears to be synchronous" | |
) | |
foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame") | |
if foreign_coro or sys.version_info < (3, 8): | |
coro = self._run_wrapped_task(coro, task_status_future) | |
task = create_task(coro, **options) | |
if not foreign_coro and sys.version_info >= (3, 8): | |
task.add_done_callback(task_done) | |
# Make the spawned task inherit the task group's cancel scope | |
_task_states[task] = TaskState( | |
parent_id=parent_id, name=name, cancel_scope=self.cancel_scope | |
) | |
self.cancel_scope._tasks.add(task) | |
return task | |
def start_soon( | |
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None | |
) -> None: | |
self._spawn(func, args, name) | |
async def start( | |
self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None | |
) -> None: | |
future: asyncio.Future = asyncio.Future() | |
task = self._spawn(func, args, name, future) | |
# If the task raises an exception after sending a start value without a switch point | |
# between, the task group is cancelled and this method never proceeds to process the | |
# completed future. That's why we have to have a shielded cancel scope here. | |
with CancelScope(shield=True): | |
try: | |
return await future | |
except CancelledError: | |
task.cancel() | |
raise | |
# | |
# Threads | |
# | |
_Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]] | |
class WorkerThread(Thread): | |
MAX_IDLE_TIME = 10 # seconds | |
def __init__( | |
self, | |
root_task: asyncio.Task, | |
workers: set[WorkerThread], | |
idle_workers: deque[WorkerThread], | |
): | |
super().__init__(name="AnyIO worker thread") | |
self.root_task = root_task | |
self.workers = workers | |
self.idle_workers = idle_workers | |
self.loop = root_task._loop | |
self.queue: Queue[ | |
tuple[Context, Callable, tuple, asyncio.Future] | None | |
] = Queue(2) | |
self.idle_since = current_time() | |
self.stopping = False | |
def _report_result( | |
self, future: asyncio.Future, result: Any, exc: BaseException | None | |
) -> None: | |
self.idle_since = current_time() | |
if not self.stopping: | |
self.idle_workers.append(self) | |
if not future.cancelled(): | |
if exc is not None: | |
if isinstance(exc, StopIteration): | |
new_exc = RuntimeError("coroutine raised StopIteration") | |
new_exc.__cause__ = exc | |
exc = new_exc | |
future.set_exception(exc) | |
else: | |
future.set_result(result) | |
def run(self) -> None: | |
with claim_worker_thread("asyncio"): | |
threadlocals.loop = self.loop | |
while True: | |
item = self.queue.get() | |
if item is None: | |
# Shutdown command received | |
return | |
context, func, args, future = item | |
if not future.cancelled(): | |
result = None | |
exception: BaseException | None = None | |
try: | |
result = context.run(func, *args) | |
except BaseException as exc: | |
exception = exc | |
if not self.loop.is_closed(): | |
self.loop.call_soon_threadsafe( | |
self._report_result, future, result, exception | |
) | |
self.queue.task_done() | |
def stop(self, f: asyncio.Task | None = None) -> None: | |
self.stopping = True | |
self.queue.put_nowait(None) | |
self.workers.discard(self) | |
try: | |
self.idle_workers.remove(self) | |
except ValueError: | |
pass | |
_threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar( | |
"_threadpool_idle_workers" | |
) | |
_threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers") | |
async def run_sync_in_worker_thread( | |
func: Callable[..., T_Retval], | |
*args: object, | |
cancellable: bool = False, | |
limiter: CapacityLimiter | None = None, | |
) -> T_Retval: | |
await checkpoint() | |
# If this is the first run in this event loop thread, set up the necessary variables | |
try: | |
idle_workers = _threadpool_idle_workers.get() | |
workers = _threadpool_workers.get() | |
except LookupError: | |
idle_workers = deque() | |
workers = set() | |
_threadpool_idle_workers.set(idle_workers) | |
_threadpool_workers.set(workers) | |
async with (limiter or current_default_thread_limiter()): | |
with CancelScope(shield=not cancellable): | |
future: asyncio.Future = asyncio.Future() | |
root_task = find_root_task() | |
if not idle_workers: | |
worker = WorkerThread(root_task, workers, idle_workers) | |
worker.start() | |
workers.add(worker) | |
root_task.add_done_callback(worker.stop) | |
else: | |
worker = idle_workers.pop() | |
# Prune any other workers that have been idle for MAX_IDLE_TIME seconds or longer | |
now = current_time() | |
while idle_workers: | |
if now - idle_workers[0].idle_since < WorkerThread.MAX_IDLE_TIME: | |
break | |
expired_worker = idle_workers.popleft() | |
expired_worker.root_task.remove_done_callback(expired_worker.stop) | |
expired_worker.stop() | |
context = copy_context() | |
context.run(sniffio.current_async_library_cvar.set, None) | |
worker.queue.put_nowait((context, func, args, future)) | |
return await future | |
def run_sync_from_thread( | |
func: Callable[..., T_Retval], | |
*args: object, | |
loop: asyncio.AbstractEventLoop | None = None, | |
) -> T_Retval: | |
def wrapper() -> None: | |
try: | |
f.set_result(func(*args)) | |
except BaseException as exc: | |
f.set_exception(exc) | |
if not isinstance(exc, Exception): | |
raise | |
f: concurrent.futures.Future[T_Retval] = Future() | |
loop = loop or threadlocals.loop | |
loop.call_soon_threadsafe(wrapper) | |
return f.result() | |
def run_async_from_thread( | |
func: Callable[..., Awaitable[T_Retval]], *args: object | |
) -> T_Retval: | |
f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( | |
func(*args), threadlocals.loop | |
) | |
return f.result() | |
class BlockingPortal(abc.BlockingPortal): | |
def __new__(cls) -> BlockingPortal: | |
return object.__new__(cls) | |
def __init__(self) -> None: | |
super().__init__() | |
self._loop = get_running_loop() | |
def _spawn_task_from_thread( | |
self, | |
func: Callable, | |
args: tuple, | |
kwargs: dict[str, Any], | |
name: object, | |
future: Future, | |
) -> None: | |
run_sync_from_thread( | |
partial(self._task_group.start_soon, name=name), | |
self._call_func, | |
func, | |
args, | |
kwargs, | |
future, | |
loop=self._loop, | |
) | |
# | |
# Subprocesses | |
# | |
class StreamReaderWrapper(abc.ByteReceiveStream): | |
_stream: asyncio.StreamReader | |
async def receive(self, max_bytes: int = 65536) -> bytes: | |
data = await self._stream.read(max_bytes) | |
if data: | |
return data | |
else: | |
raise EndOfStream | |
async def aclose(self) -> None: | |
self._stream.feed_eof() | |
class StreamWriterWrapper(abc.ByteSendStream): | |
_stream: asyncio.StreamWriter | |
async def send(self, item: bytes) -> None: | |
self._stream.write(item) | |
await self._stream.drain() | |
async def aclose(self) -> None: | |
self._stream.close() | |
class Process(abc.Process): | |
_process: asyncio.subprocess.Process | |
_stdin: StreamWriterWrapper | None | |
_stdout: StreamReaderWrapper | None | |
_stderr: StreamReaderWrapper | None | |
async def aclose(self) -> None: | |
if self._stdin: | |
await self._stdin.aclose() | |
if self._stdout: | |
await self._stdout.aclose() | |
if self._stderr: | |
await self._stderr.aclose() | |
await self.wait() | |
async def wait(self) -> int: | |
return await self._process.wait() | |
def terminate(self) -> None: | |
self._process.terminate() | |
def kill(self) -> None: | |
self._process.kill() | |
def send_signal(self, signal: int) -> None: | |
self._process.send_signal(signal) | |
def pid(self) -> int: | |
return self._process.pid | |
def returncode(self) -> int | None: | |
return self._process.returncode | |
def stdin(self) -> abc.ByteSendStream | None: | |
return self._stdin | |
def stdout(self) -> abc.ByteReceiveStream | None: | |
return self._stdout | |
def stderr(self) -> abc.ByteReceiveStream | None: | |
return self._stderr | |
async def open_process( | |
command: str | bytes | Sequence[str | bytes], | |
*, | |
shell: bool, | |
stdin: int | IO[Any] | None, | |
stdout: int | IO[Any] | None, | |
stderr: int | IO[Any] | None, | |
cwd: str | bytes | PathLike | None = None, | |
env: Mapping[str, str] | None = None, | |
start_new_session: bool = False, | |
) -> Process: | |
await checkpoint() | |
if shell: | |
process = await asyncio.create_subprocess_shell( | |
cast(Union[str, bytes], command), | |
stdin=stdin, | |
stdout=stdout, | |
stderr=stderr, | |
cwd=cwd, | |
env=env, | |
start_new_session=start_new_session, | |
) | |
else: | |
process = await asyncio.create_subprocess_exec( | |
*command, | |
stdin=stdin, | |
stdout=stdout, | |
stderr=stderr, | |
cwd=cwd, | |
env=env, | |
start_new_session=start_new_session, | |
) | |
stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None | |
stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None | |
stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None | |
return Process(process, stdin_stream, stdout_stream, stderr_stream) | |
def _forcibly_shutdown_process_pool_on_exit( | |
workers: set[Process], _task: object | |
) -> None: | |
""" | |
Forcibly shuts down worker processes belonging to this event loop.""" | |
child_watcher: asyncio.AbstractChildWatcher | None | |
try: | |
child_watcher = asyncio.get_event_loop_policy().get_child_watcher() | |
except NotImplementedError: | |
child_watcher = None | |
# Close as much as possible (w/o async/await) to avoid warnings | |
for process in workers: | |
if process.returncode is None: | |
continue | |
process._stdin._stream._transport.close() # type: ignore[union-attr] | |
process._stdout._stream._transport.close() # type: ignore[union-attr] | |
process._stderr._stream._transport.close() # type: ignore[union-attr] | |
process.kill() | |
if child_watcher: | |
child_watcher.remove_child_handler(process.pid) | |
async def _shutdown_process_pool_on_exit(workers: set[Process]) -> None: | |
""" | |
Shuts down worker processes belonging to this event loop. | |
NOTE: this only works when the event loop was started using asyncio.run() or anyio.run(). | |
""" | |
process: Process | |
try: | |
await sleep(math.inf) | |
except asyncio.CancelledError: | |
for process in workers: | |
if process.returncode is None: | |
process.kill() | |
for process in workers: | |
await process.aclose() | |
def setup_process_pool_exit_at_shutdown(workers: set[Process]) -> None: | |
kwargs: dict[str, Any] = ( | |
{"name": "AnyIO process pool shutdown task"} if _native_task_names else {} | |
) | |
create_task(_shutdown_process_pool_on_exit(workers), **kwargs) | |
find_root_task().add_done_callback( | |
partial(_forcibly_shutdown_process_pool_on_exit, workers) | |
) | |
# | |
# Sockets and networking | |
# | |
class StreamProtocol(asyncio.Protocol): | |
read_queue: deque[bytes] | |
read_event: asyncio.Event | |
write_event: asyncio.Event | |
exception: Exception | None = None | |
def connection_made(self, transport: asyncio.BaseTransport) -> None: | |
self.read_queue = deque() | |
self.read_event = asyncio.Event() | |
self.write_event = asyncio.Event() | |
self.write_event.set() | |
cast(asyncio.Transport, transport).set_write_buffer_limits(0) | |
def connection_lost(self, exc: Exception | None) -> None: | |
if exc: | |
self.exception = BrokenResourceError() | |
self.exception.__cause__ = exc | |
self.read_event.set() | |
self.write_event.set() | |
def data_received(self, data: bytes) -> None: | |
self.read_queue.append(data) | |
self.read_event.set() | |
def eof_received(self) -> bool | None: | |
self.read_event.set() | |
return True | |
def pause_writing(self) -> None: | |
self.write_event = asyncio.Event() | |
def resume_writing(self) -> None: | |
self.write_event.set() | |
class DatagramProtocol(asyncio.DatagramProtocol): | |
read_queue: deque[tuple[bytes, IPSockAddrType]] | |
read_event: asyncio.Event | |
write_event: asyncio.Event | |
exception: Exception | None = None | |
def connection_made(self, transport: asyncio.BaseTransport) -> None: | |
self.read_queue = deque(maxlen=100) # arbitrary value | |
self.read_event = asyncio.Event() | |
self.write_event = asyncio.Event() | |
self.write_event.set() | |
def connection_lost(self, exc: Exception | None) -> None: | |
self.read_event.set() | |
self.write_event.set() | |
def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None: | |
addr = convert_ipv6_sockaddr(addr) | |
self.read_queue.append((data, addr)) | |
self.read_event.set() | |
def error_received(self, exc: Exception) -> None: | |
self.exception = exc | |
def pause_writing(self) -> None: | |
self.write_event.clear() | |
def resume_writing(self) -> None: | |
self.write_event.set() | |
class SocketStream(abc.SocketStream): | |
def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol): | |
self._transport = transport | |
self._protocol = protocol | |
self._receive_guard = ResourceGuard("reading from") | |
self._send_guard = ResourceGuard("writing to") | |
self._closed = False | |
def _raw_socket(self) -> socket.socket: | |
return self._transport.get_extra_info("socket") | |
async def receive(self, max_bytes: int = 65536) -> bytes: | |
with self._receive_guard: | |
await checkpoint() | |
if ( | |
not self._protocol.read_event.is_set() | |
and not self._transport.is_closing() | |
): | |
self._transport.resume_reading() | |
await self._protocol.read_event.wait() | |
self._transport.pause_reading() | |
try: | |
chunk = self._protocol.read_queue.popleft() | |
except IndexError: | |
if self._closed: | |
raise ClosedResourceError from None | |
elif self._protocol.exception: | |
raise self._protocol.exception | |
else: | |
raise EndOfStream from None | |
if len(chunk) > max_bytes: | |
# Split the oversized chunk | |
chunk, leftover = chunk[:max_bytes], chunk[max_bytes:] | |
self._protocol.read_queue.appendleft(leftover) | |
# If the read queue is empty, clear the flag so that the next call will block until | |
# data is available | |
if not self._protocol.read_queue: | |
self._protocol.read_event.clear() | |
return chunk | |
async def send(self, item: bytes) -> None: | |
with self._send_guard: | |
await checkpoint() | |
if self._closed: | |
raise ClosedResourceError | |
elif self._protocol.exception is not None: | |
raise self._protocol.exception | |
try: | |
self._transport.write(item) | |
except RuntimeError as exc: | |
if self._transport.is_closing(): | |
raise BrokenResourceError from exc | |
else: | |
raise | |
await self._protocol.write_event.wait() | |
async def send_eof(self) -> None: | |
try: | |
self._transport.write_eof() | |
except OSError: | |
pass | |
async def aclose(self) -> None: | |
if not self._transport.is_closing(): | |
self._closed = True | |
try: | |
self._transport.write_eof() | |
except OSError: | |
pass | |
self._transport.close() | |
await sleep(0) | |
self._transport.abort() | |
class UNIXSocketStream(abc.SocketStream): | |
_receive_future: asyncio.Future | None = None | |
_send_future: asyncio.Future | None = None | |
_closing = False | |
def __init__(self, raw_socket: socket.socket): | |
self.__raw_socket = raw_socket | |
self._loop = get_running_loop() | |
self._receive_guard = ResourceGuard("reading from") | |
self._send_guard = ResourceGuard("writing to") | |
def _raw_socket(self) -> socket.socket: | |
return self.__raw_socket | |
def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: | |
def callback(f: object) -> None: | |
del self._receive_future | |
loop.remove_reader(self.__raw_socket) | |
f = self._receive_future = asyncio.Future() | |
self._loop.add_reader(self.__raw_socket, f.set_result, None) | |
f.add_done_callback(callback) | |
return f | |
def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: | |
def callback(f: object) -> None: | |
del self._send_future | |
loop.remove_writer(self.__raw_socket) | |
f = self._send_future = asyncio.Future() | |
self._loop.add_writer(self.__raw_socket, f.set_result, None) | |
f.add_done_callback(callback) | |
return f | |
async def send_eof(self) -> None: | |
with self._send_guard: | |
self._raw_socket.shutdown(socket.SHUT_WR) | |
async def receive(self, max_bytes: int = 65536) -> bytes: | |
loop = get_running_loop() | |
await checkpoint() | |
with self._receive_guard: | |
while True: | |
try: | |
data = self.__raw_socket.recv(max_bytes) | |
except BlockingIOError: | |
await self._wait_until_readable(loop) | |
except OSError as exc: | |
if self._closing: | |
raise ClosedResourceError from None | |
else: | |
raise BrokenResourceError from exc | |
else: | |
if not data: | |
raise EndOfStream | |
return data | |
async def send(self, item: bytes) -> None: | |
loop = get_running_loop() | |
await checkpoint() | |
with self._send_guard: | |
view = memoryview(item) | |
while view: | |
try: | |
bytes_sent = self.__raw_socket.send(item) | |
except BlockingIOError: | |
await self._wait_until_writable(loop) | |
except OSError as exc: | |
if self._closing: | |
raise ClosedResourceError from None | |
else: | |
raise BrokenResourceError from exc | |
else: | |
view = view[bytes_sent:] | |
async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: | |
if not isinstance(msglen, int) or msglen < 0: | |
raise ValueError("msglen must be a non-negative integer") | |
if not isinstance(maxfds, int) or maxfds < 1: | |
raise ValueError("maxfds must be a positive integer") | |
loop = get_running_loop() | |
fds = array.array("i") | |
await checkpoint() | |
with self._receive_guard: | |
while True: | |
try: | |
message, ancdata, flags, addr = self.__raw_socket.recvmsg( | |
msglen, socket.CMSG_LEN(maxfds * fds.itemsize) | |
) | |
except BlockingIOError: | |
await self._wait_until_readable(loop) | |
except OSError as exc: | |
if self._closing: | |
raise ClosedResourceError from None | |
else: | |
raise BrokenResourceError from exc | |
else: | |
if not message and not ancdata: | |
raise EndOfStream | |
break | |
for cmsg_level, cmsg_type, cmsg_data in ancdata: | |
if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS: | |
raise RuntimeError( | |
f"Received unexpected ancillary data; message = {message!r}, " | |
f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}" | |
) | |
fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) | |
return message, list(fds) | |
async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: | |
if not message: | |
raise ValueError("message must not be empty") | |
if not fds: | |
raise ValueError("fds must not be empty") | |
loop = get_running_loop() | |
filenos: list[int] = [] | |
for fd in fds: | |
if isinstance(fd, int): | |
filenos.append(fd) | |
elif isinstance(fd, IOBase): | |
filenos.append(fd.fileno()) | |
fdarray = array.array("i", filenos) | |
await checkpoint() | |
with self._send_guard: | |
while True: | |
try: | |
# The ignore can be removed after mypy picks up | |
# https://github.com/python/typeshed/pull/5545 | |
self.__raw_socket.sendmsg( | |
[message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)] | |
) | |
break | |
except BlockingIOError: | |
await self._wait_until_writable(loop) | |
except OSError as exc: | |
if self._closing: | |
raise ClosedResourceError from None | |
else: | |
raise BrokenResourceError from exc | |
async def aclose(self) -> None: | |
if not self._closing: | |
self._closing = True | |
if self.__raw_socket.fileno() != -1: | |
self.__raw_socket.close() | |
if self._receive_future: | |
self._receive_future.set_result(None) | |
if self._send_future: | |
self._send_future.set_result(None) | |
class TCPSocketListener(abc.SocketListener): | |
_accept_scope: CancelScope | None = None | |
_closed = False | |
def __init__(self, raw_socket: socket.socket): | |
self.__raw_socket = raw_socket | |
self._loop = cast(asyncio.BaseEventLoop, get_running_loop()) | |
self._accept_guard = ResourceGuard("accepting connections from") | |
def _raw_socket(self) -> socket.socket: | |
return self.__raw_socket | |
async def accept(self) -> abc.SocketStream: | |
if self._closed: | |
raise ClosedResourceError | |
with self._accept_guard: | |
await checkpoint() | |
with CancelScope() as self._accept_scope: | |
try: | |
client_sock, _addr = await self._loop.sock_accept(self._raw_socket) | |
except asyncio.CancelledError: | |
# Workaround for https://bugs.python.org/issue41317 | |
try: | |
self._loop.remove_reader(self._raw_socket) | |
except (ValueError, NotImplementedError): | |
pass | |
if self._closed: | |
raise ClosedResourceError from None | |
raise | |
finally: | |
self._accept_scope = None | |
client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) | |
transport, protocol = await self._loop.connect_accepted_socket( | |
StreamProtocol, client_sock | |
) | |
return SocketStream(transport, protocol) | |
async def aclose(self) -> None: | |
if self._closed: | |
return | |
self._closed = True | |
if self._accept_scope: | |
# Workaround for https://bugs.python.org/issue41317 | |
try: | |
self._loop.remove_reader(self._raw_socket) | |
except (ValueError, NotImplementedError): | |
pass | |
self._accept_scope.cancel() | |
await sleep(0) | |
self._raw_socket.close() | |
class UNIXSocketListener(abc.SocketListener): | |
def __init__(self, raw_socket: socket.socket): | |
self.__raw_socket = raw_socket | |
self._loop = get_running_loop() | |
self._accept_guard = ResourceGuard("accepting connections from") | |
self._closed = False | |
async def accept(self) -> abc.SocketStream: | |
await checkpoint() | |
with self._accept_guard: | |
while True: | |
try: | |
client_sock, _ = self.__raw_socket.accept() | |
client_sock.setblocking(False) | |
return UNIXSocketStream(client_sock) | |
except BlockingIOError: | |
f: asyncio.Future = asyncio.Future() | |
self._loop.add_reader(self.__raw_socket, f.set_result, None) | |
f.add_done_callback( | |
lambda _: self._loop.remove_reader(self.__raw_socket) | |
) | |
await f | |
except OSError as exc: | |
if self._closed: | |
raise ClosedResourceError from None | |
else: | |
raise BrokenResourceError from exc | |
async def aclose(self) -> None: | |
self._closed = True | |
self.__raw_socket.close() | |
def _raw_socket(self) -> socket.socket: | |
return self.__raw_socket | |
class UDPSocket(abc.UDPSocket): | |
def __init__( | |
self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol | |
): | |
self._transport = transport | |
self._protocol = protocol | |
self._receive_guard = ResourceGuard("reading from") | |
self._send_guard = ResourceGuard("writing to") | |
self._closed = False | |
def _raw_socket(self) -> socket.socket: | |
return self._transport.get_extra_info("socket") | |
async def aclose(self) -> None: | |
if not self._transport.is_closing(): | |
self._closed = True | |
self._transport.close() | |
async def receive(self) -> tuple[bytes, IPSockAddrType]: | |
with self._receive_guard: | |
await checkpoint() | |
# If the buffer is empty, ask for more data | |
if not self._protocol.read_queue and not self._transport.is_closing(): | |
self._protocol.read_event.clear() | |
await self._protocol.read_event.wait() | |
try: | |
return self._protocol.read_queue.popleft() | |
except IndexError: | |
if self._closed: | |
raise ClosedResourceError from None | |
else: | |
raise BrokenResourceError from None | |
async def send(self, item: UDPPacketType) -> None: | |
with self._send_guard: | |
await checkpoint() | |
await self._protocol.write_event.wait() | |
if self._closed: | |
raise ClosedResourceError | |
elif self._transport.is_closing(): | |
raise BrokenResourceError | |
else: | |
self._transport.sendto(*item) | |
class ConnectedUDPSocket(abc.ConnectedUDPSocket): | |
def __init__( | |
self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol | |
): | |
self._transport = transport | |
self._protocol = protocol | |
self._receive_guard = ResourceGuard("reading from") | |
self._send_guard = ResourceGuard("writing to") | |
self._closed = False | |
def _raw_socket(self) -> socket.socket: | |
return self._transport.get_extra_info("socket") | |
async def aclose(self) -> None: | |
if not self._transport.is_closing(): | |
self._closed = True | |
self._transport.close() | |
async def receive(self) -> bytes: | |
with self._receive_guard: | |
await checkpoint() | |
# If the buffer is empty, ask for more data | |
if not self._protocol.read_queue and not self._transport.is_closing(): | |
self._protocol.read_event.clear() | |
await self._protocol.read_event.wait() | |
try: | |
packet = self._protocol.read_queue.popleft() | |
except IndexError: | |
if self._closed: | |
raise ClosedResourceError from None | |
else: | |
raise BrokenResourceError from None | |
return packet[0] | |
async def send(self, item: bytes) -> None: | |
with self._send_guard: | |
await checkpoint() | |
await self._protocol.write_event.wait() | |
if self._closed: | |
raise ClosedResourceError | |
elif self._transport.is_closing(): | |
raise BrokenResourceError | |
else: | |
self._transport.sendto(item) | |
async def connect_tcp( | |
host: str, port: int, local_addr: tuple[str, int] | None = None | |
) -> SocketStream: | |
transport, protocol = cast( | |
Tuple[asyncio.Transport, StreamProtocol], | |
await get_running_loop().create_connection( | |
StreamProtocol, host, port, local_addr=local_addr | |
), | |
) | |
transport.pause_reading() | |
return SocketStream(transport, protocol) | |
async def connect_unix(path: str) -> UNIXSocketStream: | |
await checkpoint() | |
loop = get_running_loop() | |
raw_socket = socket.socket(socket.AF_UNIX) | |
raw_socket.setblocking(False) | |
while True: | |
try: | |
raw_socket.connect(path) | |
except BlockingIOError: | |
f: asyncio.Future = asyncio.Future() | |
loop.add_writer(raw_socket, f.set_result, None) | |
f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) | |
await f | |
except BaseException: | |
raw_socket.close() | |
raise | |
else: | |
return UNIXSocketStream(raw_socket) | |
async def create_udp_socket( | |
family: socket.AddressFamily, | |
local_address: IPSockAddrType | None, | |
remote_address: IPSockAddrType | None, | |
reuse_port: bool, | |
) -> UDPSocket | ConnectedUDPSocket: | |
result = await get_running_loop().create_datagram_endpoint( | |
DatagramProtocol, | |
local_addr=local_address, | |
remote_addr=remote_address, | |
family=family, | |
reuse_port=reuse_port, | |
) | |
transport = result[0] | |
protocol = result[1] | |
if protocol.exception: | |
transport.close() | |
raise protocol.exception | |
if not remote_address: | |
return UDPSocket(transport, protocol) | |
else: | |
return ConnectedUDPSocket(transport, protocol) | |
async def getaddrinfo( | |
host: bytes | str, | |
port: str | int | None, | |
*, | |
family: int | AddressFamily = 0, | |
type: int | SocketKind = 0, | |
proto: int = 0, | |
flags: int = 0, | |
) -> GetAddrInfoReturnType: | |
# https://github.com/python/typeshed/pull/4304 | |
result = await get_running_loop().getaddrinfo( | |
host, port, family=family, type=type, proto=proto, flags=flags | |
) | |
return cast(GetAddrInfoReturnType, result) | |
async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> tuple[str, str]: | |
return await get_running_loop().getnameinfo(sockaddr, flags) | |
_read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events") | |
_write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events") | |
async def wait_socket_readable(sock: socket.socket) -> None: | |
await checkpoint() | |
try: | |
read_events = _read_events.get() | |
except LookupError: | |
read_events = {} | |
_read_events.set(read_events) | |
if read_events.get(sock): | |
raise BusyResourceError("reading from") from None | |
loop = get_running_loop() | |
event = read_events[sock] = asyncio.Event() | |
loop.add_reader(sock, event.set) | |
try: | |
await event.wait() | |
finally: | |
if read_events.pop(sock, None) is not None: | |
loop.remove_reader(sock) | |
readable = True | |
else: | |
readable = False | |
if not readable: | |
raise ClosedResourceError | |
async def wait_socket_writable(sock: socket.socket) -> None: | |
await checkpoint() | |
try: | |
write_events = _write_events.get() | |
except LookupError: | |
write_events = {} | |
_write_events.set(write_events) | |
if write_events.get(sock): | |
raise BusyResourceError("writing to") from None | |
loop = get_running_loop() | |
event = write_events[sock] = asyncio.Event() | |
loop.add_writer(sock.fileno(), event.set) | |
try: | |
await event.wait() | |
finally: | |
if write_events.pop(sock, None) is not None: | |
loop.remove_writer(sock) | |
writable = True | |
else: | |
writable = False | |
if not writable: | |
raise ClosedResourceError | |
# | |
# Synchronization | |
# | |
class Event(BaseEvent): | |
def __new__(cls) -> Event: | |
return object.__new__(cls) | |
def __init__(self) -> None: | |
self._event = asyncio.Event() | |
def set(self) -> DeprecatedAwaitable: | |
self._event.set() | |
return DeprecatedAwaitable(self.set) | |
def is_set(self) -> bool: | |
return self._event.is_set() | |
async def wait(self) -> None: | |
if await self._event.wait(): | |
await checkpoint() | |
def statistics(self) -> EventStatistics: | |
return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined] | |
class CapacityLimiter(BaseCapacityLimiter): | |
_total_tokens: float = 0 | |
def __new__(cls, total_tokens: float) -> CapacityLimiter: | |
return object.__new__(cls) | |
def __init__(self, total_tokens: float): | |
self._borrowers: set[Any] = set() | |
self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict() | |
self.total_tokens = total_tokens | |
async def __aenter__(self) -> None: | |
await self.acquire() | |
async def __aexit__( | |
self, | |
exc_type: type[BaseException] | None, | |
exc_val: BaseException | None, | |
exc_tb: TracebackType | None, | |
) -> None: | |
self.release() | |
def total_tokens(self) -> float: | |
return self._total_tokens | |
def total_tokens(self, value: float) -> None: | |
if not isinstance(value, int) and not math.isinf(value): | |
raise TypeError("total_tokens must be an int or math.inf") | |
if value < 1: | |
raise ValueError("total_tokens must be >= 1") | |
old_value = self._total_tokens | |
self._total_tokens = value | |
events = [] | |
for event in self._wait_queue.values(): | |
if value <= old_value: | |
break | |
if not event.is_set(): | |
events.append(event) | |
old_value += 1 | |
for event in events: | |
event.set() | |
def borrowed_tokens(self) -> int: | |
return len(self._borrowers) | |
def available_tokens(self) -> float: | |
return self._total_tokens - len(self._borrowers) | |
def acquire_nowait(self) -> DeprecatedAwaitable: | |
self.acquire_on_behalf_of_nowait(current_task()) | |
return DeprecatedAwaitable(self.acquire_nowait) | |
def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: | |
if borrower in self._borrowers: | |
raise RuntimeError( | |
"this borrower is already holding one of this CapacityLimiter's " | |
"tokens" | |
) | |
if self._wait_queue or len(self._borrowers) >= self._total_tokens: | |
raise WouldBlock | |
self._borrowers.add(borrower) | |
return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait) | |
async def acquire(self) -> None: | |
return await self.acquire_on_behalf_of(current_task()) | |
async def acquire_on_behalf_of(self, borrower: object) -> None: | |
await checkpoint_if_cancelled() | |
try: | |
self.acquire_on_behalf_of_nowait(borrower) | |
except WouldBlock: | |
event = asyncio.Event() | |
self._wait_queue[borrower] = event | |
try: | |
await event.wait() | |
except BaseException: | |
self._wait_queue.pop(borrower, None) | |
raise | |
self._borrowers.add(borrower) | |
else: | |
try: | |
await cancel_shielded_checkpoint() | |
except BaseException: | |
self.release() | |
raise | |
def release(self) -> None: | |
self.release_on_behalf_of(current_task()) | |
def release_on_behalf_of(self, borrower: object) -> None: | |
try: | |
self._borrowers.remove(borrower) | |
except KeyError: | |
raise RuntimeError( | |
"this borrower isn't holding any of this CapacityLimiter's " "tokens" | |
) from None | |
# Notify the next task in line if this limiter has free capacity now | |
if self._wait_queue and len(self._borrowers) < self._total_tokens: | |
event = self._wait_queue.popitem(last=False)[1] | |
event.set() | |
def statistics(self) -> CapacityLimiterStatistics: | |
return CapacityLimiterStatistics( | |
self.borrowed_tokens, | |
self.total_tokens, | |
tuple(self._borrowers), | |
len(self._wait_queue), | |
) | |
_default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter") | |
def current_default_thread_limiter() -> CapacityLimiter: | |
try: | |
return _default_thread_limiter.get() | |
except LookupError: | |
limiter = CapacityLimiter(40) | |
_default_thread_limiter.set(limiter) | |
return limiter | |
# | |
# Operating system signals | |
# | |
class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): | |
def __init__(self, signals: tuple[int, ...]): | |
self._signals = signals | |
self._loop = get_running_loop() | |
self._signal_queue: deque[int] = deque() | |
self._future: asyncio.Future = asyncio.Future() | |
self._handled_signals: set[int] = set() | |
def _deliver(self, signum: int) -> None: | |
self._signal_queue.append(signum) | |
if not self._future.done(): | |
self._future.set_result(None) | |
def __enter__(self) -> _SignalReceiver: | |
for sig in set(self._signals): | |
self._loop.add_signal_handler(sig, self._deliver, sig) | |
self._handled_signals.add(sig) | |
return self | |
def __exit__( | |
self, | |
exc_type: type[BaseException] | None, | |
exc_val: BaseException | None, | |
exc_tb: TracebackType | None, | |
) -> bool | None: | |
for sig in self._handled_signals: | |
self._loop.remove_signal_handler(sig) | |
return None | |
def __aiter__(self) -> _SignalReceiver: | |
return self | |
async def __anext__(self) -> int: | |
await checkpoint() | |
if not self._signal_queue: | |
self._future = asyncio.Future() | |
await self._future | |
return self._signal_queue.popleft() | |
def open_signal_receiver(*signals: int) -> _SignalReceiver: | |
return _SignalReceiver(signals) | |
# | |
# Testing and debugging | |
# | |
def _create_task_info(task: asyncio.Task) -> TaskInfo: | |
task_state = _task_states.get(task) | |
if task_state is None: | |
name = task.get_name() if _native_task_names else None | |
parent_id = None | |
else: | |
name = task_state.name | |
parent_id = task_state.parent_id | |
return TaskInfo(id(task), parent_id, name, get_coro(task)) | |
def get_current_task() -> TaskInfo: | |
return _create_task_info(current_task()) # type: ignore[arg-type] | |
def get_running_tasks() -> list[TaskInfo]: | |
return [_create_task_info(task) for task in all_tasks() if not task.done()] | |
async def wait_all_tasks_blocked() -> None: | |
await checkpoint() | |
this_task = current_task() | |
while True: | |
for task in all_tasks(): | |
if task is this_task: | |
continue | |
if task._fut_waiter is None or task._fut_waiter.done(): # type: ignore[attr-defined] | |
await sleep(0.1) | |
break | |
else: | |
return | |
class TestRunner(abc.TestRunner): | |
def __init__( | |
self, | |
debug: bool = False, | |
use_uvloop: bool = False, | |
policy: asyncio.AbstractEventLoopPolicy | None = None, | |
): | |
self._exceptions: list[BaseException] = [] | |
_maybe_set_event_loop_policy(policy, use_uvloop) | |
self._loop = asyncio.new_event_loop() | |
self._loop.set_debug(debug) | |
self._loop.set_exception_handler(self._exception_handler) | |
asyncio.set_event_loop(self._loop) | |
def _cancel_all_tasks(self) -> None: | |
to_cancel = all_tasks(self._loop) | |
if not to_cancel: | |
return | |
for task in to_cancel: | |
task.cancel() | |
self._loop.run_until_complete( | |
asyncio.gather(*to_cancel, return_exceptions=True) | |
) | |
for task in to_cancel: | |
if task.cancelled(): | |
continue | |
if task.exception() is not None: | |
raise cast(BaseException, task.exception()) | |
def _exception_handler( | |
self, loop: asyncio.AbstractEventLoop, context: dict[str, Any] | |
) -> None: | |
if isinstance(context.get("exception"), Exception): | |
self._exceptions.append(context["exception"]) | |
else: | |
loop.default_exception_handler(context) | |
def _raise_async_exceptions(self) -> None: | |
# Re-raise any exceptions raised in asynchronous callbacks | |
if self._exceptions: | |
exceptions, self._exceptions = self._exceptions, [] | |
if len(exceptions) == 1: | |
raise exceptions[0] | |
elif exceptions: | |
raise ExceptionGroup(exceptions) | |
def close(self) -> None: | |
try: | |
self._cancel_all_tasks() | |
self._loop.run_until_complete(self._loop.shutdown_asyncgens()) | |
finally: | |
asyncio.set_event_loop(None) | |
self._loop.close() | |
def run_asyncgen_fixture( | |
self, | |
fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], | |
kwargs: dict[str, Any], | |
) -> Iterable[T_Retval]: | |
async def fixture_runner() -> None: | |
agen = fixture_func(**kwargs) | |
try: | |
retval = await agen.asend(None) | |
self._raise_async_exceptions() | |
except BaseException as exc: | |
f.set_exception(exc) | |
return | |
else: | |
f.set_result(retval) | |
await event.wait() | |
try: | |
await agen.asend(None) | |
except StopAsyncIteration: | |
pass | |
else: | |
await agen.aclose() | |
raise RuntimeError("Async generator fixture did not stop") | |
f = self._loop.create_future() | |
event = asyncio.Event() | |
fixture_task = self._loop.create_task(fixture_runner()) | |
self._loop.run_until_complete(f) | |
yield f.result() | |
event.set() | |
self._loop.run_until_complete(fixture_task) | |
self._raise_async_exceptions() | |
def run_fixture( | |
self, | |
fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], | |
kwargs: dict[str, Any], | |
) -> T_Retval: | |
retval = self._loop.run_until_complete(fixture_func(**kwargs)) | |
self._raise_async_exceptions() | |
return retval | |
def run_test( | |
self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] | |
) -> None: | |
try: | |
self._loop.run_until_complete(test_func(**kwargs)) | |
except Exception as exc: | |
self._exceptions.append(exc) | |
self._raise_async_exceptions() | |