Spaces:
Runtime error
Runtime error
from __future__ import annotations | |
import socket | |
from abc import abstractmethod | |
from contextlib import AsyncExitStack | |
from io import IOBase | |
from ipaddress import IPv4Address, IPv6Address | |
from socket import AddressFamily | |
from typing import ( | |
Any, | |
Callable, | |
Collection, | |
Mapping, | |
Tuple, | |
TypeVar, | |
Union, | |
) | |
from .._core._tasks import create_task_group | |
from .._core._typedattr import ( | |
TypedAttributeProvider, | |
TypedAttributeSet, | |
typed_attribute, | |
) | |
from ._streams import ByteStream, Listener, UnreliableObjectStream | |
from ._tasks import TaskGroup | |
IPAddressType = Union[str, IPv4Address, IPv6Address] | |
IPSockAddrType = Tuple[str, int] | |
SockAddrType = Union[IPSockAddrType, str] | |
UDPPacketType = Tuple[bytes, IPSockAddrType] | |
T_Retval = TypeVar("T_Retval") | |
class SocketAttribute(TypedAttributeSet): | |
#: the address family of the underlying socket | |
family: AddressFamily = typed_attribute() | |
#: the local socket address of the underlying socket | |
local_address: SockAddrType = typed_attribute() | |
#: for IP addresses, the local port the underlying socket is bound to | |
local_port: int = typed_attribute() | |
#: the underlying stdlib socket object | |
raw_socket: socket.socket = typed_attribute() | |
#: the remote address the underlying socket is connected to | |
remote_address: SockAddrType = typed_attribute() | |
#: for IP addresses, the remote port the underlying socket is connected to | |
remote_port: int = typed_attribute() | |
class _SocketProvider(TypedAttributeProvider): | |
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: | |
from .._core._sockets import convert_ipv6_sockaddr as convert | |
attributes: dict[Any, Callable[[], Any]] = { | |
SocketAttribute.family: lambda: self._raw_socket.family, | |
SocketAttribute.local_address: lambda: convert( | |
self._raw_socket.getsockname() | |
), | |
SocketAttribute.raw_socket: lambda: self._raw_socket, | |
} | |
try: | |
peername: tuple[str, int] | None = convert(self._raw_socket.getpeername()) | |
except OSError: | |
peername = None | |
# Provide the remote address for connected sockets | |
if peername is not None: | |
attributes[SocketAttribute.remote_address] = lambda: peername | |
# Provide local and remote ports for IP based sockets | |
if self._raw_socket.family in (AddressFamily.AF_INET, AddressFamily.AF_INET6): | |
attributes[ | |
SocketAttribute.local_port | |
] = lambda: self._raw_socket.getsockname()[1] | |
if peername is not None: | |
remote_port = peername[1] | |
attributes[SocketAttribute.remote_port] = lambda: remote_port | |
return attributes | |
def _raw_socket(self) -> socket.socket: | |
pass | |
class SocketStream(ByteStream, _SocketProvider): | |
""" | |
Transports bytes over a socket. | |
Supports all relevant extra attributes from :class:`~SocketAttribute`. | |
""" | |
class UNIXSocketStream(SocketStream): | |
async def send_fds(self, message: bytes, fds: Collection[int | IOBase]) -> None: | |
""" | |
Send file descriptors along with a message to the peer. | |
:param message: a non-empty bytestring | |
:param fds: a collection of files (either numeric file descriptors or open file or socket | |
objects) | |
""" | |
async def receive_fds(self, msglen: int, maxfds: int) -> tuple[bytes, list[int]]: | |
""" | |
Receive file descriptors along with a message from the peer. | |
:param msglen: length of the message to expect from the peer | |
:param maxfds: maximum number of file descriptors to expect from the peer | |
:return: a tuple of (message, file descriptors) | |
""" | |
class SocketListener(Listener[SocketStream], _SocketProvider): | |
""" | |
Listens to incoming socket connections. | |
Supports all relevant extra attributes from :class:`~SocketAttribute`. | |
""" | |
async def accept(self) -> SocketStream: | |
"""Accept an incoming connection.""" | |
async def serve( | |
self, | |
handler: Callable[[SocketStream], Any], | |
task_group: TaskGroup | None = None, | |
) -> None: | |
async with AsyncExitStack() as exit_stack: | |
if task_group is None: | |
task_group = await exit_stack.enter_async_context(create_task_group()) | |
while True: | |
stream = await self.accept() | |
task_group.start_soon(handler, stream) | |
class UDPSocket(UnreliableObjectStream[UDPPacketType], _SocketProvider): | |
""" | |
Represents an unconnected UDP socket. | |
Supports all relevant extra attributes from :class:`~SocketAttribute`. | |
""" | |
async def sendto(self, data: bytes, host: str, port: int) -> None: | |
"""Alias for :meth:`~.UnreliableObjectSendStream.send` ((data, (host, port))).""" | |
return await self.send((data, (host, port))) | |
class ConnectedUDPSocket(UnreliableObjectStream[bytes], _SocketProvider): | |
""" | |
Represents an connected UDP socket. | |
Supports all relevant extra attributes from :class:`~SocketAttribute`. | |
""" | |