From 1288c09ee5991eb523df1240ba56d526bea08da4 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francis=20Clairicia-Rose-Claire-Jos=C3=A9phine?= Date: Sun, 22 Sep 2024 15:02:48 +0200 Subject: [PATCH] Low-level API: Added stapled async/blocking transports (#352) --- docs/source/api/lowlevel/async/servers.rst | 1 + docs/source/api/lowlevel/async/transports.rst | 19 + docs/source/api/lowlevel/sync/transports.rst | 21 +- .../api_async/transports/composite.py | 249 ++++++++++++ .../lowlevel/api_sync/transports/composite.py | 199 +++++++++ tests/unit_test/_utils.py | 21 +- tests/unit_test/test_async/mock_tools.py | 1 + .../test_transports/test_composite.py | 376 ++++++++++++++++++ tests/unit_test/test_sync/mock_tools.py | 1 + .../test_transports/test_composite.py | 363 +++++++++++++++++ 10 files changed, 1249 insertions(+), 2 deletions(-) create mode 100644 src/easynetwork/lowlevel/api_async/transports/composite.py create mode 100644 src/easynetwork/lowlevel/api_sync/transports/composite.py create mode 100644 tests/unit_test/test_async/test_lowlevel_api/test_transports/test_composite.py create mode 100644 tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_composite.py diff --git a/docs/source/api/lowlevel/async/servers.rst b/docs/source/api/lowlevel/async/servers.rst index 4a02c71e..299e8bf4 100644 --- a/docs/source/api/lowlevel/async/servers.rst +++ b/docs/source/api/lowlevel/async/servers.rst @@ -27,6 +27,7 @@ Datagram Servers :members: .. autotypevar:: easynetwork.lowlevel.api_async.servers.datagram::_T_Address + :no-index: .. autoclass:: DatagramClientContext() :no-index: diff --git a/docs/source/api/lowlevel/async/transports.rst b/docs/source/api/lowlevel/async/transports.rst index 3b6b79df..d84ff162 100644 --- a/docs/source/api/lowlevel/async/transports.rst +++ b/docs/source/api/lowlevel/async/transports.rst @@ -24,6 +24,25 @@ SSL/TLS Support :members: +Composite Data Transports +========================= + +.. automodule:: easynetwork.lowlevel.api_async.transports.composite + :members: + +.. autotypevar:: _T_SendStreamTransport + :no-index: + +.. autotypevar:: _T_ReceiveStreamTransport + :no-index: + +.. autotypevar:: _T_SendDatagramTransport + :no-index: + +.. autotypevar:: _T_ReceiveDatagramTransport + :no-index: + + Miscellaneous ============= diff --git a/docs/source/api/lowlevel/sync/transports.rst b/docs/source/api/lowlevel/sync/transports.rst index 8a5178a7..ce859251 100644 --- a/docs/source/api/lowlevel/sync/transports.rst +++ b/docs/source/api/lowlevel/sync/transports.rst @@ -17,13 +17,32 @@ Abstract Base Classes :special-members: __enter__, __exit__ -``selectors``-based transports +``selectors``-based Transports ============================== .. automodule:: easynetwork.lowlevel.api_sync.transports.base_selector :members: +Composite Data Transports +========================= + +.. automodule:: easynetwork.lowlevel.api_sync.transports.composite + :members: + +.. autotypevar:: _T_SendStreamTransport + :no-index: + +.. autotypevar:: _T_ReceiveStreamTransport + :no-index: + +.. autotypevar:: _T_SendDatagramTransport + :no-index: + +.. autotypevar:: _T_ReceiveDatagramTransport + :no-index: + + Socket Transport Implementations ================================ diff --git a/src/easynetwork/lowlevel/api_async/transports/composite.py b/src/easynetwork/lowlevel/api_async/transports/composite.py new file mode 100644 index 00000000..e3c6c318 --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/transports/composite.py @@ -0,0 +1,249 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""Low-level asynchronous transport composite module. + +.. versionadded:: 1.1 +""" + +from __future__ import annotations + +__all__ = [ + "AsyncStapledDatagramTransport", + "AsyncStapledStreamTransport", +] + +import contextlib +from collections.abc import AsyncIterator, Callable, Iterable, Mapping +from dataclasses import dataclass, field as dataclass_field +from typing import TYPE_CHECKING, Any, Generic, TypeVar, final + +from ... import _utils +from ..._final import runtime_final_class +from . import abc as _transports +from .utils import aclose_forcefully + +if TYPE_CHECKING: + from _typeshed import WriteableBuffer + + from ..backend.abc import AsyncBackend + + +_T_SendStreamTransport = TypeVar("_T_SendStreamTransport", bound=_transports.AsyncStreamWriteTransport) +_T_ReceiveStreamTransport = TypeVar("_T_ReceiveStreamTransport", bound=_transports.AsyncStreamReadTransport) + +_T_SendDatagramTransport = TypeVar("_T_SendDatagramTransport", bound=_transports.AsyncDatagramWriteTransport) +_T_ReceiveDatagramTransport = TypeVar("_T_ReceiveDatagramTransport", bound=_transports.AsyncDatagramReadTransport) + + +@final +@runtime_final_class +@dataclass(frozen=True, slots=True) +class AsyncStapledStreamTransport(_transports.AsyncStreamTransport, Generic[_T_SendStreamTransport, _T_ReceiveStreamTransport]): + """ + An asynchronous continous stream data transport that merges two transports. + + Extra attributes will be provided from both transports, with the receive stream providing the values in case of a conflict. + + .. versionadded:: 1.1 + """ + + send_transport: _T_SendStreamTransport + """The write part of the transport.""" + + receive_transport: _T_ReceiveStreamTransport + """The read part of the transport.""" + + _backend: AsyncBackend = dataclass_field(init=False) + + def __post_init__(self) -> None: + backend = _check_stapled_transports_consistency(self.send_transport, self.receive_transport) + object.__setattr__(self, "_backend", backend) + + async def aclose(self) -> None: + """ + Closes both transports. + + Warning: + :meth:`aclose` performs a graceful close, waiting for the transports to close. + + If :meth:`aclose` is cancelled, the transports are closed using :func:`.aclose_forcefully`. + """ + await _close_stapled_transports(self.send_transport, self.receive_transport) + + def is_closing(self) -> bool: + """ + Checks if both the transports are closed or in the process of being closed. + + Returns: + :data:`True` if the transports are closing. + """ + return self.send_transport.is_closing() and self.receive_transport.is_closing() + + async def recv(self, bufsize: int) -> bytes: + """ + Calls :meth:`self.receive_transport.recv() <.AsyncStreamReadTransport.recv>`. + """ + return await self.receive_transport.recv(bufsize) + + async def recv_into(self, buffer: WriteableBuffer) -> int: + """ + Calls :meth:`self.receive_transport.recv_into() <.AsyncStreamReadTransport.recv_into>`. + """ + return await self.receive_transport.recv_into(buffer) + + async def send_all(self, data: bytes | bytearray | memoryview) -> None: + """ + Calls :meth:`self.send_transport.send_all() <.AsyncStreamWriteTransport.send_all>`. + """ + return await self.send_transport.send_all(data) + + async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None: + """ + Calls :meth:`self.send_transport.send_all_from_iterable() <.AsyncStreamWriteTransport.send_all_from_iterable>`. + """ + return await self.send_transport.send_all_from_iterable(iterable_of_data) + + async def send_eof(self) -> None: + """ + Closes the write end of the stream after the buffered write data is flushed. + + If :meth:`self.send_transport.send_eof() <.AsyncStreamTransport.send_eof>` then this calls it. Otherwise, this calls + :meth:`self.send_transport.aclose() <.AsyncBaseTransport.aclose>`. + + Note: + This method handles the case where :meth:`self.send_transport.send_eof() <.AsyncStreamTransport.send_eof>` + raises :exc:`NotImplementedError` or :exc:`.UnsupportedOperation`; + :meth:`self.send_transport.aclose() <.AsyncBaseTransport.aclose>` will be called as a fallback. + """ + try: + if not isinstance(self.send_transport, _transports.AsyncStreamTransport): + raise NotImplementedError("not a full-duplex transport") + # send_eof() can raise UnsupportedOperation, subclass of NotImplementedError + await self.send_transport.send_eof() + except NotImplementedError: + await self.send_transport.aclose() + + @_utils.inherit_doc(_transports.AsyncStreamTransport) + def backend(self) -> AsyncBackend: + return self._backend + + @property + @_utils.inherit_doc(_transports.AsyncBaseTransport) + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self.send_transport.extra_attributes, + **self.receive_transport.extra_attributes, + } + + +@final +@runtime_final_class +@dataclass(frozen=True, slots=True) +class AsyncStapledDatagramTransport( + _transports.AsyncDatagramTransport, + Generic[_T_SendDatagramTransport, _T_ReceiveDatagramTransport], +): + """ + An asynchronous transport of unreliable packets of data that merges two transports. + + Extra attributes will be provided from both transports, with the receive stream providing the values in case of a conflict. + + .. versionadded:: 1.1 + """ + + send_transport: _T_SendDatagramTransport + """The write part of the transport.""" + + receive_transport: _T_ReceiveDatagramTransport + """The read part of the transport.""" + + _backend: AsyncBackend = dataclass_field(init=False) + + def __post_init__(self) -> None: + backend = _check_stapled_transports_consistency(self.send_transport, self.receive_transport) + object.__setattr__(self, "_backend", backend) + + async def aclose(self) -> None: + """ + Closes both transports. + + Warning: + :meth:`aclose` performs a graceful close, waiting for the transports to close. + + If :meth:`aclose` is cancelled, the transports are closed using :func:`.aclose_forcefully`. + """ + await _close_stapled_transports(self.send_transport, self.receive_transport) + + def is_closing(self) -> bool: + """ + Checks if both the transports are closed or in the process of being closed. + + Returns: + :data:`True` if the transports are closing. + """ + return self.send_transport.is_closing() and self.receive_transport.is_closing() + + async def recv(self) -> bytes: + """ + Calls :meth:`self.receive_transport.recv() <.AsyncDatagramReadTransport.recv>`. + """ + return await self.receive_transport.recv() + + async def send(self, data: bytes | bytearray | memoryview) -> None: + """ + Calls :meth:`self.send_transport.send() <.AsyncDatagramWriteTransport.send>`. + """ + return await self.send_transport.send(data) + + @_utils.inherit_doc(_transports.AsyncDatagramTransport) + def backend(self) -> AsyncBackend: + return self._backend + + @property + @_utils.inherit_doc(_transports.AsyncBaseTransport) + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self.send_transport.extra_attributes, + **self.receive_transport.extra_attributes, + } + + +def _check_stapled_transports_consistency( + send_transport: _transports.AsyncBaseTransport, + receive_transport: _transports.AsyncBaseTransport, +) -> AsyncBackend: + if (backend := send_transport.backend()) is not receive_transport.backend(): + raise RuntimeError("transport backend inconsistency") + return backend + + +async def _close_stapled_transports( + send_transport: _transports.AsyncBaseTransport, + receive_transport: _transports.AsyncBaseTransport, +) -> None: + async with contextlib.AsyncExitStack() as exit_stack: + await exit_stack.enter_async_context(_try_graceful_close(receive_transport)) + await exit_stack.enter_async_context(_try_graceful_close(send_transport)) + + +@contextlib.asynccontextmanager +async def _try_graceful_close(transport: _transports.AsyncBaseTransport) -> AsyncIterator[None]: + try: + yield + except BaseException: + await aclose_forcefully(transport) + raise + else: + await transport.aclose() diff --git a/src/easynetwork/lowlevel/api_sync/transports/composite.py b/src/easynetwork/lowlevel/api_sync/transports/composite.py new file mode 100644 index 00000000..8b6c10c4 --- /dev/null +++ b/src/easynetwork/lowlevel/api_sync/transports/composite.py @@ -0,0 +1,199 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""Low-level synchronous transport composite module. + +.. versionadded:: 1.1 +""" + +from __future__ import annotations + +__all__ = [ + "StapledDatagramTransport", + "StapledStreamTransport", +] + +from collections.abc import Callable, Iterable, Mapping +from dataclasses import dataclass +from typing import TYPE_CHECKING, Any, Generic, TypeVar, final + +from ... import _utils +from ..._final import runtime_final_class +from . import abc as _transports + +if TYPE_CHECKING: + from _typeshed import WriteableBuffer + + +_T_SendStreamTransport = TypeVar("_T_SendStreamTransport", bound=_transports.StreamWriteTransport) +_T_ReceiveStreamTransport = TypeVar("_T_ReceiveStreamTransport", bound=_transports.StreamReadTransport) + +_T_SendDatagramTransport = TypeVar("_T_SendDatagramTransport", bound=_transports.DatagramWriteTransport) +_T_ReceiveDatagramTransport = TypeVar("_T_ReceiveDatagramTransport", bound=_transports.DatagramReadTransport) + + +@final +@runtime_final_class +@dataclass(frozen=True, slots=True) +class StapledStreamTransport(_transports.StreamTransport, Generic[_T_SendStreamTransport, _T_ReceiveStreamTransport]): + """ + A continous stream data transport that merges two transports. + + Extra attributes will be provided from both transports, with the receive stream providing the values in case of a conflict. + + .. versionadded:: 1.1 + """ + + send_transport: _T_SendStreamTransport + """The write part of the transport.""" + + receive_transport: _T_ReceiveStreamTransport + """The read part of the transport.""" + + def close(self) -> None: + """ + Closes both transports. + """ + _close_stapled_transports(self.send_transport, self.receive_transport) + + def is_closed(self) -> bool: + """ + Checks if :meth:`close` has been called on both transports. + + Returns: + :data:`True` if the transports are closed. + """ + return self.send_transport.is_closed() and self.receive_transport.is_closed() + + def recv(self, bufsize: int, timeout: float) -> bytes: + """ + Calls :meth:`self.receive_transport.recv() <.StreamReadTransport.recv>`. + """ + return self.receive_transport.recv(bufsize, timeout) + + def recv_into(self, buffer: WriteableBuffer, timeout: float) -> int: + """ + Calls :meth:`self.receive_transport.recv_into() <.StreamReadTransport.recv_into>`. + """ + return self.receive_transport.recv_into(buffer, timeout) + + def send(self, data: bytes | bytearray | memoryview, timeout: float) -> int: + """ + Calls :meth:`self.send_transport.send() <.StreamWriteTransport.send>`. + """ + return self.send_transport.send(data, timeout) + + def send_all(self, data: bytes | bytearray | memoryview, timeout: float) -> None: + """ + Calls :meth:`self.send_transport.send_all() <.StreamWriteTransport.send_all>`. + """ + return self.send_transport.send_all(data, timeout) + + def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview], timeout: float) -> None: + """ + Calls :meth:`self.send_transport.send_all_from_iterable() <.StreamWriteTransport.send_all_from_iterable>`. + """ + return self.send_transport.send_all_from_iterable(iterable_of_data, timeout) + + def send_eof(self) -> None: + """ + Closes the write end of the stream after the buffered write data is flushed. + + If :meth:`self.send_transport.send_eof() <.StreamTransport.send_eof>` then this calls it. Otherwise, this calls + :meth:`self.send_transport.close() <.BaseTransport.close>`. + + Note: + This method handles the case where :meth:`self.send_transport.send_eof() <.StreamTransport.send_eof>` + raises :exc:`NotImplementedError` or :exc:`.UnsupportedOperation`; + :meth:`self.send_transport.close() <.BaseTransport.close>` will be called as a fallback. + """ + try: + if not isinstance(self.send_transport, _transports.StreamTransport): + raise NotImplementedError("not a full-duplex transport") + # send_eof() can raise UnsupportedOperation, subclass of NotImplementedError + self.send_transport.send_eof() + except NotImplementedError: + self.send_transport.close() + + @property + @_utils.inherit_doc(_transports.BaseTransport) + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self.send_transport.extra_attributes, + **self.receive_transport.extra_attributes, + } + + +@final +@runtime_final_class +@dataclass(frozen=True, slots=True) +class StapledDatagramTransport(_transports.DatagramTransport, Generic[_T_SendDatagramTransport, _T_ReceiveDatagramTransport]): + """ + A transport of unreliable packets of data that merges two transports. + + Extra attributes will be provided from both transports, with the receive stream providing the values in case of a conflict. + + .. versionadded:: 1.1 + """ + + send_transport: _T_SendDatagramTransport + """The write part of the transport.""" + + receive_transport: _T_ReceiveDatagramTransport + """The read part of the transport.""" + + def close(self) -> None: + """ + Closes both transports. + """ + _close_stapled_transports(self.send_transport, self.receive_transport) + + def is_closed(self) -> bool: + """ + Checks if :meth:`close` has been called on both transports. + + Returns: + :data:`True` if the transports are closed. + """ + return self.send_transport.is_closed() and self.receive_transport.is_closed() + + def recv(self, timeout: float) -> bytes: + """ + Calls :meth:`self.receive_transport.recv() <.DatagramReadTransport.recv>`. + """ + return self.receive_transport.recv(timeout) + + def send(self, data: bytes | bytearray | memoryview, timeout: float) -> None: + """ + Calls :meth:`self.send_transport.send() <.DatagramWriteTransport.send>`. + """ + return self.send_transport.send(data, timeout) + + @property + @_utils.inherit_doc(_transports.BaseTransport) + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return { + **self.send_transport.extra_attributes, + **self.receive_transport.extra_attributes, + } + + +def _close_stapled_transports( + send_transport: _transports.BaseTransport, + receive_transport: _transports.BaseTransport, +) -> None: + try: + send_transport.close() + finally: + receive_transport.close() diff --git a/tests/unit_test/_utils.py b/tests/unit_test/_utils.py index 344d142b..f2a01833 100644 --- a/tests/unit_test/_utils.py +++ b/tests/unit_test/_utils.py @@ -1,9 +1,10 @@ from __future__ import annotations +import contextlib import functools import inspect import threading -from collections.abc import Awaitable, Callable, Coroutine, Sequence +from collections.abc import Awaitable, Callable, Coroutine, Iterator, Sequence from socket import AF_INET, AF_INET6, IPPROTO_TCP, IPPROTO_UDP, SOCK_DGRAM, SOCK_STREAM from types import TracebackType from typing import TYPE_CHECKING, Any @@ -260,3 +261,21 @@ def decorator(f: Callable[..., Coroutine[Any, Any, Any]]) -> AsyncMock: return stub return decorator + + +@contextlib.contextmanager +def restore_mock_side_effect(mock: MagicMock) -> Iterator[None]: + default_side_effect = mock.side_effect + default_return_value = mock.side_effect + try: + yield + finally: + mock.side_effect = default_side_effect + mock.return_value = default_return_value + + +@contextlib.contextmanager +def temporary_mock_side_effect(mock: MagicMock, side_effect: Any) -> Iterator[None]: + with restore_mock_side_effect(mock): + mock.side_effect = side_effect + yield diff --git a/tests/unit_test/test_async/mock_tools.py b/tests/unit_test/test_async/mock_tools.py index f8c2da33..ed68c5fb 100644 --- a/tests/unit_test/test_async/mock_tools.py +++ b/tests/unit_test/test_async/mock_tools.py @@ -22,4 +22,5 @@ def close_side_effect() -> None: mock_transport.aclose.side_effect = close_side_effect mock_transport.backend.return_value = backend + mock_transport.extra_attributes = {} return mock_transport diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_composite.py b/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_composite.py new file mode 100644 index 00000000..00d29730 --- /dev/null +++ b/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_composite.py @@ -0,0 +1,376 @@ +from __future__ import annotations + +from asyncio.exceptions import CancelledError +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING, Literal, assert_never + +from easynetwork.exceptions import UnsupportedOperation +from easynetwork.lowlevel.api_async.backend._asyncio.backend import AsyncIOBackend +from easynetwork.lowlevel.api_async.backend.utils import new_builtin_backend +from easynetwork.lowlevel.api_async.transports.abc import ( + AsyncBaseTransport, + AsyncDatagramReadTransport, + AsyncDatagramTransport, + AsyncDatagramWriteTransport, + AsyncStreamReadTransport, + AsyncStreamTransport, + AsyncStreamWriteTransport, +) +from easynetwork.lowlevel.api_async.transports.composite import AsyncStapledDatagramTransport, AsyncStapledStreamTransport + +import pytest +import pytest_asyncio + +from ...._utils import restore_mock_side_effect +from ...mock_tools import make_transport_mock + +if TYPE_CHECKING: + from unittest.mock import MagicMock + + from pytest_mock import MockerFixture + + +@pytest.mark.asyncio +class BaseAsyncStapledTransportTests: + + async def test____aclose____close_both_transports( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncBaseTransport, + ) -> None: + # Arrange + assert not stapled_transport.is_closing() + + # Act + await stapled_transport.aclose() + + # Assert + assert stapled_transport.is_closing() + mock_send_transport.aclose.assert_awaited_once_with() + mock_receive_transport.aclose.assert_awaited_once_with() + + @pytest.mark.parametrize("transport_cancelled", ["send", "receive"]) + async def test____aclose____close_both_transports____even_upon_cancellation( + self, + transport_cancelled: Literal["send", "receive"], + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncBaseTransport, + ) -> None: + # Arrange + match transport_cancelled: + case "send": + mock_send_transport.aclose.side_effect = CancelledError + case "receive": + mock_receive_transport.aclose.side_effect = CancelledError + case _: + assert_never(transport_cancelled) + + # Act + with pytest.raises(CancelledError): + await stapled_transport.aclose() + + # Assert + mock_send_transport.aclose.assert_awaited_once_with() + mock_receive_transport.aclose.assert_awaited_once_with() + + @pytest.mark.parametrize( + ["send_transport_is_closing", "receive_transport_is_closing", "expected_state"], + [ + pytest.param(False, False, False), + pytest.param(True, False, False), + pytest.param(False, True, False), + pytest.param(True, True, True), + ], + ) + async def test____is_closing____expected_state( + self, + send_transport_is_closing: bool, + receive_transport_is_closing: bool, + expected_state: bool, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncBaseTransport, + ) -> None: + # Arrange + if send_transport_is_closing: + await mock_send_transport.aclose() + if receive_transport_is_closing: + await mock_receive_transport.aclose() + + # Act + is_closing = stapled_transport.is_closing() + + # Assert + assert is_closing is expected_state + + async def test____get_backend____default( + self, + stapled_transport: AsyncBaseTransport, + asyncio_backend: AsyncIOBackend, + ) -> None: + # Arrange + + # Act & Assert + assert stapled_transport.backend() is asyncio_backend + + async def test____extra_attributes____both_transports_attributes( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncBaseTransport, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_send_transport.extra_attributes = { + mocker.sentinel.send_only_attr: lambda: mocker.sentinel.send_only_value, + mocker.sentinel.attr_conflict: lambda: mocker.sentinel.send_won, + } + mock_receive_transport.extra_attributes = { + mocker.sentinel.recv_only_attr: lambda: mocker.sentinel.recv_only_value, + mocker.sentinel.attr_conflict: lambda: mocker.sentinel.recv_won, + } + + # Act & Assert + assert stapled_transport.extra(mocker.sentinel.send_only_attr) is mocker.sentinel.send_only_value + assert stapled_transport.extra(mocker.sentinel.recv_only_attr) is mocker.sentinel.recv_only_value + assert stapled_transport.extra(mocker.sentinel.attr_conflict) is mocker.sentinel.recv_won + + +class TestAsyncStapledStreamTransport(BaseAsyncStapledTransportTests): + @pytest.fixture(params=[AsyncStreamWriteTransport, AsyncStreamTransport]) + @staticmethod + def mock_send_transport( + request: pytest.FixtureRequest, + asyncio_backend: AsyncIOBackend, + mocker: MockerFixture, + ) -> MagicMock: + return make_transport_mock(mocker=mocker, spec=request.param, backend=asyncio_backend) + + @pytest.fixture(params=[AsyncStreamReadTransport, AsyncStreamTransport]) + @staticmethod + def mock_receive_transport( + request: pytest.FixtureRequest, + asyncio_backend: AsyncIOBackend, + mocker: MockerFixture, + ) -> MagicMock: + return make_transport_mock(mocker=mocker, spec=request.param, backend=asyncio_backend) + + @pytest_asyncio.fixture + @staticmethod + async def stapled_transport( + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + ) -> AsyncIterator[AsyncStapledStreamTransport[MagicMock, MagicMock]]: + transport = AsyncStapledStreamTransport(mock_send_transport, mock_receive_transport) + mock_send_transport.reset_mock() + mock_receive_transport.reset_mock() + async with transport: + with restore_mock_side_effect(mock_send_transport.aclose), restore_mock_side_effect(mock_receive_transport.aclose): + yield transport + + async def test____dunder_init___transports_inconsistency_error( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + ) -> None: + # Arrange + mock_send_transport.backend.side_effect = [new_builtin_backend("asyncio")] + mock_receive_transport.backend.side_effect = [new_builtin_backend("asyncio")] + + # Act & Assert + with pytest.raises(RuntimeError, match=r"^transport backend inconsistency$"): + _ = AsyncStapledStreamTransport(mock_send_transport, mock_receive_transport) + + async def test____recv____calls_receive_transport_recv( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncStapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_receive_transport.recv.return_value = mocker.sentinel.recv_result + + # Act + data = await stapled_transport.recv(mocker.sentinel.recv_bufsize) + + # Assert + assert data is mocker.sentinel.recv_result + assert mock_receive_transport.mock_calls == [mocker.call.recv(mocker.sentinel.recv_bufsize)] + assert mock_send_transport.mock_calls == [] + + async def test____recv_into____calls_receive_transport_recv_into( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncStapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_receive_transport.recv_into.return_value = mocker.sentinel.recv_into_result + + # Act + nbytes = await stapled_transport.recv_into(mocker.sentinel.recv_buffer) + + # Assert + assert nbytes is mocker.sentinel.recv_into_result + assert mock_receive_transport.mock_calls == [mocker.call.recv_into(mocker.sentinel.recv_buffer)] + assert mock_send_transport.mock_calls == [] + + async def test____send_all____calls_send_transport_send_all( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncStapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_send_transport.send_all.return_value = None + + # Act + await stapled_transport.send_all(mocker.sentinel.send_data) + + # Assert + assert mock_send_transport.mock_calls == [mocker.call.send_all(mocker.sentinel.send_data)] + assert mock_receive_transport.mock_calls == [] + + async def test____send_all_from_iterable____calls_send_transport_send_all_from_iterable( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncStapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_send_transport.send_all_from_iterable.return_value = None + + # Act + await stapled_transport.send_all_from_iterable(mocker.sentinel.send_data) + + # Assert + assert mock_send_transport.mock_calls == [mocker.call.send_all_from_iterable(mocker.sentinel.send_data)] + assert mock_receive_transport.mock_calls == [] + + async def test____send_eof____calls_send_transport_send_eof_if_exists_else_aclose( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncStapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + if hasattr(mock_send_transport, "send_eof"): + mock_send_transport.send_eof.return_value = None + + # Act + await stapled_transport.send_eof() + + # Assert + if hasattr(mock_send_transport, "send_eof"): + assert mock_send_transport.mock_calls == [mocker.call.send_eof()] + else: + assert mock_send_transport.mock_calls == [mocker.call.aclose()] + assert mock_receive_transport.mock_calls == [] + + @pytest.mark.parametrize("mock_send_transport", [AsyncStreamTransport], indirect=True) + @pytest.mark.parametrize("send_eof_error", [UnsupportedOperation, NotImplementedError]) + async def test____send_eof____calls_send_transport_aclos_if_send_eof_is_not_implemented( + self, + send_eof_error: type[Exception], + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncStapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_send_transport.send_eof.side_effect = send_eof_error + + # Act + await stapled_transport.send_eof() + + # Assert + assert mock_send_transport.mock_calls == [mocker.call.send_eof(), mocker.call.aclose()] + assert mock_receive_transport.mock_calls == [] + + +class TestAsyncStapledDatagramTransport(BaseAsyncStapledTransportTests): + @pytest.fixture(params=[AsyncDatagramWriteTransport, AsyncDatagramTransport]) + @staticmethod + def mock_send_transport( + request: pytest.FixtureRequest, + asyncio_backend: AsyncIOBackend, + mocker: MockerFixture, + ) -> MagicMock: + return make_transport_mock(mocker=mocker, spec=request.param, backend=asyncio_backend) + + @pytest.fixture(params=[AsyncDatagramReadTransport, AsyncDatagramTransport]) + @staticmethod + def mock_receive_transport( + request: pytest.FixtureRequest, + asyncio_backend: AsyncIOBackend, + mocker: MockerFixture, + ) -> MagicMock: + return make_transport_mock(mocker=mocker, spec=request.param, backend=asyncio_backend) + + @pytest_asyncio.fixture + @staticmethod + async def stapled_transport( + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + ) -> AsyncIterator[AsyncStapledDatagramTransport[MagicMock, MagicMock]]: + transport = AsyncStapledDatagramTransport(mock_send_transport, mock_receive_transport) + mock_send_transport.reset_mock() + mock_receive_transport.reset_mock() + async with transport: + with restore_mock_side_effect(mock_send_transport.aclose), restore_mock_side_effect(mock_receive_transport.aclose): + yield transport + + async def test____dunder_init___transports_inconsistency_error( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + ) -> None: + # Arrange + mock_send_transport.backend.side_effect = [new_builtin_backend("asyncio")] + mock_receive_transport.backend.side_effect = [new_builtin_backend("asyncio")] + + # Act & Assert + with pytest.raises(RuntimeError, match=r"^transport backend inconsistency$"): + _ = AsyncStapledDatagramTransport(mock_send_transport, mock_receive_transport) + + async def test____recv____calls_receive_transport_recv( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncStapledDatagramTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_receive_transport.recv.return_value = mocker.sentinel.recv_result + + # Act + data = await stapled_transport.recv() + + # Assert + assert data is mocker.sentinel.recv_result + assert mock_receive_transport.mock_calls == [mocker.call.recv()] + assert mock_send_transport.mock_calls == [] + + async def test____send____calls_send_transport_send( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: AsyncStapledDatagramTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_send_transport.send.return_value = None + + # Act + await stapled_transport.send(mocker.sentinel.send_data) + + # Assert + assert mock_send_transport.mock_calls == [mocker.call.send(mocker.sentinel.send_data)] + assert mock_receive_transport.mock_calls == [] diff --git a/tests/unit_test/test_sync/mock_tools.py b/tests/unit_test/test_sync/mock_tools.py index b20f989f..3d220a95 100644 --- a/tests/unit_test/test_sync/mock_tools.py +++ b/tests/unit_test/test_sync/mock_tools.py @@ -19,4 +19,5 @@ def close_side_effect() -> None: mock_transport.is_closed.return_value = True mock_transport.close.side_effect = close_side_effect + mock_transport.extra_attributes = {} return mock_transport diff --git a/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_composite.py b/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_composite.py new file mode 100644 index 00000000..73677525 --- /dev/null +++ b/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_composite.py @@ -0,0 +1,363 @@ +from __future__ import annotations + +from collections.abc import Iterator +from typing import TYPE_CHECKING, Literal, assert_never + +from easynetwork.exceptions import UnsupportedOperation +from easynetwork.lowlevel.api_sync.transports.abc import ( + BaseTransport, + DatagramReadTransport, + DatagramTransport, + DatagramWriteTransport, + StreamReadTransport, + StreamTransport, + StreamWriteTransport, +) +from easynetwork.lowlevel.api_sync.transports.composite import StapledDatagramTransport, StapledStreamTransport + +import pytest + +from ...._utils import restore_mock_side_effect +from ...mock_tools import make_transport_mock + +if TYPE_CHECKING: + from unittest.mock import MagicMock + + from pytest_mock import MockerFixture + + +class BaseStapledTransportTests: + + def test____close____close_both_transports( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: BaseTransport, + ) -> None: + # Arrange + assert not stapled_transport.is_closed() + + # Act + stapled_transport.close() + + # Assert + assert stapled_transport.is_closed() + mock_send_transport.close.assert_called_once_with() + mock_receive_transport.close.assert_called_once_with() + + @pytest.mark.parametrize("transport_cancelled", ["send", "receive"]) + def test____close____close_both_transports____even_upon_interrupt( + self, + transport_cancelled: Literal["send", "receive"], + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: BaseTransport, + ) -> None: + # Arrange + match transport_cancelled: + case "send": + mock_send_transport.close.side_effect = KeyboardInterrupt + case "receive": + mock_receive_transport.close.side_effect = KeyboardInterrupt + case _: + assert_never(transport_cancelled) + + # Act + with pytest.raises(KeyboardInterrupt): + stapled_transport.close() + + # Assert + mock_send_transport.close.assert_called_once_with() + mock_receive_transport.close.assert_called_once_with() + + @pytest.mark.parametrize( + ["send_transport_is_closed", "receive_transport_is_closed", "expected_state"], + [ + pytest.param(False, False, False), + pytest.param(True, False, False), + pytest.param(False, True, False), + pytest.param(True, True, True), + ], + ) + def test____is_closed____expected_state( + self, + send_transport_is_closed: bool, + receive_transport_is_closed: bool, + expected_state: bool, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: BaseTransport, + ) -> None: + # Arrange + if send_transport_is_closed: + mock_send_transport.close() + if receive_transport_is_closed: + mock_receive_transport.close() + + # Act + is_closed = stapled_transport.is_closed() + + # Assert + assert is_closed is expected_state + + def test____extra_attributes____both_transports_attributes( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: BaseTransport, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_send_transport.extra_attributes = { + mocker.sentinel.send_only_attr: lambda: mocker.sentinel.send_only_value, + mocker.sentinel.attr_conflict: lambda: mocker.sentinel.send_won, + } + mock_receive_transport.extra_attributes = { + mocker.sentinel.recv_only_attr: lambda: mocker.sentinel.recv_only_value, + mocker.sentinel.attr_conflict: lambda: mocker.sentinel.recv_won, + } + + # Act & Assert + assert stapled_transport.extra(mocker.sentinel.send_only_attr) is mocker.sentinel.send_only_value + assert stapled_transport.extra(mocker.sentinel.recv_only_attr) is mocker.sentinel.recv_only_value + assert stapled_transport.extra(mocker.sentinel.attr_conflict) is mocker.sentinel.recv_won + + +class TestStapledStreamTransport(BaseStapledTransportTests): + @pytest.fixture(params=[StreamWriteTransport, StreamTransport]) + @staticmethod + def mock_send_transport( + request: pytest.FixtureRequest, + mocker: MockerFixture, + ) -> MagicMock: + return make_transport_mock(mocker=mocker, spec=request.param) + + @pytest.fixture(params=[StreamReadTransport, StreamTransport]) + @staticmethod + def mock_receive_transport( + request: pytest.FixtureRequest, + mocker: MockerFixture, + ) -> MagicMock: + return make_transport_mock(mocker=mocker, spec=request.param) + + @pytest.fixture + @staticmethod + def stapled_transport( + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + ) -> Iterator[StapledStreamTransport[MagicMock, MagicMock]]: + transport = StapledStreamTransport(mock_send_transport, mock_receive_transport) + mock_send_transport.reset_mock() + mock_receive_transport.reset_mock() + with transport: + with restore_mock_side_effect(mock_send_transport.close), restore_mock_side_effect(mock_receive_transport.close): + yield transport + + def test____recv____calls_receive_transport_recv( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: StapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_receive_transport.recv.return_value = mocker.sentinel.recv_result + + # Act + data = stapled_transport.recv(mocker.sentinel.recv_bufsize, mocker.sentinel.recv_timeout) + + # Assert + assert data is mocker.sentinel.recv_result + assert mock_receive_transport.mock_calls == [ + mocker.call.recv(mocker.sentinel.recv_bufsize, mocker.sentinel.recv_timeout), + ] + assert mock_send_transport.mock_calls == [] + + def test____recv_into____calls_receive_transport_recv_into( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: StapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_receive_transport.recv_into.return_value = mocker.sentinel.recv_into_result + + # Act + nbytes = stapled_transport.recv_into(mocker.sentinel.recv_buffer, mocker.sentinel.recv_timeout) + + # Assert + assert nbytes is mocker.sentinel.recv_into_result + assert mock_receive_transport.mock_calls == [ + mocker.call.recv_into(mocker.sentinel.recv_buffer, mocker.sentinel.recv_timeout), + ] + assert mock_send_transport.mock_calls == [] + + def test____send____calls_send_transport_send( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: StapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_send_transport.send.return_value = mocker.sentinel.send_result + + # Act + nbytes = stapled_transport.send(mocker.sentinel.send_data, mocker.sentinel.send_timeout) + + # Assert + assert nbytes is mocker.sentinel.send_result + assert mock_send_transport.mock_calls == [ + mocker.call.send(mocker.sentinel.send_data, mocker.sentinel.send_timeout), + ] + assert mock_receive_transport.mock_calls == [] + + def test____send_all____calls_send_transport_send_all( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: StapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_send_transport.send_all.return_value = None + + # Act + stapled_transport.send_all(mocker.sentinel.send_data, mocker.sentinel.send_timeout) + + # Assert + assert mock_send_transport.mock_calls == [ + mocker.call.send_all(mocker.sentinel.send_data, mocker.sentinel.send_timeout), + ] + assert mock_receive_transport.mock_calls == [] + + def test____send_all_from_iterable____calls_send_transport_send_all_from_iterable( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: StapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_send_transport.send_all_from_iterable.return_value = None + + # Act + stapled_transport.send_all_from_iterable(mocker.sentinel.send_data, mocker.sentinel.send_timeout) + + # Assert + assert mock_send_transport.mock_calls == [ + mocker.call.send_all_from_iterable(mocker.sentinel.send_data, mocker.sentinel.send_timeout), + ] + assert mock_receive_transport.mock_calls == [] + + def test____send_eof____calls_send_transport_send_eof_if_exists_else_close( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: StapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + if hasattr(mock_send_transport, "send_eof"): + mock_send_transport.send_eof.return_value = None + + # Act + stapled_transport.send_eof() + + # Assert + if hasattr(mock_send_transport, "send_eof"): + assert mock_send_transport.mock_calls == [mocker.call.send_eof()] + else: + assert mock_send_transport.mock_calls == [mocker.call.close()] + assert mock_receive_transport.mock_calls == [] + + @pytest.mark.parametrize("mock_send_transport", [StreamTransport], indirect=True) + @pytest.mark.parametrize("send_eof_error", [UnsupportedOperation, NotImplementedError]) + def test____send_eof____calls_send_transport_aclos_if_send_eof_is_not_implemented( + self, + send_eof_error: type[Exception], + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: StapledStreamTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_send_transport.send_eof.side_effect = send_eof_error + + # Act + stapled_transport.send_eof() + + # Assert + assert mock_send_transport.mock_calls == [mocker.call.send_eof(), mocker.call.close()] + assert mock_receive_transport.mock_calls == [] + + +class TestStapledDatagramTransport(BaseStapledTransportTests): + @pytest.fixture(params=[DatagramWriteTransport, DatagramTransport]) + @staticmethod + def mock_send_transport( + request: pytest.FixtureRequest, + mocker: MockerFixture, + ) -> MagicMock: + return make_transport_mock(mocker=mocker, spec=request.param) + + @pytest.fixture(params=[DatagramReadTransport, DatagramTransport]) + @staticmethod + def mock_receive_transport( + request: pytest.FixtureRequest, + mocker: MockerFixture, + ) -> MagicMock: + return make_transport_mock(mocker=mocker, spec=request.param) + + @pytest.fixture + @staticmethod + def stapled_transport( + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + ) -> Iterator[StapledDatagramTransport[MagicMock, MagicMock]]: + transport = StapledDatagramTransport(mock_send_transport, mock_receive_transport) + mock_send_transport.reset_mock() + mock_receive_transport.reset_mock() + with transport: + with restore_mock_side_effect(mock_send_transport.close), restore_mock_side_effect(mock_receive_transport.close): + yield transport + + def test____recv____calls_receive_transport_recv( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: StapledDatagramTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_receive_transport.recv.return_value = mocker.sentinel.recv_result + + # Act + data = stapled_transport.recv(mocker.sentinel.recv_timeout) + + # Assert + assert data is mocker.sentinel.recv_result + assert mock_receive_transport.mock_calls == [ + mocker.call.recv(mocker.sentinel.recv_timeout), + ] + assert mock_send_transport.mock_calls == [] + + def test____send____calls_send_transport_send( + self, + mock_send_transport: MagicMock, + mock_receive_transport: MagicMock, + stapled_transport: StapledDatagramTransport[MagicMock, MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + mock_send_transport.send.return_value = None + + # Act + stapled_transport.send(mocker.sentinel.send_data, mocker.sentinel.send_timeout) + + # Assert + assert mock_send_transport.mock_calls == [ + mocker.call.send(mocker.sentinel.send_data, mocker.sentinel.send_timeout), + ] + assert mock_receive_transport.mock_calls == []