from __future__ import annotations import array import asyncio import concurrent.futures import math import socket import sys from asyncio.base_events import _run_until_complete_cb # type: ignore[attr-defined] from collections import OrderedDict, deque from concurrent.futures import Future from contextvars import Context, copy_context from dataclasses import dataclass from functools import partial, wraps from inspect import ( CORO_RUNNING, CORO_SUSPENDED, GEN_RUNNING, GEN_SUSPENDED, getcoroutinestate, getgeneratorstate, ) from io import IOBase from os import PathLike from queue import Queue from socket import AddressFamily, SocketKind from threading import Thread from types import TracebackType from typing import ( IO, Any, AsyncGenerator, Awaitable, Callable, Collection, Coroutine, Generator, Iterable, Mapping, Optional, Sequence, Tuple, TypeVar, Union, cast, ) from weakref import WeakKeyDictionary import sniffio from .. import CapacityLimiterStatistics, EventStatistics, TaskInfo, abc from .._core._compat import DeprecatedAsyncContextManager, DeprecatedAwaitable from .._core._eventloop import claim_worker_thread, threadlocals from .._core._exceptions import ( BrokenResourceError, BusyResourceError, ClosedResourceError, EndOfStream, WouldBlock, ) from .._core._exceptions import ExceptionGroup as BaseExceptionGroup from .._core._sockets import GetAddrInfoReturnType, convert_ipv6_sockaddr from .._core._synchronization import CapacityLimiter as BaseCapacityLimiter from .._core._synchronization import Event as BaseEvent from .._core._synchronization import ResourceGuard from .._core._tasks import CancelScope as BaseCancelScope from ..abc import IPSockAddrType, UDPPacketType from ..lowlevel import RunVar if sys.version_info >= (3, 8): def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: return task.get_coro() else: def get_coro(task: asyncio.Task) -> Generator | Awaitable[Any]: return task._coro from asyncio import all_tasks, create_task, current_task, get_running_loop from asyncio import run as native_run def _get_task_callbacks(task: asyncio.Task) -> Iterable[Callable]: return [cb for cb, context in task._callbacks] T_Retval = TypeVar("T_Retval") T_contra = TypeVar("T_contra", contravariant=True) # Check whether there is native support for task names in asyncio (3.8+) _native_task_names = hasattr(asyncio.Task, "get_name") _root_task: RunVar[asyncio.Task | None] = RunVar("_root_task") def find_root_task() -> asyncio.Task: root_task = _root_task.get(None) if root_task is not None and not root_task.done(): return root_task # Look for a task that has been started via run_until_complete() for task in all_tasks(): if task._callbacks and not task.done(): for cb in _get_task_callbacks(task): if ( cb is _run_until_complete_cb or getattr(cb, "__module__", None) == "uvloop.loop" ): _root_task.set(task) return task # Look up the topmost task in the AnyIO task tree, if possible task = cast(asyncio.Task, current_task()) state = _task_states.get(task) if state: cancel_scope = state.cancel_scope while cancel_scope and cancel_scope._parent_scope is not None: cancel_scope = cancel_scope._parent_scope if cancel_scope is not None: return cast(asyncio.Task, cancel_scope._host_task) return task def get_callable_name(func: Callable) -> str: module = getattr(func, "__module__", None) qualname = getattr(func, "__qualname__", None) return ".".join([x for x in (module, qualname) if x]) # # Event loop # _run_vars = ( WeakKeyDictionary() ) # type: WeakKeyDictionary[asyncio.AbstractEventLoop, Any] current_token = get_running_loop def _task_started(task: asyncio.Task) -> bool: """Return ``True`` if the task has been started and has not finished.""" coro = cast(Coroutine[Any, Any, Any], get_coro(task)) try: return getcoroutinestate(coro) in (CORO_RUNNING, CORO_SUSPENDED) except AttributeError: try: return getgeneratorstate(cast(Generator, coro)) in ( GEN_RUNNING, GEN_SUSPENDED, ) except AttributeError: # task coro is async_genenerator_asend https://bugs.python.org/issue37771 raise Exception(f"Cannot determine if task {task} has started or not") def _maybe_set_event_loop_policy( policy: asyncio.AbstractEventLoopPolicy | None, use_uvloop: bool ) -> None: # On CPython, use uvloop when possible if no other policy has been given and if not # explicitly disabled if policy is None and use_uvloop and sys.implementation.name == "cpython": try: import uvloop except ImportError: pass else: # Test for missing shutdown_default_executor() (uvloop 0.14.0 and earlier) if not hasattr( asyncio.AbstractEventLoop, "shutdown_default_executor" ) or hasattr(uvloop.loop.Loop, "shutdown_default_executor"): policy = uvloop.EventLoopPolicy() if policy is not None: asyncio.set_event_loop_policy(policy) def run( func: Callable[..., Awaitable[T_Retval]], *args: object, debug: bool = False, use_uvloop: bool = False, policy: asyncio.AbstractEventLoopPolicy | None = None, ) -> T_Retval: @wraps(func) async def wrapper() -> T_Retval: task = cast(asyncio.Task, current_task()) task_state = TaskState(None, get_callable_name(func), None) _task_states[task] = task_state if _native_task_names: task.set_name(task_state.name) try: return await func(*args) finally: del _task_states[task] _maybe_set_event_loop_policy(policy, use_uvloop) return native_run(wrapper(), debug=debug) # # Miscellaneous # sleep = asyncio.sleep # # Timeouts and cancellation # CancelledError = asyncio.CancelledError class CancelScope(BaseCancelScope): def __new__( cls, *, deadline: float = math.inf, shield: bool = False ) -> CancelScope: return object.__new__(cls) def __init__(self, deadline: float = math.inf, shield: bool = False): self._deadline = deadline self._shield = shield self._parent_scope: CancelScope | None = None self._cancel_called = False self._active = False self._timeout_handle: asyncio.TimerHandle | None = None self._cancel_handle: asyncio.Handle | None = None self._tasks: set[asyncio.Task] = set() self._host_task: asyncio.Task | None = None self._timeout_expired = False self._cancel_calls: int = 0 def __enter__(self) -> CancelScope: if self._active: raise RuntimeError( "Each CancelScope may only be used for a single 'with' block" ) self._host_task = host_task = cast(asyncio.Task, current_task()) self._tasks.add(host_task) try: task_state = _task_states[host_task] except KeyError: task_name = host_task.get_name() if _native_task_names else None task_state = TaskState(None, task_name, self) _task_states[host_task] = task_state else: self._parent_scope = task_state.cancel_scope task_state.cancel_scope = self self._timeout() self._active = True # Start cancelling the host task if the scope was cancelled before entering if self._cancel_called: self._deliver_cancellation() return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: if not self._active: raise RuntimeError("This cancel scope is not active") if current_task() is not self._host_task: raise RuntimeError( "Attempted to exit cancel scope in a different task than it was " "entered in" ) assert self._host_task is not None host_task_state = _task_states.get(self._host_task) if host_task_state is None or host_task_state.cancel_scope is not self: raise RuntimeError( "Attempted to exit a cancel scope that isn't the current tasks's " "current cancel scope" ) self._active = False if self._timeout_handle: self._timeout_handle.cancel() self._timeout_handle = None self._tasks.remove(self._host_task) host_task_state.cancel_scope = self._parent_scope # Restart the cancellation effort in the farthest directly cancelled parent scope if this # one was shielded if self._shield: self._deliver_cancellation_to_parent() if exc_val is not None: exceptions = ( exc_val.exceptions if isinstance(exc_val, ExceptionGroup) else [exc_val] ) if all(isinstance(exc, CancelledError) for exc in exceptions): if self._timeout_expired: return self._uncancel() elif not self._cancel_called: # Task was cancelled natively return None elif not self._parent_cancelled(): # This scope was directly cancelled return self._uncancel() return None def _uncancel(self) -> bool: if sys.version_info < (3, 11) or self._host_task is None: self._cancel_calls = 0 return True # Uncancel all AnyIO cancellations for i in range(self._cancel_calls): self._host_task.uncancel() self._cancel_calls = 0 return not self._host_task.cancelling() def _timeout(self) -> None: if self._deadline != math.inf: loop = get_running_loop() if loop.time() >= self._deadline: self._timeout_expired = True self.cancel() else: self._timeout_handle = loop.call_at(self._deadline, self._timeout) def _deliver_cancellation(self) -> None: """ Deliver cancellation to directly contained tasks and nested cancel scopes. Schedule another run at the end if we still have tasks eligible for cancellation. """ should_retry = False current = current_task() for task in self._tasks: if task._must_cancel: # type: ignore[attr-defined] continue # The task is eligible for cancellation if it has started and is not in a cancel # scope shielded from this one cancel_scope = _task_states[task].cancel_scope while cancel_scope is not self: if cancel_scope is None or cancel_scope._shield: break else: cancel_scope = cancel_scope._parent_scope else: should_retry = True if task is not current and ( task is self._host_task or _task_started(task) ): self._cancel_calls += 1 task.cancel() # Schedule another callback if there are still tasks left if should_retry: self._cancel_handle = get_running_loop().call_soon( self._deliver_cancellation ) else: self._cancel_handle = None def _deliver_cancellation_to_parent(self) -> None: """Start cancellation effort in the farthest directly cancelled parent scope""" scope = self._parent_scope scope_to_cancel: CancelScope | None = None while scope is not None: if scope._cancel_called and scope._cancel_handle is None: scope_to_cancel = scope # No point in looking beyond any shielded scope if scope._shield: break scope = scope._parent_scope if scope_to_cancel is not None: scope_to_cancel._deliver_cancellation() def _parent_cancelled(self) -> bool: # Check whether any parent has been cancelled cancel_scope = self._parent_scope while cancel_scope is not None and not cancel_scope._shield: if cancel_scope._cancel_called: return True else: cancel_scope = cancel_scope._parent_scope return False def cancel(self) -> DeprecatedAwaitable: if not self._cancel_called: if self._timeout_handle: self._timeout_handle.cancel() self._timeout_handle = None self._cancel_called = True if self._host_task is not None: self._deliver_cancellation() return DeprecatedAwaitable(self.cancel) @property def deadline(self) -> float: return self._deadline @deadline.setter def deadline(self, value: float) -> None: self._deadline = float(value) if self._timeout_handle is not None: self._timeout_handle.cancel() self._timeout_handle = None if self._active and not self._cancel_called: self._timeout() @property def cancel_called(self) -> bool: return self._cancel_called @property def shield(self) -> bool: return self._shield @shield.setter def shield(self, value: bool) -> None: if self._shield != value: self._shield = value if not value: self._deliver_cancellation_to_parent() async def checkpoint() -> None: await sleep(0) async def checkpoint_if_cancelled() -> None: task = current_task() if task is None: return try: cancel_scope = _task_states[task].cancel_scope except KeyError: return while cancel_scope: if cancel_scope.cancel_called: await sleep(0) elif cancel_scope.shield: break else: cancel_scope = cancel_scope._parent_scope async def cancel_shielded_checkpoint() -> None: with CancelScope(shield=True): await sleep(0) def current_effective_deadline() -> float: try: cancel_scope = _task_states[current_task()].cancel_scope # type: ignore[index] except KeyError: return math.inf deadline = math.inf while cancel_scope: deadline = min(deadline, cancel_scope.deadline) if cancel_scope._cancel_called: deadline = -math.inf break elif cancel_scope.shield: break else: cancel_scope = cancel_scope._parent_scope return deadline def current_time() -> float: return get_running_loop().time() # # Task states # class TaskState: """ Encapsulates auxiliary task information that cannot be added to the Task instance itself because there are no guarantees about its implementation. """ __slots__ = "parent_id", "name", "cancel_scope" def __init__( self, parent_id: int | None, name: str | None, cancel_scope: CancelScope | None, ): self.parent_id = parent_id self.name = name self.cancel_scope = cancel_scope _task_states = WeakKeyDictionary() # type: WeakKeyDictionary[asyncio.Task, TaskState] # # Task groups # class ExceptionGroup(BaseExceptionGroup): def __init__(self, exceptions: list[BaseException]): super().__init__() self.exceptions = exceptions class _AsyncioTaskStatus(abc.TaskStatus): def __init__(self, future: asyncio.Future, parent_id: int): self._future = future self._parent_id = parent_id def started(self, value: T_contra | None = None) -> None: try: self._future.set_result(value) except asyncio.InvalidStateError: raise RuntimeError( "called 'started' twice on the same task status" ) from None task = cast(asyncio.Task, current_task()) _task_states[task].parent_id = self._parent_id class TaskGroup(abc.TaskGroup): def __init__(self) -> None: self.cancel_scope: CancelScope = CancelScope() self._active = False self._exceptions: list[BaseException] = [] async def __aenter__(self) -> TaskGroup: self.cancel_scope.__enter__() self._active = True return self async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: ignore_exception = self.cancel_scope.__exit__(exc_type, exc_val, exc_tb) if exc_val is not None: self.cancel_scope.cancel() self._exceptions.append(exc_val) while self.cancel_scope._tasks: try: await asyncio.wait(self.cancel_scope._tasks) except asyncio.CancelledError: self.cancel_scope.cancel() self._active = False if not self.cancel_scope._parent_cancelled(): exceptions = self._filter_cancellation_errors(self._exceptions) else: exceptions = self._exceptions try: if len(exceptions) > 1: if all( isinstance(e, CancelledError) and not e.args for e in exceptions ): # Tasks were cancelled natively, without a cancellation message raise CancelledError else: raise ExceptionGroup(exceptions) elif exceptions and exceptions[0] is not exc_val: raise exceptions[0] except BaseException as exc: # Clear the context here, as it can only be done in-flight. # If the context is not cleared, it can result in recursive tracebacks (see #145). exc.__context__ = None raise return ignore_exception @staticmethod def _filter_cancellation_errors( exceptions: Sequence[BaseException], ) -> list[BaseException]: filtered_exceptions: list[BaseException] = [] for exc in exceptions: if isinstance(exc, ExceptionGroup): new_exceptions = TaskGroup._filter_cancellation_errors(exc.exceptions) if len(new_exceptions) > 1: filtered_exceptions.append(exc) elif len(new_exceptions) == 1: filtered_exceptions.append(new_exceptions[0]) elif new_exceptions: new_exc = ExceptionGroup(new_exceptions) new_exc.__cause__ = exc.__cause__ new_exc.__context__ = exc.__context__ new_exc.__traceback__ = exc.__traceback__ filtered_exceptions.append(new_exc) elif not isinstance(exc, CancelledError) or exc.args: filtered_exceptions.append(exc) return filtered_exceptions async def _run_wrapped_task( self, coro: Coroutine, task_status_future: asyncio.Future | None ) -> None: # This is the code path for Python 3.7 on which asyncio freaks out if a task # raises a BaseException. __traceback_hide__ = __tracebackhide__ = True # noqa: F841 task = cast(asyncio.Task, current_task()) try: await coro except BaseException as exc: if task_status_future is None or task_status_future.done(): self._exceptions.append(exc) self.cancel_scope.cancel() else: task_status_future.set_exception(exc) else: if task_status_future is not None and not task_status_future.done(): task_status_future.set_exception( RuntimeError("Child exited without calling task_status.started()") ) finally: if task in self.cancel_scope._tasks: self.cancel_scope._tasks.remove(task) del _task_states[task] def _spawn( self, func: Callable[..., Awaitable[Any]], args: tuple, name: object, task_status_future: asyncio.Future | None = None, ) -> asyncio.Task: def task_done(_task: asyncio.Task) -> None: # This is the code path for Python 3.8+ assert _task in self.cancel_scope._tasks self.cancel_scope._tasks.remove(_task) del _task_states[_task] try: exc = _task.exception() except CancelledError as e: while isinstance(e.__context__, CancelledError): e = e.__context__ exc = e if exc is not None: if task_status_future is None or task_status_future.done(): self._exceptions.append(exc) self.cancel_scope.cancel() else: task_status_future.set_exception(exc) elif task_status_future is not None and not task_status_future.done(): task_status_future.set_exception( RuntimeError("Child exited without calling task_status.started()") ) if not self._active: raise RuntimeError( "This task group is not active; no new tasks can be started." ) options: dict[str, Any] = {} name = get_callable_name(func) if name is None else str(name) if _native_task_names: options["name"] = name kwargs = {} if task_status_future: parent_id = id(current_task()) kwargs["task_status"] = _AsyncioTaskStatus( task_status_future, id(self.cancel_scope._host_task) ) else: parent_id = id(self.cancel_scope._host_task) coro = func(*args, **kwargs) if not asyncio.iscoroutine(coro): raise TypeError( f"Expected an async function, but {func} appears to be synchronous" ) foreign_coro = not hasattr(coro, "cr_frame") and not hasattr(coro, "gi_frame") if foreign_coro or sys.version_info < (3, 8): coro = self._run_wrapped_task(coro, task_status_future) task = create_task(coro, **options) if not foreign_coro and sys.version_info >= (3, 8): task.add_done_callback(task_done) # Make the spawned task inherit the task group's cancel scope _task_states[task] = TaskState( parent_id=parent_id, name=name, cancel_scope=self.cancel_scope ) self.cancel_scope._tasks.add(task) return task def start_soon( self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None ) -> None: self._spawn(func, args, name) async def start( self, func: Callable[..., Awaitable[Any]], *args: object, name: object = None ) -> None: future: asyncio.Future = asyncio.Future() task = self._spawn(func, args, name, future) # If the task raises an exception after sending a start value without a switch point # between, the task group is cancelled and this method never proceeds to process the # completed future. That's why we have to have a shielded cancel scope here. with CancelScope(shield=True): try: return await future except CancelledError: task.cancel() raise # # Threads # _Retval_Queue_Type = Tuple[Optional[T_Retval], Optional[BaseException]] class WorkerThread(Thread): MAX_IDLE_TIME = 10 # seconds def __init__( self, root_task: asyncio.Task, workers: set[WorkerThread], idle_workers: deque[WorkerThread], ): super().__init__(name="AnyIO worker thread") self.root_task = root_task self.workers = workers self.idle_workers = idle_workers self.loop = root_task._loop self.queue: Queue[ tuple[Context, Callable, tuple, asyncio.Future] | None ] = Queue(2) self.idle_since = current_time() self.stopping = False def _report_result( self, future: asyncio.Future, result: Any, exc: BaseException | None ) -> None: self.idle_since = current_time() if not self.stopping: self.idle_workers.append(self) if not future.cancelled(): if exc is not None: if isinstance(exc, StopIteration): new_exc = RuntimeError("coroutine raised StopIteration") new_exc.__cause__ = exc exc = new_exc future.set_exception(exc) else: future.set_result(result) def run(self) -> None: with claim_worker_thread("asyncio"): threadlocals.loop = self.loop while True: item = self.queue.get() if item is None: # Shutdown command received return context, func, args, future = item if not future.cancelled(): result = None exception: BaseException | None = None try: result = context.run(func, *args) except BaseException as exc: exception = exc if not self.loop.is_closed(): self.loop.call_soon_threadsafe( self._report_result, future, result, exception ) self.queue.task_done() def stop(self, f: asyncio.Task | None = None) -> None: self.stopping = True self.queue.put_nowait(None) self.workers.discard(self) try: self.idle_workers.remove(self) except ValueError: pass _threadpool_idle_workers: RunVar[deque[WorkerThread]] = RunVar( "_threadpool_idle_workers" ) _threadpool_workers: RunVar[set[WorkerThread]] = RunVar("_threadpool_workers") async def run_sync_in_worker_thread( func: Callable[..., T_Retval], *args: object, cancellable: bool = False, limiter: CapacityLimiter | None = None, ) -> T_Retval: await checkpoint() # If this is the first run in this event loop thread, set up the necessary variables try: idle_workers = _threadpool_idle_workers.get() workers = _threadpool_workers.get() except LookupError: idle_workers = deque() workers = set() _threadpool_idle_workers.set(idle_workers) _threadpool_workers.set(workers) async with (limiter or current_default_thread_limiter()): with CancelScope(shield=not cancellable): future: asyncio.Future = asyncio.Future() root_task = find_root_task() if not idle_workers: worker = WorkerThread(root_task, workers, idle_workers) worker.start() workers.add(worker) root_task.add_done_callback(worker.stop) else: worker = idle_workers.pop() # Prune any other workers that have been idle for MAX_IDLE_TIME seconds or longer now = current_time() while idle_workers: if now - idle_workers[0].idle_since < WorkerThread.MAX_IDLE_TIME: break expired_worker = idle_workers.popleft() expired_worker.root_task.remove_done_callback(expired_worker.stop) expired_worker.stop() context = copy_context() context.run(sniffio.current_async_library_cvar.set, None) worker.queue.put_nowait((context, func, args, future)) return await future def run_sync_from_thread( func: Callable[..., T_Retval], *args: object, loop: asyncio.AbstractEventLoop | None = None, ) -> T_Retval: @wraps(func) def wrapper() -> None: try: f.set_result(func(*args)) except BaseException as exc: f.set_exception(exc) if not isinstance(exc, Exception): raise f: concurrent.futures.Future[T_Retval] = Future() loop = loop or threadlocals.loop loop.call_soon_threadsafe(wrapper) return f.result() def run_async_from_thread( func: Callable[..., Awaitable[T_Retval]], *args: object ) -> T_Retval: f: concurrent.futures.Future[T_Retval] = asyncio.run_coroutine_threadsafe( func(*args), threadlocals.loop ) return f.result() class BlockingPortal(abc.BlockingPortal): def __new__(cls) -> BlockingPortal: return object.__new__(cls) def __init__(self) -> None: super().__init__() self._loop = get_running_loop() def _spawn_task_from_thread( self, func: Callable, args: tuple, kwargs: dict[str, Any], name: object, future: Future, ) -> None: run_sync_from_thread( partial(self._task_group.start_soon, name=name), self._call_func, func, args, kwargs, future, loop=self._loop, ) # # Subprocesses # @dataclass(eq=False) class StreamReaderWrapper(abc.ByteReceiveStream): _stream: asyncio.StreamReader async def receive(self, max_bytes: int = 65536) -> bytes: data = await self._stream.read(max_bytes) if data: return data else: raise EndOfStream async def aclose(self) -> None: self._stream.feed_eof() @dataclass(eq=False) class StreamWriterWrapper(abc.ByteSendStream): _stream: asyncio.StreamWriter async def send(self, item: bytes) -> None: self._stream.write(item) await self._stream.drain() async def aclose(self) -> None: self._stream.close() @dataclass(eq=False) class Process(abc.Process): _process: asyncio.subprocess.Process _stdin: StreamWriterWrapper | None _stdout: StreamReaderWrapper | None _stderr: StreamReaderWrapper | None async def aclose(self) -> None: if self._stdin: await self._stdin.aclose() if self._stdout: await self._stdout.aclose() if self._stderr: await self._stderr.aclose() await self.wait() async def wait(self) -> int: return await self._process.wait() def terminate(self) -> None: self._process.terminate() def kill(self) -> None: self._process.kill() def send_signal(self, signal: int) -> None: self._process.send_signal(signal) @property def pid(self) -> int: return self._process.pid @property def returncode(self) -> int | None: return self._process.returncode @property def stdin(self) -> abc.ByteSendStream | None: return self._stdin @property def stdout(self) -> abc.ByteReceiveStream | None: return self._stdout @property def stderr(self) -> abc.ByteReceiveStream | None: return self._stderr async def open_process( command: str | bytes | Sequence[str | bytes], *, shell: bool, stdin: int | IO[Any] | None, stdout: int | IO[Any] | None, stderr: int | IO[Any] | None, cwd: str | bytes | PathLike | None = None, env: Mapping[str, str] | None = None, start_new_session: bool = False, ) -> Process: await checkpoint() if shell: process = await asyncio.create_subprocess_shell( cast(Union[str, bytes], command), stdin=stdin, stdout=stdout, stderr=stderr, cwd=cwd, env=env, start_new_session=start_new_session, ) else: process = await asyncio.create_subprocess_exec( *command, stdin=stdin, stdout=stdout, stderr=stderr, cwd=cwd, env=env, start_new_session=start_new_session, ) stdin_stream = StreamWriterWrapper(process.stdin) if process.stdin else None stdout_stream = StreamReaderWrapper(process.stdout) if process.stdout else None stderr_stream = StreamReaderWrapper(process.stderr) if process.stderr else None return Process(process, stdin_stream, stdout_stream, stderr_stream) def _forcibly_shutdown_process_pool_on_exit( workers: set[Process], _task: object ) -> None: """ Forcibly shuts down worker processes belonging to this event loop.""" child_watcher: asyncio.AbstractChildWatcher | None try: child_watcher = asyncio.get_event_loop_policy().get_child_watcher() except NotImplementedError: child_watcher = None # Close as much as possible (w/o async/await) to avoid warnings for process in workers: if process.returncode is None: continue process._stdin._stream._transport.close() # type: ignore[union-attr] process._stdout._stream._transport.close() # type: ignore[union-attr] process._stderr._stream._transport.close() # type: ignore[union-attr] process.kill() if child_watcher: child_watcher.remove_child_handler(process.pid) async def _shutdown_process_pool_on_exit(workers: set[Process]) -> None: """ Shuts down worker processes belonging to this event loop. NOTE: this only works when the event loop was started using asyncio.run() or anyio.run(). """ process: Process try: await sleep(math.inf) except asyncio.CancelledError: for process in workers: if process.returncode is None: process.kill() for process in workers: await process.aclose() def setup_process_pool_exit_at_shutdown(workers: set[Process]) -> None: kwargs: dict[str, Any] = ( {"name": "AnyIO process pool shutdown task"} if _native_task_names else {} ) create_task(_shutdown_process_pool_on_exit(workers), **kwargs) find_root_task().add_done_callback( partial(_forcibly_shutdown_process_pool_on_exit, workers) ) # # Sockets and networking # class StreamProtocol(asyncio.Protocol): read_queue: deque[bytes] read_event: asyncio.Event write_event: asyncio.Event exception: Exception | None = None def connection_made(self, transport: asyncio.BaseTransport) -> None: self.read_queue = deque() self.read_event = asyncio.Event() self.write_event = asyncio.Event() self.write_event.set() cast(asyncio.Transport, transport).set_write_buffer_limits(0) def connection_lost(self, exc: Exception | None) -> None: if exc: self.exception = BrokenResourceError() self.exception.__cause__ = exc self.read_event.set() self.write_event.set() def data_received(self, data: bytes) -> None: self.read_queue.append(data) self.read_event.set() def eof_received(self) -> bool | None: self.read_event.set() return True def pause_writing(self) -> None: self.write_event = asyncio.Event() def resume_writing(self) -> None: self.write_event.set() class DatagramProtocol(asyncio.DatagramProtocol): read_queue: deque[tuple[bytes, IPSockAddrType]] read_event: asyncio.Event write_event: asyncio.Event exception: Exception | None = None def connection_made(self, transport: asyncio.BaseTransport) -> None: self.read_queue = deque(maxlen=100) # arbitrary value self.read_event = asyncio.Event() self.write_event = asyncio.Event() self.write_event.set() def connection_lost(self, exc: Exception | None) -> None: self.read_event.set() self.write_event.set() def datagram_received(self, data: bytes, addr: IPSockAddrType) -> None: addr = convert_ipv6_sockaddr(addr) self.read_queue.append((data, addr)) self.read_event.set() def error_received(self, exc: Exception) -> None: self.exception = exc def pause_writing(self) -> None: self.write_event.clear() def resume_writing(self) -> None: self.write_event.set() class SocketStream(abc.SocketStream): def __init__(self, transport: asyncio.Transport, protocol: StreamProtocol): self._transport = transport self._protocol = protocol self._receive_guard = ResourceGuard("reading from") self._send_guard = ResourceGuard("writing to") self._closed = False @property def _raw_socket(self) -> socket.socket: return self._transport.get_extra_info("socket") async def receive(self, max_bytes: int = 65536) -> bytes: with self._receive_guard: await checkpoint() if ( not self._protocol.read_event.is_set() and not self._transport.is_closing() ): self._transport.resume_reading() await self._protocol.read_event.wait() self._transport.pause_reading() try: chunk = self._protocol.read_queue.popleft() except IndexError: if self._closed: raise ClosedResourceError from None elif self._protocol.exception: raise self._protocol.exception else: raise EndOfStream from None if len(chunk) > max_bytes: # Split the oversized chunk chunk, leftover = chunk[:max_bytes], chunk[max_bytes:] self._protocol.read_queue.appendleft(leftover) # If the read queue is empty, clear the flag so that the next call will block until # data is available if not self._protocol.read_queue: self._protocol.read_event.clear() return chunk async def send(self, item: bytes) -> None: with self._send_guard: await checkpoint() if self._closed: raise ClosedResourceError elif self._protocol.exception is not None: raise self._protocol.exception try: self._transport.write(item) except RuntimeError as exc: if self._transport.is_closing(): raise BrokenResourceError from exc else: raise await self._protocol.write_event.wait() async def send_eof(self) -> None: try: self._transport.write_eof() except OSError: pass async def aclose(self) -> None: if not self._transport.is_closing(): self._closed = True try: self._transport.write_eof() except OSError: pass self._transport.close() await sleep(0) self._transport.abort() class UNIXSocketStream(abc.SocketStream): _receive_future: asyncio.Future | None = None _send_future: asyncio.Future | None = None _closing = False def __init__(self, raw_socket: socket.socket): self.__raw_socket = raw_socket self._loop = get_running_loop() self._receive_guard = ResourceGuard("reading from") self._send_guard = ResourceGuard("writing to") @property def _raw_socket(self) -> socket.socket: return self.__raw_socket def _wait_until_readable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: def callback(f: object) -> None: del self._receive_future loop.remove_reader(self.__raw_socket) f = self._receive_future = asyncio.Future() self._loop.add_reader(self.__raw_socket, f.set_result, None) f.add_done_callback(callback) return f def _wait_until_writable(self, loop: asyncio.AbstractEventLoop) -> asyncio.Future: def callback(f: object) -> None: del self._send_future loop.remove_writer(self.__raw_socket) f = self._send_future = asyncio.Future() self._loop.add_writer(self.__raw_socket, f.set_result, None) f.add_done_callback(callback) return f async def send_eof(self) -> None: with self._send_guard: self._raw_socket.shutdown(socket.SHUT_WR) async def receive(self, max_bytes: int = 65536) -> bytes: loop = get_running_loop() await checkpoint() with self._receive_guard: while True: try: data = self.__raw_socket.recv(max_bytes) except BlockingIOError: await self._wait_until_readable(loop) except OSError as exc: if self._closing: raise ClosedResourceError from None else: raise BrokenResourceError from exc else: if not data: raise EndOfStream return data async def send(self, item: bytes) -> None: loop = get_running_loop() await checkpoint() with self._send_guard: view = memoryview(item) while view: try: bytes_sent = self.__raw_socket.send(item) except BlockingIOError: await self._wait_until_writable(loop) except OSError as exc: if self._closing: raise ClosedResourceError from None else: raise BrokenResourceError from exc else: view = view[bytes_sent:] async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: if not isinstance(msglen, int) or msglen < 0: raise ValueError("msglen must be a non-negative integer") if not isinstance(maxfds, int) or maxfds < 1: raise ValueError("maxfds must be a positive integer") loop = get_running_loop() fds = array.array("i") await checkpoint() with self._receive_guard: while True: try: message, ancdata, flags, addr = self.__raw_socket.recvmsg( msglen, socket.CMSG_LEN(maxfds * fds.itemsize) ) except BlockingIOError: await self._wait_until_readable(loop) except OSError as exc: if self._closing: raise ClosedResourceError from None else: raise BrokenResourceError from exc else: if not message and not ancdata: raise EndOfStream break for cmsg_level, cmsg_type, cmsg_data in ancdata: if cmsg_level != socket.SOL_SOCKET or cmsg_type != socket.SCM_RIGHTS: raise RuntimeError( f"Received unexpected ancillary data; message = {message!r}, " f"cmsg_level = {cmsg_level}, cmsg_type = {cmsg_type}" ) fds.frombytes(cmsg_data[: len(cmsg_data) - (len(cmsg_data) % fds.itemsize)]) return message, list(fds) async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: if not message: raise ValueError("message must not be empty") if not fds: raise ValueError("fds must not be empty") loop = get_running_loop() filenos: list[int] = [] for fd in fds: if isinstance(fd, int): filenos.append(fd) elif isinstance(fd, IOBase): filenos.append(fd.fileno()) fdarray = array.array("i", filenos) await checkpoint() with self._send_guard: while True: try: # The ignore can be removed after mypy picks up # https://github.com/python/typeshed/pull/5545 self.__raw_socket.sendmsg( [message], [(socket.SOL_SOCKET, socket.SCM_RIGHTS, fdarray)] ) break except BlockingIOError: await self._wait_until_writable(loop) except OSError as exc: if self._closing: raise ClosedResourceError from None else: raise BrokenResourceError from exc async def aclose(self) -> None: if not self._closing: self._closing = True if self.__raw_socket.fileno() != -1: self.__raw_socket.close() if self._receive_future: self._receive_future.set_result(None) if self._send_future: self._send_future.set_result(None) class TCPSocketListener(abc.SocketListener): _accept_scope: CancelScope | None = None _closed = False def __init__(self, raw_socket: socket.socket): self.__raw_socket = raw_socket self._loop = cast(asyncio.BaseEventLoop, get_running_loop()) self._accept_guard = ResourceGuard("accepting connections from") @property def _raw_socket(self) -> socket.socket: return self.__raw_socket async def accept(self) -> abc.SocketStream: if self._closed: raise ClosedResourceError with self._accept_guard: await checkpoint() with CancelScope() as self._accept_scope: try: client_sock, _addr = await self._loop.sock_accept(self._raw_socket) except asyncio.CancelledError: # Workaround for https://bugs.python.org/issue41317 try: self._loop.remove_reader(self._raw_socket) except (ValueError, NotImplementedError): pass if self._closed: raise ClosedResourceError from None raise finally: self._accept_scope = None client_sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, 1) transport, protocol = await self._loop.connect_accepted_socket( StreamProtocol, client_sock ) return SocketStream(transport, protocol) async def aclose(self) -> None: if self._closed: return self._closed = True if self._accept_scope: # Workaround for https://bugs.python.org/issue41317 try: self._loop.remove_reader(self._raw_socket) except (ValueError, NotImplementedError): pass self._accept_scope.cancel() await sleep(0) self._raw_socket.close() class UNIXSocketListener(abc.SocketListener): def __init__(self, raw_socket: socket.socket): self.__raw_socket = raw_socket self._loop = get_running_loop() self._accept_guard = ResourceGuard("accepting connections from") self._closed = False async def accept(self) -> abc.SocketStream: await checkpoint() with self._accept_guard: while True: try: client_sock, _ = self.__raw_socket.accept() client_sock.setblocking(False) return UNIXSocketStream(client_sock) except BlockingIOError: f: asyncio.Future = asyncio.Future() self._loop.add_reader(self.__raw_socket, f.set_result, None) f.add_done_callback( lambda _: self._loop.remove_reader(self.__raw_socket) ) await f except OSError as exc: if self._closed: raise ClosedResourceError from None else: raise BrokenResourceError from exc async def aclose(self) -> None: self._closed = True self.__raw_socket.close() @property def _raw_socket(self) -> socket.socket: return self.__raw_socket class UDPSocket(abc.UDPSocket): def __init__( self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol ): self._transport = transport self._protocol = protocol self._receive_guard = ResourceGuard("reading from") self._send_guard = ResourceGuard("writing to") self._closed = False @property def _raw_socket(self) -> socket.socket: return self._transport.get_extra_info("socket") async def aclose(self) -> None: if not self._transport.is_closing(): self._closed = True self._transport.close() async def receive(self) -> tuple[bytes, IPSockAddrType]: with self._receive_guard: await checkpoint() # If the buffer is empty, ask for more data if not self._protocol.read_queue and not self._transport.is_closing(): self._protocol.read_event.clear() await self._protocol.read_event.wait() try: return self._protocol.read_queue.popleft() except IndexError: if self._closed: raise ClosedResourceError from None else: raise BrokenResourceError from None async def send(self, item: UDPPacketType) -> None: with self._send_guard: await checkpoint() await self._protocol.write_event.wait() if self._closed: raise ClosedResourceError elif self._transport.is_closing(): raise BrokenResourceError else: self._transport.sendto(*item) class ConnectedUDPSocket(abc.ConnectedUDPSocket): def __init__( self, transport: asyncio.DatagramTransport, protocol: DatagramProtocol ): self._transport = transport self._protocol = protocol self._receive_guard = ResourceGuard("reading from") self._send_guard = ResourceGuard("writing to") self._closed = False @property def _raw_socket(self) -> socket.socket: return self._transport.get_extra_info("socket") async def aclose(self) -> None: if not self._transport.is_closing(): self._closed = True self._transport.close() async def receive(self) -> bytes: with self._receive_guard: await checkpoint() # If the buffer is empty, ask for more data if not self._protocol.read_queue and not self._transport.is_closing(): self._protocol.read_event.clear() await self._protocol.read_event.wait() try: packet = self._protocol.read_queue.popleft() except IndexError: if self._closed: raise ClosedResourceError from None else: raise BrokenResourceError from None return packet[0] async def send(self, item: bytes) -> None: with self._send_guard: await checkpoint() await self._protocol.write_event.wait() if self._closed: raise ClosedResourceError elif self._transport.is_closing(): raise BrokenResourceError else: self._transport.sendto(item) async def connect_tcp( host: str, port: int, local_addr: tuple[str, int] | None = None ) -> SocketStream: transport, protocol = cast( Tuple[asyncio.Transport, StreamProtocol], await get_running_loop().create_connection( StreamProtocol, host, port, local_addr=local_addr ), ) transport.pause_reading() return SocketStream(transport, protocol) async def connect_unix(path: str) -> UNIXSocketStream: await checkpoint() loop = get_running_loop() raw_socket = socket.socket(socket.AF_UNIX) raw_socket.setblocking(False) while True: try: raw_socket.connect(path) except BlockingIOError: f: asyncio.Future = asyncio.Future() loop.add_writer(raw_socket, f.set_result, None) f.add_done_callback(lambda _: loop.remove_writer(raw_socket)) await f except BaseException: raw_socket.close() raise else: return UNIXSocketStream(raw_socket) async def create_udp_socket( family: socket.AddressFamily, local_address: IPSockAddrType | None, remote_address: IPSockAddrType | None, reuse_port: bool, ) -> UDPSocket | ConnectedUDPSocket: result = await get_running_loop().create_datagram_endpoint( DatagramProtocol, local_addr=local_address, remote_addr=remote_address, family=family, reuse_port=reuse_port, ) transport = result[0] protocol = result[1] if protocol.exception: transport.close() raise protocol.exception if not remote_address: return UDPSocket(transport, protocol) else: return ConnectedUDPSocket(transport, protocol) async def getaddrinfo( host: bytes | str, port: str | int | None, *, family: int | AddressFamily = 0, type: int | SocketKind = 0, proto: int = 0, flags: int = 0, ) -> GetAddrInfoReturnType: # https://github.com/python/typeshed/pull/4304 result = await get_running_loop().getaddrinfo( host, port, family=family, type=type, proto=proto, flags=flags ) return cast(GetAddrInfoReturnType, result) async def getnameinfo(sockaddr: IPSockAddrType, flags: int = 0) -> tuple[str, str]: return await get_running_loop().getnameinfo(sockaddr, flags) _read_events: RunVar[dict[Any, asyncio.Event]] = RunVar("read_events") _write_events: RunVar[dict[Any, asyncio.Event]] = RunVar("write_events") async def wait_socket_readable(sock: socket.socket) -> None: await checkpoint() try: read_events = _read_events.get() except LookupError: read_events = {} _read_events.set(read_events) if read_events.get(sock): raise BusyResourceError("reading from") from None loop = get_running_loop() event = read_events[sock] = asyncio.Event() loop.add_reader(sock, event.set) try: await event.wait() finally: if read_events.pop(sock, None) is not None: loop.remove_reader(sock) readable = True else: readable = False if not readable: raise ClosedResourceError async def wait_socket_writable(sock: socket.socket) -> None: await checkpoint() try: write_events = _write_events.get() except LookupError: write_events = {} _write_events.set(write_events) if write_events.get(sock): raise BusyResourceError("writing to") from None loop = get_running_loop() event = write_events[sock] = asyncio.Event() loop.add_writer(sock.fileno(), event.set) try: await event.wait() finally: if write_events.pop(sock, None) is not None: loop.remove_writer(sock) writable = True else: writable = False if not writable: raise ClosedResourceError # # Synchronization # class Event(BaseEvent): def __new__(cls) -> Event: return object.__new__(cls) def __init__(self) -> None: self._event = asyncio.Event() def set(self) -> DeprecatedAwaitable: self._event.set() return DeprecatedAwaitable(self.set) def is_set(self) -> bool: return self._event.is_set() async def wait(self) -> None: if await self._event.wait(): await checkpoint() def statistics(self) -> EventStatistics: return EventStatistics(len(self._event._waiters)) # type: ignore[attr-defined] class CapacityLimiter(BaseCapacityLimiter): _total_tokens: float = 0 def __new__(cls, total_tokens: float) -> CapacityLimiter: return object.__new__(cls) def __init__(self, total_tokens: float): self._borrowers: set[Any] = set() self._wait_queue: OrderedDict[Any, asyncio.Event] = OrderedDict() self.total_tokens = total_tokens async def __aenter__(self) -> None: await self.acquire() async def __aexit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> None: self.release() @property def total_tokens(self) -> float: return self._total_tokens @total_tokens.setter def total_tokens(self, value: float) -> None: if not isinstance(value, int) and not math.isinf(value): raise TypeError("total_tokens must be an int or math.inf") if value < 1: raise ValueError("total_tokens must be >= 1") old_value = self._total_tokens self._total_tokens = value events = [] for event in self._wait_queue.values(): if value <= old_value: break if not event.is_set(): events.append(event) old_value += 1 for event in events: event.set() @property def borrowed_tokens(self) -> int: return len(self._borrowers) @property def available_tokens(self) -> float: return self._total_tokens - len(self._borrowers) def acquire_nowait(self) -> DeprecatedAwaitable: self.acquire_on_behalf_of_nowait(current_task()) return DeprecatedAwaitable(self.acquire_nowait) def acquire_on_behalf_of_nowait(self, borrower: object) -> DeprecatedAwaitable: if borrower in self._borrowers: raise RuntimeError( "this borrower is already holding one of this CapacityLimiter's " "tokens" ) if self._wait_queue or len(self._borrowers) >= self._total_tokens: raise WouldBlock self._borrowers.add(borrower) return DeprecatedAwaitable(self.acquire_on_behalf_of_nowait) async def acquire(self) -> None: return await self.acquire_on_behalf_of(current_task()) async def acquire_on_behalf_of(self, borrower: object) -> None: await checkpoint_if_cancelled() try: self.acquire_on_behalf_of_nowait(borrower) except WouldBlock: event = asyncio.Event() self._wait_queue[borrower] = event try: await event.wait() except BaseException: self._wait_queue.pop(borrower, None) raise self._borrowers.add(borrower) else: try: await cancel_shielded_checkpoint() except BaseException: self.release() raise def release(self) -> None: self.release_on_behalf_of(current_task()) def release_on_behalf_of(self, borrower: object) -> None: try: self._borrowers.remove(borrower) except KeyError: raise RuntimeError( "this borrower isn't holding any of this CapacityLimiter's " "tokens" ) from None # Notify the next task in line if this limiter has free capacity now if self._wait_queue and len(self._borrowers) < self._total_tokens: event = self._wait_queue.popitem(last=False)[1] event.set() def statistics(self) -> CapacityLimiterStatistics: return CapacityLimiterStatistics( self.borrowed_tokens, self.total_tokens, tuple(self._borrowers), len(self._wait_queue), ) _default_thread_limiter: RunVar[CapacityLimiter] = RunVar("_default_thread_limiter") def current_default_thread_limiter() -> CapacityLimiter: try: return _default_thread_limiter.get() except LookupError: limiter = CapacityLimiter(40) _default_thread_limiter.set(limiter) return limiter # # Operating system signals # class _SignalReceiver(DeprecatedAsyncContextManager["_SignalReceiver"]): def __init__(self, signals: tuple[int, ...]): self._signals = signals self._loop = get_running_loop() self._signal_queue: deque[int] = deque() self._future: asyncio.Future = asyncio.Future() self._handled_signals: set[int] = set() def _deliver(self, signum: int) -> None: self._signal_queue.append(signum) if not self._future.done(): self._future.set_result(None) def __enter__(self) -> _SignalReceiver: for sig in set(self._signals): self._loop.add_signal_handler(sig, self._deliver, sig) self._handled_signals.add(sig) return self def __exit__( self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None, ) -> bool | None: for sig in self._handled_signals: self._loop.remove_signal_handler(sig) return None def __aiter__(self) -> _SignalReceiver: return self async def __anext__(self) -> int: await checkpoint() if not self._signal_queue: self._future = asyncio.Future() await self._future return self._signal_queue.popleft() def open_signal_receiver(*signals: int) -> _SignalReceiver: return _SignalReceiver(signals) # # Testing and debugging # def _create_task_info(task: asyncio.Task) -> TaskInfo: task_state = _task_states.get(task) if task_state is None: name = task.get_name() if _native_task_names else None parent_id = None else: name = task_state.name parent_id = task_state.parent_id return TaskInfo(id(task), parent_id, name, get_coro(task)) def get_current_task() -> TaskInfo: return _create_task_info(current_task()) # type: ignore[arg-type] def get_running_tasks() -> list[TaskInfo]: return [_create_task_info(task) for task in all_tasks() if not task.done()] async def wait_all_tasks_blocked() -> None: await checkpoint() this_task = current_task() while True: for task in all_tasks(): if task is this_task: continue if task._fut_waiter is None or task._fut_waiter.done(): # type: ignore[attr-defined] await sleep(0.1) break else: return class TestRunner(abc.TestRunner): def __init__( self, debug: bool = False, use_uvloop: bool = False, policy: asyncio.AbstractEventLoopPolicy | None = None, ): self._exceptions: list[BaseException] = [] _maybe_set_event_loop_policy(policy, use_uvloop) self._loop = asyncio.new_event_loop() self._loop.set_debug(debug) self._loop.set_exception_handler(self._exception_handler) asyncio.set_event_loop(self._loop) def _cancel_all_tasks(self) -> None: to_cancel = all_tasks(self._loop) if not to_cancel: return for task in to_cancel: task.cancel() self._loop.run_until_complete( asyncio.gather(*to_cancel, return_exceptions=True) ) for task in to_cancel: if task.cancelled(): continue if task.exception() is not None: raise cast(BaseException, task.exception()) def _exception_handler( self, loop: asyncio.AbstractEventLoop, context: dict[str, Any] ) -> None: if isinstance(context.get("exception"), Exception): self._exceptions.append(context["exception"]) else: loop.default_exception_handler(context) def _raise_async_exceptions(self) -> None: # Re-raise any exceptions raised in asynchronous callbacks if self._exceptions: exceptions, self._exceptions = self._exceptions, [] if len(exceptions) == 1: raise exceptions[0] elif exceptions: raise ExceptionGroup(exceptions) def close(self) -> None: try: self._cancel_all_tasks() self._loop.run_until_complete(self._loop.shutdown_asyncgens()) finally: asyncio.set_event_loop(None) self._loop.close() def run_asyncgen_fixture( self, fixture_func: Callable[..., AsyncGenerator[T_Retval, Any]], kwargs: dict[str, Any], ) -> Iterable[T_Retval]: async def fixture_runner() -> None: agen = fixture_func(**kwargs) try: retval = await agen.asend(None) self._raise_async_exceptions() except BaseException as exc: f.set_exception(exc) return else: f.set_result(retval) await event.wait() try: await agen.asend(None) except StopAsyncIteration: pass else: await agen.aclose() raise RuntimeError("Async generator fixture did not stop") f = self._loop.create_future() event = asyncio.Event() fixture_task = self._loop.create_task(fixture_runner()) self._loop.run_until_complete(f) yield f.result() event.set() self._loop.run_until_complete(fixture_task) self._raise_async_exceptions() def run_fixture( self, fixture_func: Callable[..., Coroutine[Any, Any, T_Retval]], kwargs: dict[str, Any], ) -> T_Retval: retval = self._loop.run_until_complete(fixture_func(**kwargs)) self._raise_async_exceptions() return retval def run_test( self, test_func: Callable[..., Coroutine[Any, Any, Any]], kwargs: dict[str, Any] ) -> None: try: self._loop.run_until_complete(test_func(**kwargs)) except Exception as exc: self._exceptions.append(exc) self._raise_async_exceptions()