From ec58a4cce7e6aa37efcc48bc2cdd974f7ed02615 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francis=20Clairicia-Rose-Claire-Jos=C3=A9phine?= Date: Thu, 9 May 2024 13:28:09 +0200 Subject: [PATCH] Documentation: Fixed API docstrings (#292) --- docs/source/_static/css/rtfd.css | 4 + docs/source/api/lowlevel/async/servers.rst | 8 ++ docs/source/api/lowlevel/sync/transports.rst | 2 +- docs/source/conf.py | 39 +++++-- docs/source/index.rst | 2 + pdm.lock | 18 +-- pyproject.toml | 10 +- src/easynetwork/clients/async_tcp.py | 12 +- src/easynetwork/clients/tcp.py | 8 +- .../lowlevel/api_async/backend/abc.py | 14 +-- .../lowlevel/api_async/endpoints/datagram.py | 57 +++++----- .../lowlevel/api_async/endpoints/stream.py | 61 +++++----- .../lowlevel/api_async/servers/datagram.py | 63 ++++++++--- .../lowlevel/api_async/servers/stream.py | 107 ++++++++++++------ .../lowlevel/api_async/transports/tls.py | 88 ++++++++------ .../lowlevel/api_sync/endpoints/datagram.py | 51 +++++---- .../lowlevel/api_sync/endpoints/stream.py | 55 +++++---- .../api_sync/transports/base_selector.py | 19 +++- .../lowlevel/api_sync/transports/socket.py | 54 ++++++++- src/easynetwork/lowlevel/socket.py | 8 +- src/easynetwork/serializers/pickle.py | 10 +- src/easynetwork/serializers/struct.py | 12 +- .../serializers/wrapper/compressor.py | 12 +- src/easynetwork/servers/async_tcp.py | 17 ++- src/easynetwork/servers/misc.py | 6 +- .../test_servers/test_stream.py | 14 +-- .../test_transports/test_selector.py | 2 +- 27 files changed, 481 insertions(+), 272 deletions(-) create mode 100644 docs/source/_static/css/rtfd.css diff --git a/docs/source/_static/css/rtfd.css b/docs/source/_static/css/rtfd.css new file mode 100644 index 00000000..4bd06cb6 --- /dev/null +++ b/docs/source/_static/css/rtfd.css @@ -0,0 +1,4 @@ +/* Force content to use more spaces than default (800px) */ +.wy-nav-content { + max-width: 1345px; +} diff --git a/docs/source/api/lowlevel/async/servers.rst b/docs/source/api/lowlevel/async/servers.rst index 1a2c4433..1bb65c22 100644 --- a/docs/source/api/lowlevel/async/servers.rst +++ b/docs/source/api/lowlevel/async/servers.rst @@ -9,9 +9,17 @@ Stream Servers :members: :no-docstring: +.. autoclass:: easynetwork.lowlevel.api_async.servers.stream::Client() + :members: + :exclude-members: __init__ + Datagram Servers ================ .. automodule:: easynetwork.lowlevel.api_async.servers.datagram :members: :no-docstring: + +.. autoclass:: easynetwork.lowlevel.api_async.servers.datagram::DatagramClientContext() + :members: + :exclude-members: __init__ diff --git a/docs/source/api/lowlevel/sync/transports.rst b/docs/source/api/lowlevel/sync/transports.rst index a5053cec..645c7650 100644 --- a/docs/source/api/lowlevel/sync/transports.rst +++ b/docs/source/api/lowlevel/sync/transports.rst @@ -11,7 +11,7 @@ Abstract Base Classes :no-docstring: ``selectors``-based transports ------------------------------- +============================== .. seealso:: diff --git a/docs/source/conf.py b/docs/source/conf.py index a350ed5c..086eb05b 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -8,14 +8,14 @@ import os.path import sys -from importlib.metadata import version as get_version +from importlib.metadata import version as _get_distribution_version # -- Project information ----------------------------------------------------- # https://www.sphinx-doc.org/en/master/usage/configuration.html#project-information project = "EasyNetwork" copyright = "2024, Francis Clairicia-Rose-Claire-Josephine" author = "FrankySnow9" -release = get_version("easynetwork") +release = _get_distribution_version("easynetwork") version = ".".join(release.split(".")[:3]) # -- General configuration --------------------------------------------------- @@ -66,7 +66,8 @@ # -- sphinx.ext.autodoc configuration ---------------------------------------- # https://www.sphinx-doc.org/en/master/usage/extensions/autodoc.html -autoclass_content = "both" +autoclass_content = "class" +autodoc_class_signature = "separated" autodoc_member_order = "bysource" autodoc_default_options = { "undoc-members": None, @@ -75,22 +76,25 @@ "show-inheritance": None, } autodoc_preserve_defaults = True -autodoc_typehints = "description" +autodoc_typehints = "both" autodoc_typehints_description_target = "documented_params" autodoc_type_aliases = { - "_typing_bz2.BZ2Compressor": "bz2.BZ2Compressor", - "_typing_bz2.BZ2Decompressor": "bz2.BZ2Decompressor", - "_typing_zlib._Compress": "zlib.Compress", - "_typing_zlib._Decompress": "zlib.Decompress", - "_typing_pickle.Pickler": "pickle.Pickler", - "_typing_pickle.Unpickler": "pickle.Unpickler", - "_typing_struct.Struct": "struct.Struct", - "_typing_ssl.SSLContext": "ssl.SSLContext", "_socket._RetAddress": "typing.Any", "_socket.socket": "socket.socket", + "BZ2Compressor": "bz2.BZ2Compressor", + "BZ2Decompressor": "bz2.BZ2Decompressor", "contextvars.Context": "contextvars.Context", + "MemoryBIO": "ssl.MemoryBIO", + "Pickler": "pickle.Pickler", "ReadableBuffer": "bytes | bytearray | memoryview", + "SSLContext": "ssl.SSLContext", + "SSLSession": "ssl.SSLSession", + "SSLSocket": "ssl.SSLSocket", + "Struct": "struct.Struct", + "Unpickler": "pickle.Unpickler", "WriteableBuffer": "bytearray | memoryview", + "ZLibCompress": "zlib.Compress", + "ZLibDecompress": "zlib.Decompress", } autodoc_inherit_docstrings = False autodoc_mock_imports = [ @@ -148,6 +152,7 @@ ] html_css_files = [ "css/details.css", + "css/rtfd.css", ] # -- sphinx-rtd-theme configuration ------------------------------------------ @@ -156,3 +161,13 @@ html_theme_options = { "navigation_depth": -1, # Unlimited } + + +# ----------------------------------------------------------------------------- + + +def setup(app) -> None: + import warnings + from sphinx import RemovedInNextVersionWarning + + warnings.filterwarnings("ignore", category=RemovedInNextVersionWarning, module="sphinx_toolbox.more_autodoc.autoprotocol") diff --git a/docs/source/index.rst b/docs/source/index.rst index b0822d0b..ee8aaf5e 100644 --- a/docs/source/index.rst +++ b/docs/source/index.rst @@ -21,6 +21,7 @@ Welcome to EasyNetwork's documentation! ------------ * :github:repo:`Source code ` + * `Release Notes `_ * :ref:`genindex` * :ref:`modindex` @@ -29,5 +30,6 @@ Welcome to EasyNetwork's documentation! :github: :pypi: easynetwork + Release Notes genindex modindex diff --git a/pdm.lock b/pdm.lock index a2b246d2..7a136e2e 100644 --- a/pdm.lock +++ b/pdm.lock @@ -5,7 +5,7 @@ groups = ["default", "bandit", "benchmark-servers", "benchmark-servers-deps", "build", "cbor", "coverage", "dev", "doc", "flake8", "format", "micro-benchmark", "msgpack", "mypy", "pre-commit", "sniffio", "test", "tox", "types-msgpack", "uvloop"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.1" -content_hash = "sha256:16e715e8862a21ee6e5f50fbe6b1d295febab22c93e0cf5cad2bbf33291619c6" +content_hash = "sha256:2eedd6a31781a0825df6f986c3b6026fb800a491116812d59c3d61b74740eec9" [[package]] name = "alabaster" @@ -1712,17 +1712,17 @@ files = [ [[package]] name = "sphinx" -version = "6.2.1" -requires_python = ">=3.8" +version = "7.3.7" +requires_python = ">=3.9" summary = "Python documentation generator" groups = ["doc"] dependencies = [ "Jinja2>=3.0", - "Pygments>=2.13", - "alabaster<0.8,>=0.7", + "Pygments>=2.14", + "alabaster~=0.7.14", "babel>=2.9", "colorama>=0.4.5; sys_platform == \"win32\"", - "docutils<0.20,>=0.18.1", + "docutils<0.22,>=0.18.1", "imagesize>=1.3", "packaging>=21.0", "requests>=2.25.0", @@ -1732,11 +1732,11 @@ dependencies = [ "sphinxcontrib-htmlhelp>=2.0.0", "sphinxcontrib-jsmath", "sphinxcontrib-qthelp", - "sphinxcontrib-serializinghtml>=1.1.5", + "sphinxcontrib-serializinghtml>=1.1.9", ] files = [ - {file = "Sphinx-6.2.1.tar.gz", hash = "sha256:6d56a34697bb749ffa0152feafc4b19836c755d90a7c59b72bc7dfd371b9cc6b"}, - {file = "sphinx-6.2.1-py3-none-any.whl", hash = "sha256:97787ff1fa3256a3eef9eda523a63dbf299f7b47e053cfcf684a1c2a8380c912"}, + {file = "sphinx-7.3.7-py3-none-any.whl", hash = "sha256:413f75440be4cacf328f580b4274ada4565fb2187d696a84970c23f77b64d8c3"}, + {file = "sphinx-7.3.7.tar.gz", hash = "sha256:a4a7db75ed37531c05002d56ed6948d4c42f473a36f46e1382b0bd76ca9627bc"}, ] [[package]] diff --git a/pyproject.toml b/pyproject.toml index fc07c3d4..0af66802 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -112,11 +112,11 @@ coverage = [ "coverage~=7.0", ] doc = [ - "sphinx>=6.2.1,<7", - "sphinx-rtd-theme>=1.2.2", - "sphinx-tabs>=3.4.1", - "sphinx-toolbox>=3.5.0", - "enum-tools[sphinx]>=0.10.0", + "sphinx>=7.3,<8", + "sphinx-rtd-theme>=2.0,<3", + "sphinx-tabs>=3.4.1,<4", + "sphinx-toolbox>=3.5.0,<4", + "enum-tools[sphinx]>=0.12.0,<1", ] micro-benchmark = [ "pytest~=7.4", diff --git a/src/easynetwork/clients/async_tcp.py b/src/easynetwork/clients/async_tcp.py index e2dbb114..82a97809 100644 --- a/src/easynetwork/clients/async_tcp.py +++ b/src/easynetwork/clients/async_tcp.py @@ -53,7 +53,7 @@ from .abc import AbstractAsyncNetworkClient if TYPE_CHECKING: - import ssl as _typing_ssl + from ssl import SSLContext @dataclasses.dataclass(kw_only=True, slots=True) @@ -96,7 +96,7 @@ def __init__( *, local_address: tuple[str, int] | None = ..., happy_eyeballs_delay: float | None = ..., - ssl: _typing_ssl.SSLContext | bool | None = ..., + ssl: SSLContext | bool | None = ..., server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., ssl_shutdown_timeout: float | None = ..., @@ -114,7 +114,7 @@ def __init__( protocol: StreamProtocol[_T_SentPacket, _T_ReceivedPacket], backend: AsyncBackend | BuiltinAsyncBackendToken | None = ..., *, - ssl: _typing_ssl.SSLContext | bool | None = ..., + ssl: SSLContext | bool | None = ..., server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., ssl_shutdown_timeout: float | None = ..., @@ -131,7 +131,7 @@ def __init__( protocol: StreamProtocol[_T_SentPacket, _T_ReceivedPacket], backend: AsyncBackend | BuiltinAsyncBackendToken | None = None, *, - ssl: _typing_ssl.SSLContext | bool | None = None, + ssl: SSLContext | bool | None = None, server_hostname: str | None = None, ssl_handshake_timeout: float | None = None, ssl_shutdown_timeout: float | None = None, @@ -302,7 +302,7 @@ async def __create_ssl_over_tcp_connection( backend: AsyncBackend, host: str, port: int, - ssl_context: _typing_ssl.SSLContext, + ssl_context: SSLContext, *, server_hostname: str | None, ssl_handshake_timeout: float | None, @@ -334,7 +334,7 @@ async def __create_ssl_over_tcp_connection( async def __wrap_ssl_over_stream_socket_client_side( backend: AsyncBackend, socket: _socket.socket, - ssl_context: _typing_ssl.SSLContext, + ssl_context: SSLContext, *, server_hostname: str, ssl_handshake_timeout: float | None, diff --git a/src/easynetwork/clients/tcp.py b/src/easynetwork/clients/tcp.py index ab85c84d..af5b7161 100644 --- a/src/easynetwork/clients/tcp.py +++ b/src/easynetwork/clients/tcp.py @@ -51,7 +51,7 @@ from .abc import AbstractNetworkClient if TYPE_CHECKING: - import ssl as _typing_ssl + from ssl import SSLContext class TCPNetworkClient(AbstractNetworkClient[_T_SentPacket, _T_ReceivedPacket]): @@ -75,7 +75,7 @@ def __init__( *, connect_timeout: float | None = ..., local_address: tuple[str, int] | None = ..., - ssl: _typing_ssl.SSLContext | bool | None = ..., + ssl: SSLContext | bool | None = ..., server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., ssl_shutdown_timeout: float | None = ..., @@ -93,7 +93,7 @@ def __init__( /, protocol: StreamProtocol[_T_SentPacket, _T_ReceivedPacket], *, - ssl: _typing_ssl.SSLContext | bool | None = ..., + ssl: SSLContext | bool | None = ..., server_hostname: str | None = ..., ssl_handshake_timeout: float | None = ..., ssl_shutdown_timeout: float | None = ..., @@ -110,7 +110,7 @@ def __init__( /, protocol: StreamProtocol[_T_SentPacket, _T_ReceivedPacket], *, - ssl: _typing_ssl.SSLContext | bool | None = None, + ssl: SSLContext | bool | None = None, server_hostname: str | None = None, ssl_handshake_timeout: float | None = None, ssl_shutdown_timeout: float | None = None, diff --git a/src/easynetwork/lowlevel/api_async/backend/abc.py b/src/easynetwork/lowlevel/api_async/backend/abc.py index 5bdee0f7..a89d4412 100644 --- a/src/easynetwork/lowlevel/api_async/backend/abc.py +++ b/src/easynetwork/lowlevel/api_async/backend/abc.py @@ -34,7 +34,7 @@ from collections.abc import Awaitable, Callable, Coroutine, Mapping, Sequence from contextlib import AbstractContextManager from types import TracebackType -from typing import TYPE_CHECKING, Any, Generic, NoReturn, ParamSpec, Protocol, Self, TypeVar, TypeVarTuple +from typing import TYPE_CHECKING, Any, Generic, NoReturn, ParamSpec, Protocol, Self, TypeVar, TypeVarTuple, Unpack if TYPE_CHECKING: import concurrent.futures @@ -451,9 +451,9 @@ async def __aexit__( @abstractmethod def start_soon( self, - coro_func: Callable[[*_T_PosArgs], Coroutine[Any, Any, _T]], + coro_func: Callable[[Unpack[_T_PosArgs]], Coroutine[Any, Any, _T]], /, - *args: *_T_PosArgs, + *args: Unpack[_T_PosArgs], name: str | None = ..., ) -> None: """ @@ -470,9 +470,9 @@ def start_soon( @abstractmethod async def start( self, - coro_func: Callable[[*_T_PosArgs], Coroutine[Any, Any, _T]], + coro_func: Callable[[Unpack[_T_PosArgs]], Coroutine[Any, Any, _T]], /, - *args: *_T_PosArgs, + *args: Unpack[_T_PosArgs], name: str | None = ..., ) -> Task[_T]: """ @@ -625,8 +625,8 @@ class AsyncBackend(metaclass=ABCMeta): @abstractmethod def bootstrap( self, - coro_func: Callable[[*_T_PosArgs], Coroutine[Any, Any, _T]], - *args: *_T_PosArgs, + coro_func: Callable[[Unpack[_T_PosArgs]], Coroutine[Any, Any, _T]], + *args: Unpack[_T_PosArgs], runner_options: Mapping[str, Any] | None = ..., ) -> _T: """ diff --git a/src/easynetwork/lowlevel/api_async/endpoints/datagram.py b/src/easynetwork/lowlevel/api_async/endpoints/datagram.py index 24d70d94..1b7b2bbb 100644 --- a/src/easynetwork/lowlevel/api_async/endpoints/datagram.py +++ b/src/easynetwork/lowlevel/api_async/endpoints/datagram.py @@ -27,15 +27,15 @@ from collections.abc import Callable, Mapping from typing import Any, Generic -from .... import protocol as protocol_module from ...._typevars import _T_ReceivedPacket, _T_SentPacket from ....exceptions import DatagramProtocolParseError +from ....protocol import DatagramProtocol from ... import _utils from ..backend.abc import AsyncBackend -from ..transports import abc as transports +from ..transports.abc import AsyncBaseTransport, AsyncDatagramReadTransport, AsyncDatagramTransport, AsyncDatagramWriteTransport -class AsyncDatagramReceiverEndpoint(transports.AsyncBaseTransport, Generic[_T_ReceivedPacket]): +class AsyncDatagramReceiverEndpoint(AsyncBaseTransport, Generic[_T_ReceivedPacket]): """ A read-only communication endpoint based on unreliable packets of data. """ @@ -48,8 +48,8 @@ class AsyncDatagramReceiverEndpoint(transports.AsyncBaseTransport, Generic[_T_Re def __init__( self, - transport: transports.AsyncDatagramReadTransport, - protocol: protocol_module.DatagramProtocol[Any, _T_ReceivedPacket], + transport: AsyncDatagramReadTransport, + protocol: DatagramProtocol[Any, _T_ReceivedPacket], ) -> None: """ Parameters: @@ -57,14 +57,14 @@ def __init__( protocol: The :term:`protocol object` to use. """ - if not isinstance(transport, transports.AsyncDatagramReadTransport): + if not isinstance(transport, AsyncDatagramReadTransport): raise TypeError(f"Expected an AsyncDatagramReadTransport object, got {transport!r}") - if not isinstance(protocol, protocol_module.DatagramProtocol): + if not isinstance(protocol, DatagramProtocol): raise TypeError(f"Expected a DatagramProtocol object, got {protocol!r}") self.__receiver: _DataReceiverImpl[_T_ReceivedPacket] = _DataReceiverImpl(transport, protocol) - self.__transport: transports.AsyncDatagramReadTransport = transport + self.__transport: AsyncDatagramReadTransport = transport self.__recv_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently receving data on this endpoint") def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: @@ -106,16 +106,17 @@ async def recv_packet(self) -> _T_ReceivedPacket: return await receiver.receive() - @_utils.inherit_doc(transports.AsyncBaseTransport) + @_utils.inherit_doc(AsyncBaseTransport) def backend(self) -> AsyncBackend: return self.__transport.backend() @property + @_utils.inherit_doc(AsyncBaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes -class AsyncDatagramSenderEndpoint(transports.AsyncBaseTransport, Generic[_T_SentPacket]): +class AsyncDatagramSenderEndpoint(AsyncBaseTransport, Generic[_T_SentPacket]): """ A write-only communication endpoint based on unreliable packets of data. """ @@ -128,8 +129,8 @@ class AsyncDatagramSenderEndpoint(transports.AsyncBaseTransport, Generic[_T_Sent def __init__( self, - transport: transports.AsyncDatagramWriteTransport, - protocol: protocol_module.DatagramProtocol[_T_SentPacket, Any], + transport: AsyncDatagramWriteTransport, + protocol: DatagramProtocol[_T_SentPacket, Any], ) -> None: """ Parameters: @@ -137,14 +138,14 @@ def __init__( protocol: The :term:`protocol object` to use. """ - if not isinstance(transport, transports.AsyncDatagramWriteTransport): + if not isinstance(transport, AsyncDatagramWriteTransport): raise TypeError(f"Expected an AsyncDatagramWriteTransport object, got {transport!r}") - if not isinstance(protocol, protocol_module.DatagramProtocol): + if not isinstance(protocol, DatagramProtocol): raise TypeError(f"Expected a DatagramProtocol object, got {protocol!r}") self.__sender: _DataSenderImpl[_T_SentPacket] = _DataSenderImpl(transport, protocol) - self.__transport: transports.AsyncDatagramWriteTransport = transport + self.__transport: AsyncDatagramWriteTransport = transport self.__send_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently sending data on this endpoint") def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: @@ -186,16 +187,17 @@ async def send_packet(self, packet: _T_SentPacket) -> None: await sender.send(packet) - @_utils.inherit_doc(transports.AsyncBaseTransport) + @_utils.inherit_doc(AsyncBaseTransport) def backend(self) -> AsyncBackend: return self.__transport.backend() @property + @_utils.inherit_doc(AsyncBaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes -class AsyncDatagramEndpoint(transports.AsyncBaseTransport, Generic[_T_SentPacket, _T_ReceivedPacket]): +class AsyncDatagramEndpoint(AsyncBaseTransport, Generic[_T_SentPacket, _T_ReceivedPacket]): """ A full-duplex communication endpoint based on unreliable packets of data. """ @@ -210,8 +212,8 @@ class AsyncDatagramEndpoint(transports.AsyncBaseTransport, Generic[_T_SentPacket def __init__( self, - transport: transports.AsyncDatagramTransport, - protocol: protocol_module.DatagramProtocol[_T_SentPacket, _T_ReceivedPacket], + transport: AsyncDatagramTransport, + protocol: DatagramProtocol[_T_SentPacket, _T_ReceivedPacket], ) -> None: """ Parameters: @@ -219,15 +221,15 @@ def __init__( protocol: The :term:`protocol object` to use. """ - if not isinstance(transport, transports.AsyncDatagramTransport): + if not isinstance(transport, AsyncDatagramTransport): raise TypeError(f"Expected an AsyncDatagramTransport object, got {transport!r}") - if not isinstance(protocol, protocol_module.DatagramProtocol): + if not isinstance(protocol, DatagramProtocol): raise TypeError(f"Expected a DatagramProtocol object, got {protocol!r}") self.__sender: _DataSenderImpl[_T_SentPacket] = _DataSenderImpl(transport, protocol) self.__receiver: _DataReceiverImpl[_T_ReceivedPacket] = _DataReceiverImpl(transport, protocol) - self.__transport: transports.AsyncDatagramTransport = transport + self.__transport: AsyncDatagramTransport = transport self.__send_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently sending data on this endpoint") self.__recv_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently receving data on this endpoint") @@ -285,19 +287,20 @@ async def recv_packet(self) -> _T_ReceivedPacket: return await receiver.receive() - @_utils.inherit_doc(transports.AsyncBaseTransport) + @_utils.inherit_doc(AsyncBaseTransport) def backend(self) -> AsyncBackend: return self.__transport.backend() @property + @_utils.inherit_doc(AsyncBaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes @dataclasses.dataclass(slots=True) class _DataSenderImpl(Generic[_T_SentPacket]): - transport: transports.AsyncDatagramWriteTransport - protocol: protocol_module.DatagramProtocol[_T_SentPacket, Any] + transport: AsyncDatagramWriteTransport + protocol: DatagramProtocol[_T_SentPacket, Any] async def send(self, packet: _T_SentPacket) -> None: try: @@ -312,8 +315,8 @@ async def send(self, packet: _T_SentPacket) -> None: @dataclasses.dataclass(slots=True) class _DataReceiverImpl(Generic[_T_ReceivedPacket]): - transport: transports.AsyncDatagramReadTransport - protocol: protocol_module.DatagramProtocol[Any, _T_ReceivedPacket] + transport: AsyncDatagramReadTransport + protocol: DatagramProtocol[Any, _T_ReceivedPacket] async def receive(self) -> _T_ReceivedPacket: datagram = await self.transport.recv() diff --git a/src/easynetwork/lowlevel/api_async/endpoints/stream.py b/src/easynetwork/lowlevel/api_async/endpoints/stream.py index 81f4ed06..8519bb02 100644 --- a/src/easynetwork/lowlevel/api_async/endpoints/stream.py +++ b/src/easynetwork/lowlevel/api_async/endpoints/stream.py @@ -28,16 +28,22 @@ from collections.abc import Callable, Mapping from typing import Any, Generic, Literal, assert_never -from .... import protocol as protocol_module from ...._typevars import _T_ReceivedPacket, _T_SentPacket from ....exceptions import UnsupportedOperation +from ....protocol import StreamProtocol from ....warnings import ManualBufferAllocationWarning from ... import _stream, _utils from ..backend.abc import AsyncBackend -from ..transports import abc as transports +from ..transports.abc import ( + AsyncBaseTransport, + AsyncBufferedStreamReadTransport, + AsyncStreamReadTransport, + AsyncStreamTransport, + AsyncStreamWriteTransport, +) -class AsyncStreamReceiverEndpoint(transports.AsyncBaseTransport, Generic[_T_ReceivedPacket]): +class AsyncStreamReceiverEndpoint(AsyncBaseTransport, Generic[_T_ReceivedPacket]): """ A read-only communication endpoint based on continuous stream data transport. """ @@ -50,8 +56,8 @@ class AsyncStreamReceiverEndpoint(transports.AsyncBaseTransport, Generic[_T_Rece def __init__( self, - transport: transports.AsyncStreamReadTransport, - protocol: protocol_module.StreamProtocol[Any, _T_ReceivedPacket], + transport: AsyncStreamReadTransport, + protocol: StreamProtocol[Any, _T_ReceivedPacket], max_recv_size: int, *, manual_buffer_allocation: Literal["try", "no", "force"] = "try", @@ -76,7 +82,7 @@ def __init__( :exc:`.ManualBufferAllocationWarning`. """ - if not isinstance(transport, transports.AsyncStreamReadTransport): + if not isinstance(transport, AsyncStreamReadTransport): raise TypeError(f"Expected an AsyncStreamReadTransport object, got {transport!r}") _check_max_recv_size_value(max_recv_size) _check_manual_buffer_allocation_value(manual_buffer_allocation) @@ -90,7 +96,7 @@ def __init__( manual_buffer_allocation_warning_stacklevel=manual_buffer_allocation_warning_stacklevel, ) - self.__transport: transports.AsyncStreamReadTransport = transport + self.__transport: AsyncStreamReadTransport = transport self.__recv_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently receving data on this endpoint") def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: @@ -134,7 +140,7 @@ async def recv_packet(self) -> _T_ReceivedPacket: return await receiver.receive() - @_utils.inherit_doc(transports.AsyncBaseTransport) + @_utils.inherit_doc(AsyncBaseTransport) def backend(self) -> AsyncBackend: return self.__transport.backend() @@ -144,11 +150,12 @@ def max_recv_size(self) -> int: return self.__receiver.max_recv_size @property + @_utils.inherit_doc(AsyncBaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes -class AsyncStreamSenderEndpoint(transports.AsyncBaseTransport, Generic[_T_SentPacket]): +class AsyncStreamSenderEndpoint(AsyncBaseTransport, Generic[_T_SentPacket]): """ A write-only communication endpoint based on continuous stream data transport. """ @@ -161,8 +168,8 @@ class AsyncStreamSenderEndpoint(transports.AsyncBaseTransport, Generic[_T_SentPa def __init__( self, - transport: transports.AsyncStreamWriteTransport, - protocol: protocol_module.StreamProtocol[_T_SentPacket, Any], + transport: AsyncStreamWriteTransport, + protocol: StreamProtocol[_T_SentPacket, Any], ) -> None: """ Parameters: @@ -170,12 +177,12 @@ def __init__( protocol: The :term:`protocol object` to use. """ - if not isinstance(transport, transports.AsyncStreamWriteTransport): + if not isinstance(transport, AsyncStreamWriteTransport): raise TypeError(f"Expected an AsyncStreamWriteTransport object, got {transport!r}") self.__sender: _DataSenderImpl[_T_SentPacket] = _DataSenderImpl(transport, _stream.StreamDataProducer(protocol)) - self.__transport: transports.AsyncStreamWriteTransport = transport + self.__transport: AsyncStreamWriteTransport = transport self.__send_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently sending data on this endpoint") def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: @@ -218,16 +225,17 @@ async def send_packet(self, packet: _T_SentPacket) -> None: return await sender.send(packet) - @_utils.inherit_doc(transports.AsyncBaseTransport) + @_utils.inherit_doc(AsyncBaseTransport) def backend(self) -> AsyncBackend: return self.__transport.backend() @property + @_utils.inherit_doc(AsyncBaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes -class AsyncStreamEndpoint(transports.AsyncBaseTransport, Generic[_T_SentPacket, _T_ReceivedPacket]): +class AsyncStreamEndpoint(AsyncBaseTransport, Generic[_T_SentPacket, _T_ReceivedPacket]): """ A full-duplex communication endpoint based on continuous stream data transport. """ @@ -243,8 +251,8 @@ class AsyncStreamEndpoint(transports.AsyncBaseTransport, Generic[_T_SentPacket, def __init__( self, - transport: transports.AsyncStreamTransport, - protocol: protocol_module.StreamProtocol[_T_SentPacket, _T_ReceivedPacket], + transport: AsyncStreamTransport, + protocol: StreamProtocol[_T_SentPacket, _T_ReceivedPacket], max_recv_size: int, *, manual_buffer_allocation: Literal["try", "no", "force"] = "try", @@ -269,7 +277,7 @@ def __init__( :exc:`.ManualBufferAllocationWarning`. """ - if not isinstance(transport, transports.AsyncStreamTransport): + if not isinstance(transport, AsyncStreamTransport): raise TypeError(f"Expected an AsyncStreamTransport object, got {transport!r}") _check_max_recv_size_value(max_recv_size) _check_manual_buffer_allocation_value(manual_buffer_allocation) @@ -284,7 +292,7 @@ def __init__( manual_buffer_allocation_warning_stacklevel=manual_buffer_allocation_warning_stacklevel, ) - self.__transport: transports.AsyncStreamTransport = transport + self.__transport: AsyncStreamTransport = transport self.__send_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently sending data on this endpoint") self.__recv_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently receving data on this endpoint") self.__eof_sent: bool = False @@ -369,7 +377,7 @@ async def recv_packet(self) -> _T_ReceivedPacket: return await receiver.receive() - @_utils.inherit_doc(transports.AsyncBaseTransport) + @_utils.inherit_doc(AsyncBaseTransport) def backend(self) -> AsyncBackend: return self.__transport.backend() @@ -379,13 +387,14 @@ def max_recv_size(self) -> int: return self.__receiver.max_recv_size @property + @_utils.inherit_doc(AsyncBaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes @dataclasses.dataclass(slots=True) class _DataSenderImpl(Generic[_T_SentPacket]): - transport: transports.AsyncStreamWriteTransport + transport: AsyncStreamWriteTransport producer: _stream.StreamDataProducer[_T_SentPacket] async def send(self, packet: _T_SentPacket) -> None: @@ -394,7 +403,7 @@ async def send(self, packet: _T_SentPacket) -> None: @dataclasses.dataclass(slots=True) class _DataReceiverImpl(Generic[_T_ReceivedPacket]): - transport: transports.AsyncStreamReadTransport + transport: AsyncStreamReadTransport consumer: _stream.StreamDataConsumer[_T_ReceivedPacket] max_recv_size: int _eof_reached: bool = dataclasses.field(init=False, default=False) @@ -429,7 +438,7 @@ async def receive(self) -> _T_ReceivedPacket: @dataclasses.dataclass(slots=True) class _BufferedReceiverImpl(Generic[_T_ReceivedPacket]): - transport: transports.AsyncBufferedStreamReadTransport + transport: AsyncBufferedStreamReadTransport consumer: _stream.BufferedStreamDataConsumer[_T_ReceivedPacket] _eof_reached: bool = dataclasses.field(init=False, default=False) @@ -465,8 +474,8 @@ async def receive(self) -> _T_ReceivedPacket: def _get_receiver( - transport: transports.AsyncStreamReadTransport, - protocol: protocol_module.StreamProtocol[Any, _T_ReceivedPacket], + transport: AsyncStreamReadTransport, + protocol: StreamProtocol[Any, _T_ReceivedPacket], *, max_recv_size: int, manual_buffer_allocation: Literal["try", "no", "force"], @@ -479,7 +488,7 @@ def _get_receiver( case "try" | "force": try: buffered_consumer = _stream.BufferedStreamDataConsumer(protocol, max_recv_size) - if not isinstance(transport, transports.AsyncBufferedStreamReadTransport): + if not isinstance(transport, AsyncBufferedStreamReadTransport): msg = f"The transport implementation {transport!r} does not implement AsyncBufferedStreamReadTransport interface" if manual_buffer_allocation == "try": warnings.warn( diff --git a/src/easynetwork/lowlevel/api_async/servers/datagram.py b/src/easynetwork/lowlevel/api_async/servers/datagram.py index a62925e0..a3870103 100644 --- a/src/easynetwork/lowlevel/api_async/servers/datagram.py +++ b/src/easynetwork/lowlevel/api_async/servers/datagram.py @@ -16,7 +16,7 @@ from __future__ import annotations -__all__ = ["AsyncDatagramServer", "DatagramClientContext"] +__all__ = ["AsyncDatagramServer"] import contextlib import contextvars @@ -29,13 +29,13 @@ from contextlib import AsyncExitStack, ExitStack from typing import Any, Generic, NoReturn, TypeVar -from .... import protocol as protocol_module from ...._typevars import _T_Request, _T_Response from ....exceptions import DatagramProtocolParseError -from ... import _utils, typed_attr +from ....protocol import DatagramProtocol +from ... import _utils from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction from ..backend.abc import AsyncBackend, ICondition, ILock, TaskGroup -from ..transports import abc as transports +from ..transports.abc import AsyncBaseTransport, AsyncDatagramListener _T_Address = TypeVar("_T_Address", bound=Hashable) @@ -47,6 +47,10 @@ @dataclasses.dataclass(frozen=True, unsafe_hash=True) class DatagramClientContext(Generic[_T_Response, _T_Address]): + """ + Contains information about the remote endpoint which sends a datagram. + """ + __slots__ = ( "address", "server", @@ -54,30 +58,41 @@ class DatagramClientContext(Generic[_T_Response, _T_Address]): ) address: _T_Address + """The client address""" + server: AsyncDatagramServer[Any, _T_Response, _T_Address] + """The server which receives the datagram.""" -class AsyncDatagramServer(typed_attr.TypedAttributeProvider, Generic[_T_Request, _T_Response, _T_Address]): +class AsyncDatagramServer(AsyncBaseTransport, Generic[_T_Request, _T_Response, _T_Address]): + """ + Datagram packet listener interface. + """ + __slots__ = ( "__listener", "__protocol", "__sendto_lock", "__serve_guard", - "__weakref__", ) def __init__( self, - listener: transports.AsyncDatagramListener[_T_Address], - protocol: protocol_module.DatagramProtocol[_T_Response, _T_Request], + listener: AsyncDatagramListener[_T_Address], + protocol: DatagramProtocol[_T_Response, _T_Request], ) -> None: - if not isinstance(listener, transports.AsyncDatagramListener): + """ + Parameters: + listener: the transport implementation to wrap. + protocol: The :term:`protocol object` to use. + """ + if not isinstance(listener, AsyncDatagramListener): raise TypeError(f"Expected an AsyncDatagramListener object, got {listener!r}") - if not isinstance(protocol, protocol_module.DatagramProtocol): + if not isinstance(protocol, DatagramProtocol): raise TypeError(f"Expected a DatagramProtocol object, got {protocol!r}") - self.__listener: transports.AsyncDatagramListener[_T_Address] = listener - self.__protocol: protocol_module.DatagramProtocol[_T_Response, _T_Request] = protocol + self.__listener: AsyncDatagramListener[_T_Address] = listener + self.__protocol: DatagramProtocol[_T_Response, _T_Request] = protocol self.__sendto_lock: ILock = listener.backend().create_lock() self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently receiving datagrams") @@ -105,7 +120,7 @@ async def aclose(self) -> None: """ await self.__listener.aclose() - @_utils.inherit_doc(transports.AsyncBaseTransport) + @_utils.inherit_doc(AsyncBaseTransport) def backend(self) -> AsyncBackend: return self.__listener.backend() @@ -134,6 +149,25 @@ async def serve( ], task_group: TaskGroup | None = None, ) -> NoReturn: + """ + Receive incoming datagrams as they come in and start tasks to handle them. + + Important: + There will always be only one active generator per client. + All the pending datagrams received while the generator is running are queued. + + This behavior is designed to act like a stream request handler. + + Note: + If the generator returns before the first :keyword:`yield` statement, the received datagram is discarded. + + This is useful when a client that you do not expect to see sends something; the datagrams are parsed only when + the generator hits a :keyword:`yield` statement. + + Parameters: + datagram_received_cb: a callable that will be used to handle each received datagram. + task_group: the task group that will be used to start tasks for handling each received datagram. + """ with self.__serve_guard: listener = self.__listener backend = listener.backend() @@ -257,7 +291,7 @@ def __on_task_done( @staticmethod def __parse_datagram( datagram: bytes, - protocol: protocol_module.DatagramProtocol[_T_Response, _T_Request], + protocol: DatagramProtocol[_T_Response, _T_Request], ) -> AsyncGenAction[_T_Request]: try: try: @@ -272,6 +306,7 @@ def __parse_datagram( return SendAction(request) @property + @_utils.inherit_doc(AsyncBaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__listener.extra_attributes diff --git a/src/easynetwork/lowlevel/api_async/servers/stream.py b/src/easynetwork/lowlevel/api_async/servers/stream.py index 909d4f30..c92e561c 100644 --- a/src/easynetwork/lowlevel/api_async/servers/stream.py +++ b/src/easynetwork/lowlevel/api_async/servers/stream.py @@ -16,7 +16,7 @@ from __future__ import annotations -__all__ = ["AsyncStreamClient", "AsyncStreamServer"] +__all__ = ["AsyncStreamServer"] import contextlib import dataclasses @@ -24,34 +24,45 @@ from collections.abc import AsyncGenerator, Callable, Mapping from typing import Any, Generic, Literal, NoReturn, assert_never -from .... import protocol as protocol_module from ...._typevars import _T_Request, _T_Response from ....exceptions import UnsupportedOperation +from ....protocol import StreamProtocol from ....warnings import ManualBufferAllocationWarning -from ... import _stream, _utils, typed_attr +from ... import _stream, _utils from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction from ..backend.abc import AsyncBackend, TaskGroup -from ..transports import abc as transports, utils as transports_utils +from ..transports import utils as transports_utils +from ..transports.abc import ( + AsyncBaseTransport, + AsyncBufferedStreamReadTransport, + AsyncListener, + AsyncStreamReadTransport, + AsyncStreamTransport, + AsyncStreamWriteTransport, +) + + +class Client(AsyncBaseTransport, Generic[_T_Response]): + """ + Write-end of the connected client. + """ - -class AsyncStreamClient(typed_attr.TypedAttributeProvider, Generic[_T_Response]): __slots__ = ( "__transport", "__producer", "__exit_stack", "__send_guard", - "__weakref__", ) def __init__( self, - transport: transports.AsyncStreamWriteTransport, + transport: AsyncStreamWriteTransport, producer: _stream.StreamDataProducer[_T_Response], exit_stack: contextlib.AsyncExitStack, ) -> None: super().__init__() - self.__transport: transports.AsyncStreamWriteTransport = transport + self.__transport: AsyncStreamWriteTransport = transport self.__producer: _stream.StreamDataProducer[_T_Response] = producer self.__exit_stack: contextlib.AsyncExitStack = exit_stack self.__send_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently sending data on this endpoint") @@ -86,40 +97,64 @@ async def send_packet(self, packet: _T_Response) -> None: with self.__send_guard: await self.__transport.send_all_from_iterable(self.__producer.generate(packet)) + @_utils.inherit_doc(AsyncBaseTransport) + def backend(self) -> AsyncBackend: + return self.__transport.backend() + @property + @_utils.inherit_doc(AsyncBaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes -class AsyncStreamServer(typed_attr.TypedAttributeProvider, Generic[_T_Request, _T_Response]): +class AsyncStreamServer(AsyncBaseTransport, Generic[_T_Request, _T_Response]): + """ + Stream listener interface. + """ + __slots__ = ( "__listener", "__protocol", "__max_recv_size", "__serve_guard", "__manual_buffer_allocation", - "__weakref__", ) def __init__( self, - listener: transports.AsyncListener[transports.AsyncStreamTransport], - protocol: protocol_module.StreamProtocol[_T_Response, _T_Request], + listener: AsyncListener[AsyncStreamTransport], + protocol: StreamProtocol[_T_Response, _T_Request], max_recv_size: int, *, manual_buffer_allocation: Literal["try", "no", "force"] = "try", ) -> None: - if not isinstance(listener, transports.AsyncListener): + """ + Parameters: + listener: the transport implementation to wrap. + protocol: The :term:`protocol object` to use. + max_recv_size: Read buffer size. + manual_buffer_allocation: Select whether or not to enable the manual buffer allocation system: + + * ``"try"``: (the default) will use the buffer API if the transport and protocol support it, + and fall back to the default implementation otherwise. + Emits a :exc:`.ManualBufferAllocationWarning` if only the transport does not support it. + + * ``"no"``: does not use the buffer API, even if they both support it. + + * ``"force"``: requires the buffer API. Raises :exc:`.UnsupportedOperation` if it fails and + no warnings are emitted. + """ + if not isinstance(listener, AsyncListener): raise TypeError(f"Expected an AsyncListener object, got {listener!r}") - if not isinstance(protocol, protocol_module.StreamProtocol): + if not isinstance(protocol, StreamProtocol): raise TypeError(f"Expected a StreamProtocol object, got {protocol!r}") if not isinstance(max_recv_size, int) or max_recv_size <= 0: raise ValueError("'max_recv_size' must be a strictly positive integer") if manual_buffer_allocation not in ("try", "no", "force"): raise ValueError('"manual_buffer_allocation" must be "try", "no" or "force"') - self.__listener: transports.AsyncListener[transports.AsyncStreamTransport] = listener - self.__protocol: protocol_module.StreamProtocol[_T_Response, _T_Request] = protocol + self.__listener: AsyncListener[AsyncStreamTransport] = listener + self.__protocol: StreamProtocol[_T_Response, _T_Request] = protocol self.__max_recv_size: int = max_recv_size self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently accepting new connections") self.__manual_buffer_allocation: Literal["try", "no", "force"] = manual_buffer_allocation @@ -148,25 +183,32 @@ async def aclose(self) -> None: """ await self.__listener.aclose() - @_utils.inherit_doc(transports.AsyncBaseTransport) + @_utils.inherit_doc(AsyncBaseTransport) def backend(self) -> AsyncBackend: return self.__listener.backend() async def serve( self, - client_connected_cb: Callable[[AsyncStreamClient[_T_Response]], AsyncGenerator[float | None, _T_Request]], + client_connected_cb: Callable[[Client[_T_Response]], AsyncGenerator[float | None, _T_Request]], task_group: TaskGroup | None = None, ) -> NoReturn: + """ + Accept incoming connections as they come in and start tasks to handle them. + + Parameters: + client_connected_cb: a callable that will be used to handle each accepted connection. + task_group: the task group that will be used to start tasks for handling each accepted connection. + """ with self.__serve_guard: handler = _utils.prepend_argument(client_connected_cb, self.__client_coroutine) await self.__listener.serve(handler, task_group) async def __client_coroutine( self, - client_connected_cb: Callable[[AsyncStreamClient[_T_Response]], AsyncGenerator[float | None, _T_Request]], - transport: transports.AsyncStreamTransport, + client_connected_cb: Callable[[Client[_T_Response]], AsyncGenerator[float | None, _T_Request]], + transport: AsyncStreamTransport, ) -> None: - if not isinstance(transport, transports.AsyncStreamTransport): + if not isinstance(transport, AsyncStreamTransport): raise TypeError(f"Expected an AsyncStreamTransport object, got {transport!r}") async with contextlib.AsyncExitStack() as task_exit_stack: @@ -180,14 +222,12 @@ async def __client_coroutine( case "try" | "force" as manual_buffer_allocation: try: consumer = _stream.BufferedStreamDataConsumer(self.__protocol, self.__max_recv_size) - if not isinstance(transport, transports.AsyncBufferedStreamReadTransport): + if not isinstance(transport, AsyncBufferedStreamReadTransport): msg = f"The transport implementation {transport!r} does not implement AsyncBufferedStreamReadTransport interface" if manual_buffer_allocation == "try": - warnings.warn( - f'{msg}. Consider explicitly setting the "manual_buffer_allocation" strategy to "no".', - category=ManualBufferAllocationWarning, - stacklevel=1, - ) + _warn_msg = f'{msg}. Consider explicitly setting the "manual_buffer_allocation" strategy to "no".' + warnings.warn(_warn_msg, category=ManualBufferAllocationWarning, stacklevel=1) + del _warn_msg raise UnsupportedOperation(msg) request_receiver = _BufferedRequestReceiver( transport=transport, @@ -216,11 +256,9 @@ async def __client_coroutine( client_exit_stack = await task_exit_stack.enter_async_context(contextlib.AsyncExitStack()) client_exit_stack.callback(consumer.clear) - client = AsyncStreamClient(transport, producer, client_exit_stack) - - request_handler_generator = client_connected_cb(client) + request_handler_generator = client_connected_cb(Client(transport, producer, client_exit_stack)) - del client_exit_stack, task_exit_stack, client_connected_cb, client + del client_exit_stack, task_exit_stack, client_connected_cb timeout: float | None try: @@ -243,13 +281,14 @@ async def __client_coroutine( await request_handler_generator.aclose() @property + @_utils.inherit_doc(AsyncBaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__listener.extra_attributes @dataclasses.dataclass(kw_only=True, eq=False, frozen=True, slots=True) class _RequestReceiver(Generic[_T_Request]): - transport: transports.AsyncStreamReadTransport + transport: AsyncStreamReadTransport consumer: _stream.StreamDataConsumer[_T_Request] max_recv_size: int __null_timeout_ctx: contextlib.nullcontext[None] = dataclasses.field(init=False, default_factory=contextlib.nullcontext) @@ -283,7 +322,7 @@ async def next(self, timeout: float | None) -> AsyncGenAction[_T_Request]: @dataclasses.dataclass(kw_only=True, eq=False, frozen=True, slots=True) class _BufferedRequestReceiver(Generic[_T_Request]): - transport: transports.AsyncBufferedStreamReadTransport + transport: AsyncBufferedStreamReadTransport consumer: _stream.BufferedStreamDataConsumer[_T_Request] __null_timeout_ctx: contextlib.nullcontext[None] = dataclasses.field(init=False, default_factory=contextlib.nullcontext) diff --git a/src/easynetwork/lowlevel/api_async/transports/tls.py b/src/easynetwork/lowlevel/api_async/transports/tls.py index a4b21525..d1f32cfe 100644 --- a/src/easynetwork/lowlevel/api_async/transports/tls.py +++ b/src/easynetwork/lowlevel/api_async/transports/tls.py @@ -38,11 +38,11 @@ from ....exceptions import UnsupportedOperation from ... import _utils, constants, socket as socket_tools from ..backend.abc import AsyncBackend, TaskGroup -from . import abc as transports +from .abc import AsyncBufferedStreamReadTransport, AsyncListener, AsyncStreamReadTransport, AsyncStreamTransport from .utils import aclose_forcefully if TYPE_CHECKING: - import ssl as _typing_ssl + from ssl import MemoryBIO, SSLContext, SSLObject, SSLSession from _typeshed import WriteableBuffer @@ -51,18 +51,22 @@ @dataclasses.dataclass(repr=False, eq=False, slots=True, kw_only=True) -class AsyncTLSStreamTransport(transports.AsyncStreamTransport, transports.AsyncBufferedStreamReadTransport): - _transport: transports.AsyncStreamTransport +class AsyncTLSStreamTransport(AsyncStreamTransport, AsyncBufferedStreamReadTransport): + """ + SSL/TLS wrapper for a continuous stream transport. + """ + + _transport: AsyncStreamTransport _standard_compatible: bool _shutdown_timeout: float - _ssl_object: _typing_ssl.SSLObject - _read_bio: _typing_ssl.MemoryBIO - _write_bio: _typing_ssl.MemoryBIO + _ssl_object: SSLObject + _read_bio: MemoryBIO + _write_bio: MemoryBIO __incoming_reader: _IncomingDataReader = dataclasses.field(init=False) __closing: bool = dataclasses.field(init=False, default=False) def __post_init__(self) -> None: - if isinstance(self._transport, transports.AsyncBufferedStreamReadTransport): + if isinstance(self._transport, AsyncBufferedStreamReadTransport): self.__incoming_reader = _BufferedIncomingDataReader(transport=self._transport) else: self.__incoming_reader = _IncomingDataReader(transport=self._transport) @@ -70,16 +74,32 @@ def __post_init__(self) -> None: @classmethod async def wrap( cls, - transport: transports.AsyncStreamTransport, - ssl_context: _typing_ssl.SSLContext, + transport: AsyncStreamTransport, + ssl_context: SSLContext, *, handshake_timeout: float | None = None, shutdown_timeout: float | None = None, server_side: bool | None = None, server_hostname: str | None = None, standard_compatible: bool = True, - session: _typing_ssl.SSLSession | None = None, + session: SSLSession | None = None, ) -> Self: + """ + Parameters: + transport: The transport to wrap. + ssl_context: a :class:`ssl.SSLContext` object to use to create the transport. + handshake_timeout: The time in seconds to wait for the TLS handshake to complete before aborting the connection. + ``60.0`` seconds if :data:`None` (default). + shutdown_timeout: The time in seconds to wait for the SSL shutdown to complete before aborting the connection. + ``30.0`` seconds if :data:`None` (default). + server_side: Indicates whether we are a client or a server for the handshake part. If it is set to :data:`None`, + it is deduced according to `server_hostname`. + server_hostname: sets or overrides the hostname that the target server's certificate will be matched against. + If `server_side` is :data:`True`, you must pass a value for `server_hostname`. + standard_compatible: If :data:`False`, skip the closing handshake when closing the connection, + and don't raise an exception if the peer does the same. + session: If an SSL session already exits, use it insead. + """ assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used if server_side is None: @@ -129,11 +149,11 @@ def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: msg = f"unclosed transport {self!r} pointing to {transport!r} (and cannot be closed synchronously)" _warn(msg, ResourceWarning, source=self) - @_utils.inherit_doc(transports.AsyncStreamTransport) + @_utils.inherit_doc(AsyncStreamTransport) def is_closing(self) -> bool: return self.__closing - @_utils.inherit_doc(transports.AsyncStreamTransport) + @_utils.inherit_doc(AsyncStreamTransport) async def aclose(self) -> None: with contextlib.ExitStack() as stack: stack.callback(self.__incoming_reader.close) @@ -153,11 +173,11 @@ async def aclose(self) -> None: await self._transport.aclose() - @_utils.inherit_doc(transports.AsyncStreamTransport) + @_utils.inherit_doc(AsyncStreamTransport) def backend(self) -> AsyncBackend: return self._transport.backend() - @_utils.inherit_doc(transports.AsyncStreamTransport) + @_utils.inherit_doc(AsyncStreamTransport) async def recv(self, bufsize: int) -> bytes: assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used try: @@ -170,7 +190,7 @@ async def recv(self, bufsize: int) -> bytes: return b"" raise - @_utils.inherit_doc(transports.AsyncBufferedStreamReadTransport) + @_utils.inherit_doc(AsyncBufferedStreamReadTransport) async def recv_into(self, buffer: WriteableBuffer) -> int: assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used nbytes = memoryview(buffer).nbytes or 1024 @@ -184,7 +204,7 @@ async def recv_into(self, buffer: WriteableBuffer) -> int: return 0 raise - @_utils.inherit_doc(transports.AsyncStreamTransport) + @_utils.inherit_doc(AsyncStreamTransport) async def send_all(self, data: bytes | bytearray | memoryview) -> None: assert _ssl_module is not None, "stdlib ssl module not available" # nosec assert_used try: @@ -192,7 +212,7 @@ async def send_all(self, data: bytes | bytearray | memoryview) -> None: except _ssl_module.SSLZeroReturnError as exc: raise _utils.error_from_errno(errno.ECONNRESET) from exc - @_utils.inherit_doc(transports.AsyncStreamTransport) + @_utils.inherit_doc(AsyncStreamTransport) async def send_eof(self) -> None: raise UnsupportedOperation("SSL/TLS API does not support sending EOF.") @@ -230,7 +250,7 @@ async def _retry_ssl_method( return result @property - @_utils.inherit_doc(transports.AsyncStreamTransport) + @_utils.inherit_doc(AsyncStreamTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return { **self._transport.extra_attributes, @@ -239,7 +259,7 @@ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: } -class AsyncTLSListener(transports.AsyncListener[AsyncTLSStreamTransport]): +class AsyncTLSListener(AsyncListener[AsyncTLSStreamTransport]): __slots__ = ( "__listener", "__ssl_context", @@ -250,16 +270,16 @@ class AsyncTLSListener(transports.AsyncListener[AsyncTLSStreamTransport]): def __init__( self, - listener: transports.AsyncListener[transports.AsyncStreamTransport], - ssl_context: _typing_ssl.SSLContext, + listener: AsyncListener[AsyncStreamTransport], + ssl_context: SSLContext, *, handshake_timeout: float | None = None, shutdown_timeout: float | None = None, standard_compatible: bool = True, ) -> None: super().__init__() - self.__listener: transports.AsyncListener[transports.AsyncStreamTransport] = listener - self.__ssl_context: _typing_ssl.SSLContext = ssl_context + self.__listener: AsyncListener[AsyncStreamTransport] = listener + self.__ssl_context: SSLContext = ssl_context self.__handshake_timeout: float | None = handshake_timeout self.__shutdown_timeout: float | None = shutdown_timeout self.__standard_compatible: bool = standard_compatible @@ -274,15 +294,15 @@ def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: msg = f"unclosed listener {self!r} pointing to {listener!r} (and cannot be closed synchronously)" _warn(msg, ResourceWarning, source=self) - @_utils.inherit_doc(transports.AsyncListener) + @_utils.inherit_doc(AsyncListener) def is_closing(self) -> bool: return self.__listener.is_closing() - @_utils.inherit_doc(transports.AsyncListener) + @_utils.inherit_doc(AsyncListener) async def aclose(self) -> None: return await self.__listener.aclose() - @_utils.inherit_doc(transports.AsyncListener) + @_utils.inherit_doc(AsyncListener) async def serve( self, handler: Callable[[AsyncTLSStreamTransport], Coroutine[Any, Any, None]], @@ -292,7 +312,7 @@ async def serve( logger = logging.getLogger(__name__) @functools.wraps(handler) - async def tls_handler_wrapper(stream: transports.AsyncStreamTransport, /) -> None: + async def tls_handler_wrapper(stream: AsyncStreamTransport, /) -> None: try: stream = await AsyncTLSStreamTransport.wrap( stream, @@ -318,12 +338,12 @@ async def tls_handler_wrapper(stream: transports.AsyncStreamTransport, /) -> Non await listener.serve(tls_handler_wrapper, task_group) - @_utils.inherit_doc(transports.AsyncListener) + @_utils.inherit_doc(AsyncListener) def backend(self) -> AsyncBackend: return self.__listener.backend() @property - @_utils.inherit_doc(transports.AsyncListener) + @_utils.inherit_doc(AsyncListener) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return { **self.__listener.extra_attributes, @@ -334,10 +354,10 @@ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: @dataclasses.dataclass(kw_only=True, eq=False, slots=True) class _IncomingDataReader: - transport: transports.AsyncStreamReadTransport + transport: AsyncStreamReadTransport max_size: Final[int] = 256 * 1024 # 256KiB - async def readinto(self, read_bio: _typing_ssl.MemoryBIO) -> int: + async def readinto(self, read_bio: MemoryBIO) -> int: data = await self.transport.recv(self.max_size) if data: return read_bio.write(data) @@ -350,7 +370,7 @@ def close(self) -> None: @dataclasses.dataclass(kw_only=True, eq=False, slots=True) class _BufferedIncomingDataReader(_IncomingDataReader): - transport: transports.AsyncBufferedStreamReadTransport + transport: AsyncBufferedStreamReadTransport buffer: bytearray | None = dataclasses.field(init=False) buffer_view: memoryview = dataclasses.field(init=False) @@ -358,7 +378,7 @@ def __post_init__(self) -> None: self.buffer = bytearray(self.max_size) self.buffer_view = memoryview(self.buffer) - async def readinto(self, read_bio: _typing_ssl.MemoryBIO) -> int: + async def readinto(self, read_bio: MemoryBIO) -> int: buffer = self.buffer_view nbytes = await self.transport.recv_into(buffer) if nbytes: diff --git a/src/easynetwork/lowlevel/api_sync/endpoints/datagram.py b/src/easynetwork/lowlevel/api_sync/endpoints/datagram.py index 77e583b7..8f1fdd3e 100644 --- a/src/easynetwork/lowlevel/api_sync/endpoints/datagram.py +++ b/src/easynetwork/lowlevel/api_sync/endpoints/datagram.py @@ -28,14 +28,14 @@ from collections.abc import Callable, Mapping from typing import Any, Generic -from .... import protocol as protocol_module from ...._typevars import _T_ReceivedPacket, _T_SentPacket from ....exceptions import DatagramProtocolParseError +from ....protocol import DatagramProtocol from ... import _utils -from ..transports import abc as transports +from ..transports.abc import BaseTransport, DatagramReadTransport, DatagramTransport, DatagramWriteTransport -class DatagramReceiverEndpoint(transports.BaseTransport, Generic[_T_ReceivedPacket]): +class DatagramReceiverEndpoint(BaseTransport, Generic[_T_ReceivedPacket]): """ A read-only communication endpoint based on unreliable packets of data. """ @@ -47,8 +47,8 @@ class DatagramReceiverEndpoint(transports.BaseTransport, Generic[_T_ReceivedPack def __init__( self, - transport: transports.DatagramReadTransport, - protocol: protocol_module.DatagramProtocol[Any, _T_ReceivedPacket], + transport: DatagramReadTransport, + protocol: DatagramProtocol[Any, _T_ReceivedPacket], ) -> None: """ Parameters: @@ -56,14 +56,14 @@ def __init__( protocol: The :term:`protocol object` to use. """ - if not isinstance(transport, transports.DatagramReadTransport): + if not isinstance(transport, DatagramReadTransport): raise TypeError(f"Expected a DatagramReadTransport object, got {transport!r}") - if not isinstance(protocol, protocol_module.DatagramProtocol): + if not isinstance(protocol, DatagramProtocol): raise TypeError(f"Expected a DatagramProtocol object, got {protocol!r}") self.__receiver: _DataReceiverImpl[_T_ReceivedPacket] = _DataReceiverImpl(transport, protocol) - self.__transport: transports.DatagramReadTransport = transport + self.__transport: DatagramReadTransport = transport def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: try: @@ -113,11 +113,12 @@ def recv_packet(self, *, timeout: float | None = None) -> _T_ReceivedPacket: return receiver.receive(timeout) @property + @_utils.inherit_doc(BaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes -class DatagramSenderEndpoint(transports.BaseTransport, Generic[_T_SentPacket]): +class DatagramSenderEndpoint(BaseTransport, Generic[_T_SentPacket]): """ A write-only communication endpoint based on unreliable packets of data. """ @@ -129,8 +130,8 @@ class DatagramSenderEndpoint(transports.BaseTransport, Generic[_T_SentPacket]): def __init__( self, - transport: transports.DatagramWriteTransport, - protocol: protocol_module.DatagramProtocol[_T_SentPacket, Any], + transport: DatagramWriteTransport, + protocol: DatagramProtocol[_T_SentPacket, Any], ) -> None: """ Parameters: @@ -138,14 +139,14 @@ def __init__( protocol: The :term:`protocol object` to use. """ - if not isinstance(transport, transports.DatagramWriteTransport): + if not isinstance(transport, DatagramWriteTransport): raise TypeError(f"Expected a DatagramWriteTransport object, got {transport!r}") - if not isinstance(protocol, protocol_module.DatagramProtocol): + if not isinstance(protocol, DatagramProtocol): raise TypeError(f"Expected a DatagramProtocol object, got {protocol!r}") self.__sender: _DataSenderImpl[_T_SentPacket] = _DataSenderImpl(transport, protocol) - self.__transport: transports.DatagramWriteTransport = transport + self.__transport: DatagramWriteTransport = transport def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: try: @@ -197,11 +198,12 @@ def send_packet(self, packet: _T_SentPacket, *, timeout: float | None = None) -> return sender.send(packet, timeout) @property + @_utils.inherit_doc(BaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes -class DatagramEndpoint(transports.BaseTransport, Generic[_T_SentPacket, _T_ReceivedPacket]): +class DatagramEndpoint(BaseTransport, Generic[_T_SentPacket, _T_ReceivedPacket]): """ A full-duplex communication endpoint based on unreliable packets of data. """ @@ -214,8 +216,8 @@ class DatagramEndpoint(transports.BaseTransport, Generic[_T_SentPacket, _T_Recei def __init__( self, - transport: transports.DatagramTransport, - protocol: protocol_module.DatagramProtocol[_T_SentPacket, _T_ReceivedPacket], + transport: DatagramTransport, + protocol: DatagramProtocol[_T_SentPacket, _T_ReceivedPacket], ) -> None: """ Parameters: @@ -223,15 +225,15 @@ def __init__( protocol: The :term:`protocol object` to use. """ - if not isinstance(transport, transports.DatagramTransport): + if not isinstance(transport, DatagramTransport): raise TypeError(f"Expected a DatagramTransport object, got {transport!r}") - if not isinstance(protocol, protocol_module.DatagramProtocol): + if not isinstance(protocol, DatagramProtocol): raise TypeError(f"Expected a DatagramProtocol object, got {protocol!r}") self.__sender: _DataSenderImpl[_T_SentPacket] = _DataSenderImpl(transport, protocol) self.__receiver: _DataReceiverImpl[_T_ReceivedPacket] = _DataReceiverImpl(transport, protocol) - self.__transport: transports.DatagramTransport = transport + self.__transport: DatagramTransport = transport def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: try: @@ -306,14 +308,15 @@ def recv_packet(self, *, timeout: float | None = None) -> _T_ReceivedPacket: return receiver.receive(timeout) @property + @_utils.inherit_doc(BaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes @dataclasses.dataclass(slots=True) class _DataSenderImpl(Generic[_T_SentPacket]): - transport: transports.DatagramWriteTransport - protocol: protocol_module.DatagramProtocol[_T_SentPacket, Any] + transport: DatagramWriteTransport + protocol: DatagramProtocol[_T_SentPacket, Any] def send(self, packet: _T_SentPacket, timeout: float) -> None: try: @@ -328,8 +331,8 @@ def send(self, packet: _T_SentPacket, timeout: float) -> None: @dataclasses.dataclass(slots=True) class _DataReceiverImpl(Generic[_T_ReceivedPacket]): - transport: transports.DatagramReadTransport - protocol: protocol_module.DatagramProtocol[Any, _T_ReceivedPacket] + transport: DatagramReadTransport + protocol: DatagramProtocol[Any, _T_ReceivedPacket] def receive(self, timeout: float) -> _T_ReceivedPacket: datagram = self.transport.recv(timeout) diff --git a/src/easynetwork/lowlevel/api_sync/endpoints/stream.py b/src/easynetwork/lowlevel/api_sync/endpoints/stream.py index 9d44483d..f92595ea 100644 --- a/src/easynetwork/lowlevel/api_sync/endpoints/stream.py +++ b/src/easynetwork/lowlevel/api_sync/endpoints/stream.py @@ -29,15 +29,21 @@ from collections.abc import Callable, Mapping from typing import Any, Generic, Literal, assert_never -from .... import protocol as protocol_module from ...._typevars import _T_ReceivedPacket, _T_SentPacket from ....exceptions import UnsupportedOperation +from ....protocol import StreamProtocol from ....warnings import ManualBufferAllocationWarning from ... import _stream, _utils -from ..transports import abc as transports +from ..transports.abc import ( + BaseTransport, + BufferedStreamReadTransport, + StreamReadTransport, + StreamTransport, + StreamWriteTransport, +) -class StreamReceiverEndpoint(transports.BaseTransport, Generic[_T_ReceivedPacket]): +class StreamReceiverEndpoint(BaseTransport, Generic[_T_ReceivedPacket]): """ A read-only communication endpoint based on continuous stream data transport. """ @@ -49,8 +55,8 @@ class StreamReceiverEndpoint(transports.BaseTransport, Generic[_T_ReceivedPacket def __init__( self, - transport: transports.StreamReadTransport, - protocol: protocol_module.StreamProtocol[Any, _T_ReceivedPacket], + transport: StreamReadTransport, + protocol: StreamProtocol[Any, _T_ReceivedPacket], max_recv_size: int, *, manual_buffer_allocation: Literal["try", "no", "force"] = "try", @@ -75,7 +81,7 @@ def __init__( :exc:`.ManualBufferAllocationWarning`. """ - if not isinstance(transport, transports.StreamReadTransport): + if not isinstance(transport, StreamReadTransport): raise TypeError(f"Expected a StreamReadTransport object, got {transport!r}") _check_max_recv_size_value(max_recv_size) _check_manual_buffer_allocation_value(manual_buffer_allocation) @@ -89,7 +95,7 @@ def __init__( manual_buffer_allocation_warning_stacklevel=manual_buffer_allocation_warning_stacklevel, ) - self.__transport: transports.StreamReadTransport = transport + self.__transport: StreamReadTransport = transport def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: try: @@ -146,11 +152,12 @@ def max_recv_size(self) -> int: return self.__receiver.max_recv_size @property + @_utils.inherit_doc(BaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes -class StreamSenderEndpoint(transports.BaseTransport, Generic[_T_SentPacket]): +class StreamSenderEndpoint(BaseTransport, Generic[_T_SentPacket]): """ A write-only communication endpoint based on continuous stream data transport. """ @@ -162,8 +169,8 @@ class StreamSenderEndpoint(transports.BaseTransport, Generic[_T_SentPacket]): def __init__( self, - transport: transports.StreamWriteTransport, - protocol: protocol_module.StreamProtocol[_T_SentPacket, Any], + transport: StreamWriteTransport, + protocol: StreamProtocol[_T_SentPacket, Any], ) -> None: """ Parameters: @@ -171,11 +178,11 @@ def __init__( protocol: The :term:`protocol object` to use. """ - if not isinstance(transport, transports.StreamWriteTransport): + if not isinstance(transport, StreamWriteTransport): raise TypeError(f"Expected a StreamWriteTransport object, got {transport!r}") self.__sender: _DataSenderImpl[_T_SentPacket] = _DataSenderImpl(transport, _stream.StreamDataProducer(protocol)) - self.__transport: transports.StreamReadTransport | transports.StreamWriteTransport = transport + self.__transport: StreamReadTransport | StreamWriteTransport = transport def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: try: @@ -229,11 +236,12 @@ def send_packet(self, packet: _T_SentPacket, *, timeout: float | None = None) -> return sender.send(packet, timeout) @property + @_utils.inherit_doc(BaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes -class StreamEndpoint(transports.BaseTransport, Generic[_T_SentPacket, _T_ReceivedPacket]): +class StreamEndpoint(BaseTransport, Generic[_T_SentPacket, _T_ReceivedPacket]): """ A full-duplex communication endpoint based on continuous stream data transport. """ @@ -247,8 +255,8 @@ class StreamEndpoint(transports.BaseTransport, Generic[_T_SentPacket, _T_Receive def __init__( self, - transport: transports.StreamTransport, - protocol: protocol_module.StreamProtocol[_T_SentPacket, _T_ReceivedPacket], + transport: StreamTransport, + protocol: StreamProtocol[_T_SentPacket, _T_ReceivedPacket], max_recv_size: int, *, manual_buffer_allocation: Literal["try", "no", "force"] = "try", @@ -273,7 +281,7 @@ def __init__( :exc:`.ManualBufferAllocationWarning`. """ - if not isinstance(transport, transports.StreamTransport): + if not isinstance(transport, StreamTransport): raise TypeError(f"Expected a StreamTransport object, got {transport!r}") _check_max_recv_size_value(max_recv_size) _check_manual_buffer_allocation_value(manual_buffer_allocation) @@ -288,7 +296,7 @@ def __init__( manual_buffer_allocation_warning_stacklevel=manual_buffer_allocation_warning_stacklevel, ) - self.__transport: transports.StreamTransport = transport + self.__transport: StreamTransport = transport self.__eof_sent: bool = False def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: @@ -392,13 +400,14 @@ def max_recv_size(self) -> int: return self.__receiver.max_recv_size @property + @_utils.inherit_doc(BaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__transport.extra_attributes @dataclasses.dataclass(slots=True) class _DataSenderImpl(Generic[_T_SentPacket]): - transport: transports.StreamWriteTransport + transport: StreamWriteTransport producer: _stream.StreamDataProducer[_T_SentPacket] def send(self, packet: _T_SentPacket, timeout: float) -> None: @@ -407,7 +416,7 @@ def send(self, packet: _T_SentPacket, timeout: float) -> None: @dataclasses.dataclass(slots=True) class _DataReceiverImpl(Generic[_T_ReceivedPacket]): - transport: transports.StreamReadTransport + transport: StreamReadTransport consumer: _stream.StreamDataConsumer[_T_ReceivedPacket] max_recv_size: int _eof_reached: bool = dataclasses.field(init=False, default=False) @@ -448,7 +457,7 @@ def receive(self, timeout: float) -> _T_ReceivedPacket: @dataclasses.dataclass(slots=True) class _BufferedReceiverImpl(Generic[_T_ReceivedPacket]): - transport: transports.BufferedStreamReadTransport + transport: BufferedStreamReadTransport consumer: _stream.BufferedStreamDataConsumer[_T_ReceivedPacket] _eof_reached: bool = dataclasses.field(init=False, default=False) @@ -491,8 +500,8 @@ def receive(self, timeout: float) -> _T_ReceivedPacket: def _get_receiver( - transport: transports.StreamReadTransport, - protocol: protocol_module.StreamProtocol[Any, _T_ReceivedPacket], + transport: StreamReadTransport, + protocol: StreamProtocol[Any, _T_ReceivedPacket], *, max_recv_size: int, manual_buffer_allocation: Literal["try", "no", "force"], @@ -505,7 +514,7 @@ def _get_receiver( case "try" | "force": try: buffered_consumer = _stream.BufferedStreamDataConsumer(protocol, max_recv_size) - if not isinstance(transport, transports.BufferedStreamReadTransport): + if not isinstance(transport, BufferedStreamReadTransport): msg = f"The transport implementation {transport!r} does not implement BufferedStreamReadTransport interface" if manual_buffer_allocation == "try": warnings.warn( diff --git a/src/easynetwork/lowlevel/api_sync/transports/base_selector.py b/src/easynetwork/lowlevel/api_sync/transports/base_selector.py index 30795843..2794bd29 100644 --- a/src/easynetwork/lowlevel/api_sync/transports/base_selector.py +++ b/src/easynetwork/lowlevel/api_sync/transports/base_selector.py @@ -102,6 +102,24 @@ def _retry( callback: Callable[[], _T_Return], timeout: float, ) -> tuple[_T_Return, float]: + """ + Calls `callback` without argument and returns the output. + + If the callable raises :class:`WouldBlockOnRead` or :class:`WouldBlockOnWrite`, waits for ``fileno`` to be + available for reading or writing respectively, and retries to call the callback. + + Parameters: + callback: the function to call. + timeout: the maximum amount of seconds to wait for the file descriptor to be available. + + Raises: + TimeoutError: timed out + + Returns: + a tuple with the result of the callback and the timeout which is deduced from the waited time. + + :meta public: + """ timeout = _utils.validate_timeout_delay(timeout, positive_check=True) retry_interval = self._retry_interval event: int @@ -298,7 +316,6 @@ def send_noblock(self, data: bytes | bytearray | memoryview) -> None: Parameters: data: the bytes to send. - timeout: the allowed time (in seconds) for blocking operations. Can be set to :data:`math.inf`. Raises: WouldBlockOnRead: the operation would block when reading the pipe. diff --git a/src/easynetwork/lowlevel/api_sync/transports/socket.py b/src/easynetwork/lowlevel/api_sync/transports/socket.py index 0ad51b72..2b175d63 100644 --- a/src/easynetwork/lowlevel/api_sync/transports/socket.py +++ b/src/easynetwork/lowlevel/api_sync/transports/socket.py @@ -44,7 +44,7 @@ from . import base_selector if TYPE_CHECKING: - import ssl as _typing_ssl + from ssl import SSLContext, SSLSession, SSLSocket from _typeshed import WriteableBuffer @@ -62,6 +62,10 @@ def _close_stream_socket(sock: socket.socket) -> None: class SocketStreamTransport(base_selector.SelectorStreamTransport, base_selector.SelectorBufferedStreamReadTransport): + """ + A stream data transport implementation which wraps a stream :class:`~socket.socket`. + """ + __slots__ = ("__socket",) def __init__( @@ -71,6 +75,13 @@ def __init__( *, selector_factory: Callable[[], selectors.BaseSelector] | None = None, ) -> None: + """ + Parameters: + sock: The :data:`~socket.SOCK_STREAM` socket to wrap. + retry_interval: The maximum wait time to wait for a blocking operation before retrying. + Set it to :data:`math.inf` to disable this feature. + selector_factory: If given, the callable object to use to create a new :class:`selectors.BaseSelector` instance. + """ super().__init__(retry_interval=retry_interval, selector_factory=selector_factory) _utils.check_socket_no_ssl(sock) @@ -162,12 +173,16 @@ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: class SSLStreamTransport(base_selector.SelectorStreamTransport, base_selector.SelectorBufferedStreamReadTransport): + """ + A stream data transport implementation which wraps a stream :class:`~socket.socket`. + """ + __slots__ = ("__socket", "__ssl_shutdown_timeout", "__standard_compatible") def __init__( self, sock: socket.socket, - ssl_context: _typing_ssl.SSLContext, + ssl_context: SSLContext, retry_interval: float, *, handshake_timeout: float | None = None, @@ -175,9 +190,28 @@ def __init__( server_side: bool | None = None, server_hostname: str | None = None, standard_compatible: bool = True, - session: _typing_ssl.SSLSession | None = None, + session: SSLSession | None = None, selector_factory: Callable[[], selectors.BaseSelector] | None = None, ) -> None: + """ + Parameters: + sock: The :data:`~socket.SOCK_STREAM` socket to wrap. + ssl_context: a :class:`ssl.SSLContext` object to use to create the transport. + retry_interval: The maximum wait time to wait for a blocking operation before retrying. + Set it to :data:`math.inf` to disable this feature. + handshake_timeout: The time in seconds to wait for the TLS handshake to complete before aborting the connection. + ``60.0`` seconds if :data:`None` (default). + shutdown_timeout: The time in seconds to wait for the SSL shutdown to complete before aborting the connection. + ``30.0`` seconds if :data:`None` (default). + server_side: Indicates whether we are a client or a server for the handshake part. If it is set to :data:`None`, + it is deduced according to `server_hostname`. + server_hostname: sets or overrides the hostname that the target server's certificate will be matched against. + If `server_side` is :data:`True`, you must pass a value for `server_hostname`. + standard_compatible: If :data:`False`, skip the closing handshake when closing the connection, + and don't raise an exception if the peer does the same. + session: If an SSL session already exits, use it insead. + selector_factory: If given, the callable object to use to create a new :class:`selectors.BaseSelector` instance. + """ super().__init__(retry_interval=retry_interval, selector_factory=selector_factory) if handshake_timeout is None: @@ -194,7 +228,7 @@ def __init__( raise ValueError("A 'SOCK_STREAM' socket is expected") if server_side is None: server_side = not server_hostname - self.__socket: _typing_ssl.SSLSocket = ssl_context.wrap_socket( + self.__socket: SSLSocket = ssl_context.wrap_socket( sock, server_side=server_side, server_hostname=server_hostname, @@ -286,6 +320,10 @@ def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: class SocketDatagramTransport(base_selector.SelectorDatagramTransport): + """ + A datagram transport implementation which wraps a datagram :class:`~socket.socket`. + """ + __slots__ = ("__socket", "__max_datagram_size") def __init__( @@ -296,6 +334,14 @@ def __init__( max_datagram_size: int = constants.MAX_DATAGRAM_BUFSIZE, selector_factory: Callable[[], selectors.BaseSelector] | None = None, ) -> None: + """ + Parameters: + sock: The :data:`~socket.SOCK_DGRAM` socket to wrap. + retry_interval: The maximum wait time to wait for a blocking operation before retrying. + Set it to :data:`math.inf` to disable this feature. + max_datagram_size: The maximum packet size supported by :manpage:`recvfrom(2)` for the current socket. + selector_factory: If given, the callable object to use to create a new :class:`selectors.BaseSelector` instance. + """ super().__init__(retry_interval=retry_interval, selector_factory=selector_factory) if max_datagram_size <= 0: diff --git a/src/easynetwork/lowlevel/socket.py b/src/easynetwork/lowlevel/socket.py index ce31169f..4562f374 100644 --- a/src/easynetwork/lowlevel/socket.py +++ b/src/easynetwork/lowlevel/socket.py @@ -63,7 +63,7 @@ from ._final import runtime_final_class if TYPE_CHECKING: - import ssl as _typing_ssl + from ssl import SSLContext, SSLObject, SSLSocket, _PeerCertRetDictType _P = ParamSpec("_P") _T_Return = TypeVar("_T_Return") @@ -101,10 +101,10 @@ class INETSocketAttribute(SocketAttribute): class TLSAttribute(typed_attr.TypedAttributeSet): __slots__ = () - sslcontext: _typing_ssl.SSLContext = typed_attr.typed_attribute() + sslcontext: SSLContext = typed_attr.typed_attribute() """:class:`ssl.SSLContext` instance.""" - peercert: _typing_ssl._PeerCertRetDictType = typed_attr.typed_attribute() + peercert: _PeerCertRetDictType = typed_attr.typed_attribute() """peer certificate; result of :meth:`ssl.SSLSocket.getpeercert`.""" cipher: tuple[str, str, int] = typed_attr.typed_attribute() @@ -610,7 +610,7 @@ def _get_socket_extra(sock: ISocket, *, wrap_in_proxy: bool = True) -> dict[Any, } -def _get_tls_extra(ssl_object: _typing_ssl.SSLObject | _typing_ssl.SSLSocket) -> dict[Any, Callable[[], Any]]: +def _get_tls_extra(ssl_object: SSLObject | SSLSocket) -> dict[Any, Callable[[], Any]]: return { TLSAttribute.sslcontext: lambda: ssl_object.context, TLSAttribute.peercert: lambda: _value_or_lookup_error(ssl_object.getpeercert()), diff --git a/src/easynetwork/serializers/pickle.py b/src/easynetwork/serializers/pickle.py index 91eeb0cf..ab9fe160 100644 --- a/src/easynetwork/serializers/pickle.py +++ b/src/easynetwork/serializers/pickle.py @@ -32,7 +32,7 @@ from .abc import AbstractPacketSerializer if TYPE_CHECKING: - import pickle as _typing_pickle + from pickle import Pickler, Unpickler def _get_default_pickler_protocol() -> int: @@ -78,8 +78,8 @@ def __init__( pickler_config: PicklerConfig | None = None, unpickler_config: UnpicklerConfig | None = None, *, - pickler_cls: type[_typing_pickle.Pickler] | None = None, - unpickler_cls: type[_typing_pickle.Unpickler] | None = None, + pickler_cls: type[Pickler] | None = None, + unpickler_cls: type[Unpickler] | None = None, pickler_optimize: bool = False, debug: bool = False, ) -> None: @@ -101,8 +101,8 @@ def __init__( import pickletools self.__optimize = pickletools.optimize - self.__pickler_cls: Callable[[IO[bytes]], _typing_pickle.Pickler] - self.__unpickler_cls: Callable[[IO[bytes]], _typing_pickle.Unpickler] + self.__pickler_cls: Callable[[IO[bytes]], pickle.Pickler] + self.__unpickler_cls: Callable[[IO[bytes]], pickle.Unpickler] if pickler_config is None: pickler_config = PicklerConfig() diff --git a/src/easynetwork/serializers/struct.py b/src/easynetwork/serializers/struct.py index 45d1d1b7..ec838364 100644 --- a/src/easynetwork/serializers/struct.py +++ b/src/easynetwork/serializers/struct.py @@ -28,7 +28,7 @@ from .base_stream import FixedSizePacketSerializer if TYPE_CHECKING: - import struct as _typing_struct + from struct import Struct from _typeshed import ReadableBuffer, SupportsKeysAndGetItem @@ -84,14 +84,14 @@ def __init__(self, format: str, *, debug: bool = False) -> None: format: The :class:`struct.Struct` format definition string. debug: If :data:`True`, add information to :exc:`.DeserializeError` via the ``error_info`` attribute. """ - from struct import Struct, error + import struct as struct_module if format and format[0] not in _ENDIANNESS_CHARACTERS: format = f"!{format}" # network byte order - struct = Struct(format) + struct = struct_module.Struct(format) super().__init__(struct.size, debug=debug) - self.__s: _typing_struct.Struct = struct - self.__error_cls = error + self.__s: Struct = struct + self.__error_cls = struct_module.error @abstractmethod def iter_values(self, packet: _T_SentDTOPacket, /) -> Iterable[Any]: @@ -180,7 +180,7 @@ def deserialize_from_buffer(self, data: ReadableBuffer) -> _T_ReceivedDTOPacket: @property @final - def struct(self) -> _typing_struct.Struct: + def struct(self) -> Struct: """The underlying :class:`struct.Struct` instance. Read-only attribute.""" return self.__s diff --git a/src/easynetwork/serializers/wrapper/compressor.py b/src/easynetwork/serializers/wrapper/compressor.py index 871d1fda..2c04b5fe 100644 --- a/src/easynetwork/serializers/wrapper/compressor.py +++ b/src/easynetwork/serializers/wrapper/compressor.py @@ -35,8 +35,8 @@ from ..base_stream import _wrap_generic_buffered_incremental_deserialize, _wrap_generic_incremental_deserialize if TYPE_CHECKING: - import bz2 as _typing_bz2 - import zlib as _typing_zlib + from bz2 import BZ2Compressor, BZ2Decompressor + from zlib import _Compress as ZLibCompress, _Decompress as ZLibDecompress from _typeshed import ReadableBuffer @@ -285,14 +285,14 @@ def __init__( self.__decompressor_factory = bz2.BZ2Decompressor @final - def new_compressor_stream(self) -> _typing_bz2.BZ2Compressor: + def new_compressor_stream(self) -> BZ2Compressor: """ See :meth:`.AbstractCompressorSerializer.new_compressor_stream` documentation for details. """ return self.__compressor_factory(self.__compresslevel) @final - def new_decompressor_stream(self) -> _typing_bz2.BZ2Decompressor: + def new_decompressor_stream(self) -> BZ2Decompressor: """ See :meth:`.AbstractCompressorSerializer.new_decompressor_stream` documentation for details. """ @@ -327,14 +327,14 @@ def __init__( self.__decompressor_factory = zlib.decompressobj @final - def new_compressor_stream(self) -> _typing_zlib._Compress: + def new_compressor_stream(self) -> ZLibCompress: """ See :meth:`.AbstractCompressorSerializer.new_compressor_stream` documentation for details. """ return self.__compressor_factory(self.__compresslevel) @final - def new_decompressor_stream(self) -> _typing_zlib._Decompress: + def new_decompressor_stream(self) -> ZLibDecompress: """ See :meth:`.AbstractCompressorSerializer.new_decompressor_stream` documentation for details. """ diff --git a/src/easynetwork/servers/async_tcp.py b/src/easynetwork/servers/async_tcp.py index 092bbb4c..bca77874 100644 --- a/src/easynetwork/servers/async_tcp.py +++ b/src/easynetwork/servers/async_tcp.py @@ -49,7 +49,7 @@ from .misc import build_lowlevel_stream_server_handler if TYPE_CHECKING: - import ssl as _typing_ssl + from ssl import SSLContext class AsyncTCPNetworkServer(AbstractAsyncNetworkServer, Generic[_T_Request, _T_Response]): @@ -82,7 +82,7 @@ def __init__( request_handler: AsyncStreamRequestHandler[_T_Request, _T_Response], backend: AsyncBackend | BuiltinAsyncBackendToken | None = None, *, - ssl: _typing_ssl.SSLContext | None = None, + ssl: SSLContext | None = None, ssl_handshake_timeout: float | None = None, ssl_shutdown_timeout: float | None = None, ssl_standard_compatible: bool | None = None, @@ -222,7 +222,7 @@ async def __create_ssl_over_tcp_listeners( host: str | Sequence[str] | None, port: int, backlog: int, - ssl_context: _typing_ssl.SSLContext, + ssl_context: SSLContext, *, ssl_handshake_timeout: float | None, ssl_shutdown_timeout: float | None, @@ -399,7 +399,7 @@ async def __serve( @contextlib.asynccontextmanager async def __client_initializer( self, - lowlevel_client: _stream_server.AsyncStreamClient[_T_Response], + lowlevel_client: _stream_server.Client[_T_Response], ) -> AsyncIterator[AsyncStreamClient[_T_Response] | None]: async with contextlib.AsyncExitStack() as client_exit_stack: self.__attach_server() @@ -422,7 +422,7 @@ async def __client_initializer( client_exit_stack.callback(self.__set_socket_linger_if_not_closed, lowlevel_client.extra(INETSocketAttribute.socket)) logger: logging.Logger = self.__logger - client = _ConnectedClientAPI(self.__backend, client_address, lowlevel_client) + client = _ConnectedClientAPI(client_address, lowlevel_client) del lowlevel_client @@ -516,13 +516,12 @@ class _ConnectedClientAPI(AsyncStreamClient[_T_Response]): def __init__( self, - backend: AsyncBackend, address: SocketAddress, - client: _stream_server.AsyncStreamClient[_T_Response], + client: _stream_server.Client[_T_Response], ) -> None: - self.__client: _stream_server.AsyncStreamClient[_T_Response] = client + self.__client: _stream_server.Client[_T_Response] = client self.__closing: bool = False - self.__send_lock = backend.create_lock() + self.__send_lock = client.backend().create_lock() self.__proxy: SocketProxy = SocketProxy(client.extra(INETSocketAttribute.socket)) self.__address: SocketAddress = address self.__extra_attributes_cache: Mapping[Any, Callable[[], Any]] | None = None diff --git a/src/easynetwork/servers/misc.py b/src/easynetwork/servers/misc.py index 1c31fb66..571af44a 100644 --- a/src/easynetwork/servers/misc.py +++ b/src/easynetwork/servers/misc.py @@ -38,13 +38,13 @@ def build_lowlevel_stream_server_handler( initializer: Callable[ - [_lowlevel_stream_server.AsyncStreamClient[_T_Response]], + [_lowlevel_stream_server.Client[_T_Response]], AbstractAsyncContextManager[AsyncStreamClient[_T_Response] | None], ], request_handler: AsyncStreamRequestHandler[_T_Request, _T_Response], *, logger: logging.Logger | None = None, -) -> Callable[[_lowlevel_stream_server.AsyncStreamClient[_T_Response]], AsyncGenerator[float | None, _T_Request]]: +) -> Callable[[_lowlevel_stream_server.Client[_T_Response]], AsyncGenerator[float | None, _T_Request]]: """ Creates an :term:`asynchronous generator` function, usable by :meth:`.AsyncStreamServer.serve`, from an :class:`AsyncStreamRequestHandler`. @@ -64,7 +64,7 @@ def build_lowlevel_stream_server_handler( logger = logging.getLogger(__name__) async def handler( - lowlevel_client: _lowlevel_stream_server.AsyncStreamClient[_T_Response], / + lowlevel_client: _lowlevel_stream_server.Client[_T_Response], / ) -> AsyncGenerator[float | None, _T_Request]: async with initializer(lowlevel_client) as client, AsyncExitStack() as request_handler_exit_stack: del lowlevel_client diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_stream.py b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_stream.py index 7d8a5a21..efd42603 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_stream.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_stream.py @@ -10,7 +10,7 @@ from easynetwork.exceptions import UnsupportedOperation from easynetwork.lowlevel._stream import StreamDataProducer -from easynetwork.lowlevel.api_async.servers.stream import AsyncStreamClient, AsyncStreamServer +from easynetwork.lowlevel.api_async.servers.stream import AsyncStreamServer, Client from easynetwork.lowlevel.api_async.transports.abc import ( AsyncBufferedStreamReadTransport, AsyncListener, @@ -61,13 +61,13 @@ def client( mock_stream_transport: MagicMock, mock_stream_protocol: MagicMock, client_exit_stack: contextlib.AsyncExitStack, - ) -> AsyncStreamClient[Any]: - return AsyncStreamClient(mock_stream_transport, StreamDataProducer(mock_stream_protocol), client_exit_stack) + ) -> Client[Any]: + return Client(mock_stream_transport, StreamDataProducer(mock_stream_protocol), client_exit_stack) @pytest.mark.parametrize("transport_closed", [False, True]) async def test____is_closing____default( self, - client: AsyncStreamClient[Any], + client: Client[Any], mock_stream_transport: MagicMock, transport_closed: bool, ) -> None: @@ -84,7 +84,7 @@ async def test____is_closing____default( async def test____aclose____default( self, - client: AsyncStreamClient[Any], + client: Client[Any], mock_stream_transport: MagicMock, client_exit_stack: contextlib.AsyncExitStack, mocker: MockerFixture, @@ -103,7 +103,7 @@ async def test____aclose____default( async def test____extra_attributes____default( self, - client: AsyncStreamClient[Any], + client: Client[Any], mock_stream_transport: MagicMock, mocker: MockerFixture, ) -> None: @@ -118,7 +118,7 @@ async def test____extra_attributes____default( async def test____send_packet____send_bytes_to_transport( self, - client: AsyncStreamClient[Any], + client: Client[Any], mock_stream_transport: MagicMock, mock_stream_protocol: MagicMock, mocker: MockerFixture, diff --git a/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_selector.py b/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_selector.py index 247d086e..d0bca6ad 100644 --- a/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_selector.py +++ b/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_selector.py @@ -57,7 +57,7 @@ def retry_interval(request: pytest.FixtureRequest) -> float: @pytest.fixture @staticmethod def mock_transport(retry_interval: float, mock_selector: MagicMock, mocker: MockerFixture) -> MagicMock: - mock_transport = mocker.NonCallableMagicMock(spec=SelectorStreamTransport) + mock_transport = mocker.NonCallableMagicMock(spec=SelectorBaseTransport) SelectorBaseTransport.__init__(mock_transport, retry_interval, lambda: mock_selector) return mock_transport