From 1d58bb53d9ea7f0e57f11de3cb49975f91a76f6a Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sat, 18 Nov 2023 11:51:49 +0100 Subject: [PATCH] [FIX] Implement zero copy writes for TCP socket (sync and async) transports --- docs/source/howto/serializers.rst | 7 +- src/easynetwork/lowlevel/_utils.py | 33 ++- .../lowlevel/api_async/transports/abc.py | 14 +- .../lowlevel/api_sync/transports/abc.py | 18 +- .../lowlevel/api_sync/transports/socket.py | 29 ++- .../lowlevel/asyncio/_asyncio_utils.py | 35 ++- src/easynetwork/lowlevel/asyncio/socket.py | 33 ++- .../lowlevel/asyncio/stream/socket.py | 6 + src/easynetwork/lowlevel/constants.py | 18 ++ .../test_async/test_server/test_tcp.py | 2 +- tests/tools.py | 9 + tests/unit_test/base.py | 26 ++- .../test_api/test_client/test_tcp.py | 8 +- .../test_asyncio_backend/test_socket.py | 160 ++++++++++++- .../test_asyncio_backend/test_stream.py | 23 +- .../test_asyncio_backend/test_utils.py | 47 ++++ .../test_transports/test_abc.py | 3 +- .../test_sync/test_client/test_tcp.py | 6 +- .../test_transports/test_abc.py | 20 +- .../test_transports/test_socket.py | 216 ++++++++++++++++-- tests/unit_test/test_tools/test_utils.py | 46 +++- 21 files changed, 691 insertions(+), 68 deletions(-) diff --git a/docs/source/howto/serializers.rst b/docs/source/howto/serializers.rst index b5ea5605..6e295224 100644 --- a/docs/source/howto/serializers.rst +++ b/docs/source/howto/serializers.rst @@ -204,11 +204,8 @@ Most of the time, you will have a single :keyword:`yield`. The goal is: each :ke .. note:: - The endpoint implementation can (and most likely will) decide to concatenate all the pieces and do one big send. - This is the optimized way to send a large byte buffer. - - However, it may be more attractive to do something else with the returned bytes. - :meth:`~.AbstractIncrementalPacketSerializer.incremental_serialize` is here to give endpoints this freedom. + The endpoint implementation can decide to concatenate all the pieces and do one big send. However, it may be more attractive to do something else + with the returned bytes. :meth:`~.AbstractIncrementalPacketSerializer.incremental_serialize` is here to give endpoints this freedom. The Purpose Of ``incremental_deserialize()`` diff --git a/src/easynetwork/lowlevel/_utils.py b/src/easynetwork/lowlevel/_utils.py index be82d93c..0f8692e5 100644 --- a/src/easynetwork/lowlevel/_utils.py +++ b/src/easynetwork/lowlevel/_utils.py @@ -16,6 +16,7 @@ __all__ = [ "ElapsedTime", + "adjust_leftover_buffer", "check_real_socket_state", "check_socket_family", "check_socket_no_ssl", @@ -30,6 +31,7 @@ "remove_traceback_frames_in_place", "replace_kwargs", "set_reuseport", + "supports_socket_sendmsg", "validate_timeout_delay", ] @@ -41,8 +43,10 @@ import socket as _socket import threading import time +from abc import abstractmethod +from collections import deque from collections.abc import Callable, Iterable, Iterator -from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Self, TypeGuard, TypeVar +from typing import TYPE_CHECKING, Any, Concatenate, Final, ParamSpec, Protocol, Self, TypeGuard, TypeVar try: import ssl as _ssl @@ -58,6 +62,8 @@ if TYPE_CHECKING: from ssl import SSLError as _SSLError, SSLSocket as _SSLSocket + from _typeshed import ReadableBuffer + from .socket import ISocket, SupportsSocketOptions _P = ParamSpec("_P") @@ -130,6 +136,20 @@ def check_real_socket_state(socket: ISocket) -> None: raise error_from_errno(errno) +_HAS_SENDMSG: Final[bool] = hasattr(_socket.socket, "sendmsg") + + +class _SupportsSocketSendMSG(Protocol): + @abstractmethod + def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> int: + ... + + +def supports_socket_sendmsg(sock: _socket.socket) -> TypeGuard[_SupportsSocketSendMSG]: + assert isinstance(sock, _socket.SocketType) # nosec assert_used + return _HAS_SENDMSG + + def is_ssl_socket(socket: _socket.socket) -> TypeGuard[_SSLSocket]: if ssl is None: return False @@ -170,6 +190,17 @@ def iter_bytes(b: bytes | bytearray | memoryview) -> Iterator[bytes]: return map(int.to_bytes, b) +def adjust_leftover_buffer(buffers: deque[memoryview], nbytes: int) -> None: + while nbytes > 0: + b = buffers.popleft() + b_len = len(b) + if b_len <= nbytes: + nbytes -= b_len + else: + buffers.appendleft(b[nbytes:]) + break + + def is_socket_connected(sock: ISocket) -> bool: try: sock.getpeername() diff --git a/src/easynetwork/lowlevel/api_async/transports/abc.py b/src/easynetwork/lowlevel/api_async/transports/abc.py index b9b9820f..8198d1c6 100644 --- a/src/easynetwork/lowlevel/api_async/transports/abc.py +++ b/src/easynetwork/lowlevel/api_async/transports/abc.py @@ -118,19 +118,15 @@ async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytear """ An efficient way to send a bunch of data via the transport. - Currently, the default implementation concatenates the arguments and - calls :meth:`send_all` on the result. + Like :meth:`send_all`, this method continues to send data from bytes until either all data has been sent or an error + occurs. :data:`None` is returned on success. On error, an exception is raised, and there is no way to determine how much + data, if any, was successfully sent. Parameters: iterable_of_data: An :term:`iterable` yielding the bytes to send. """ - iterable_of_data = list(iterable_of_data) - if len(iterable_of_data) == 1: - data = iterable_of_data[0] - else: - data = b"".join(iterable_of_data) - del iterable_of_data - return await self.send_all(data) + for data in iterable_of_data: + await self.send_all(data) class AsyncStreamTransport(AsyncStreamWriteTransport, AsyncStreamReadTransport): diff --git a/src/easynetwork/lowlevel/api_sync/transports/abc.py b/src/easynetwork/lowlevel/api_sync/transports/abc.py index 8012ced8..22372eb7 100644 --- a/src/easynetwork/lowlevel/api_sync/transports/abc.py +++ b/src/easynetwork/lowlevel/api_sync/transports/abc.py @@ -129,7 +129,7 @@ def send_all(self, data: bytes | bytearray | memoryview, timeout: float) -> None """ total_sent: int = 0 - with memoryview(data) as data: + with memoryview(data) as data, data.cast("B") as data: nb_bytes_to_send = len(data) if nb_bytes_to_send == 0: sent = self.send(data, timeout) @@ -148,8 +148,9 @@ def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | """ An efficient way to send a bunch of data via the transport. - Currently, the default implementation concatenates the arguments and - calls :meth:`send_all` on the result. + Like :meth:`send_all`, this method continues to send data from bytes until either all data has been sent or an error + occurs. :data:`None` is returned on success. On error, an exception is raised, and there is no way to determine how much + data, if any, was successfully sent. Parameters: iterable_of_data: An :term:`iterable` yielding the bytes to send. @@ -159,13 +160,10 @@ def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | ValueError: Negative `timeout`. TimeoutError: Operation timed out. """ - iterable_of_data = list(iterable_of_data) - if len(iterable_of_data) == 1: - data = iterable_of_data[0] - else: - data = b"".join(iterable_of_data) - del iterable_of_data - return self.send_all(data, timeout) + for data in iterable_of_data: + with _utils.ElapsedTime() as elapsed: + self.send_all(data, timeout) + timeout = elapsed.recompute_timeout(timeout) class StreamTransport(StreamWriteTransport, StreamReadTransport): diff --git a/src/easynetwork/lowlevel/api_sync/transports/socket.py b/src/easynetwork/lowlevel/api_sync/transports/socket.py index c9f32e5f..42c4f47e 100644 --- a/src/easynetwork/lowlevel/api_sync/transports/socket.py +++ b/src/easynetwork/lowlevel/api_sync/transports/socket.py @@ -22,10 +22,11 @@ "SocketStreamTransport", ] +import itertools import selectors import socket -from collections import ChainMap -from collections.abc import Callable, Mapping +from collections import ChainMap, deque +from collections.abc import Callable, Iterable, Mapping from typing import TYPE_CHECKING, Any, ParamSpec, TypeVar try: @@ -96,6 +97,30 @@ def send_noblock(self, data: bytes | bytearray | memoryview) -> int: except (BlockingIOError, InterruptedError): raise base_selector.WouldBlockOnWrite(self.__socket.fileno()) from None + @_utils.inherit_doc(base_selector.SelectorStreamTransport) + def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview], timeout: float) -> None: + _sock = self.__socket + if constants.SC_IOV_MAX <= 0 or not _utils.supports_socket_sendmsg(_sock): + return super().send_all_from_iterable(iterable_of_data, timeout) + + buffers: deque[memoryview] = deque(memoryview(data).cast("B") for data in iterable_of_data) + del iterable_of_data + + sock_sendmsg = _sock.sendmsg + del _sock + + def try_sendmsg() -> int: + try: + return sock_sendmsg(itertools.islice(buffers, constants.SC_IOV_MAX)) + except (BlockingIOError, InterruptedError): + raise base_selector.WouldBlockOnWrite(self.__socket.fileno()) from None + + while buffers: + with _utils.ElapsedTime() as elapsed: + sent: int = self._retry(try_sendmsg, timeout) + _utils.adjust_leftover_buffer(buffers, sent) + timeout = elapsed.recompute_timeout(timeout) + @_utils.inherit_doc(base_selector.SelectorStreamTransport) def send_eof(self) -> None: if self.__socket.fileno() < 0: diff --git a/src/easynetwork/lowlevel/asyncio/_asyncio_utils.py b/src/easynetwork/lowlevel/asyncio/_asyncio_utils.py index a2a6d736..988307e1 100644 --- a/src/easynetwork/lowlevel/asyncio/_asyncio_utils.py +++ b/src/easynetwork/lowlevel/asyncio/_asyncio_utils.py @@ -17,7 +17,12 @@ from __future__ import annotations -__all__ = ["create_connection", "open_listener_sockets_from_getaddrinfo_result"] +__all__ = [ + "create_connection", + "open_listener_sockets_from_getaddrinfo_result", + "wait_until_readable", + "wait_until_writable", +] import asyncio import contextlib @@ -216,3 +221,31 @@ def open_listener_sockets_from_getaddrinfo_result( socket_exit_stack.pop_all() return sockets + + +def wait_until_readable(sock: _socket.socket, loop: asyncio.AbstractEventLoop) -> asyncio.Future[None]: + def on_fut_done(f: asyncio.Future[None]) -> None: + loop.remove_reader(sock) + + def wakeup(f: asyncio.Future[None]) -> None: + if not f.done(): + f.set_result(None) + + f = loop.create_future() + loop.add_reader(sock, wakeup, f) + f.add_done_callback(on_fut_done) + return f + + +def wait_until_writable(sock: _socket.socket, loop: asyncio.AbstractEventLoop) -> asyncio.Future[None]: + def on_fut_done(f: asyncio.Future[None]) -> None: + loop.remove_writer(sock) + + def wakeup(f: asyncio.Future[None]) -> None: + if not f.done(): + f.set_result(None) + + f = loop.create_future() + loop.add_writer(sock, wakeup, f) + f.add_done_callback(on_fut_done) + return f diff --git a/src/easynetwork/lowlevel/asyncio/socket.py b/src/easynetwork/lowlevel/asyncio/socket.py index 0934e6e7..11ffe359 100644 --- a/src/easynetwork/lowlevel/asyncio/socket.py +++ b/src/easynetwork/lowlevel/asyncio/socket.py @@ -23,12 +23,16 @@ import asyncio.trsock import contextlib import errno as _errno +import itertools import socket as _socket -from collections.abc import Iterator -from typing import TYPE_CHECKING, Literal, Self, TypeAlias +from collections import deque +from collections.abc import Iterable, Iterator +from typing import TYPE_CHECKING, Literal, Self, TypeAlias, cast from weakref import WeakSet -from .. import _utils +from ...exceptions import UnsupportedOperation +from .. import _utils, constants +from . import _asyncio_utils from .tasks import CancelScope, TaskUtils if TYPE_CHECKING: @@ -120,6 +124,26 @@ async def sendall(self, data: ReadableBuffer, /) -> None: socket = self.__check_not_closed() await self.__loop.sock_sendall(socket, data) + async def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> None: + with self.__conflict_detection("send", abort_errno=_errno.ECONNABORTED): + socket = self.__check_not_closed() + if constants.SC_IOV_MAX <= 0 or not _utils.supports_socket_sendmsg(_sock := socket): + raise UnsupportedOperation("sendmsg() is not supported") + + loop = self.__loop + buffers = cast("deque[memoryview]", deque(memoryview(data).cast("B") for data in buffers)) + + sock_sendmsg = _sock.sendmsg + del _sock + + while buffers: + try: + sent: int = sock_sendmsg(itertools.islice(buffers, constants.SC_IOV_MAX)) + except (BlockingIOError, InterruptedError): + await _asyncio_utils.wait_until_writable(socket, loop) + else: + _utils.adjust_leftover_buffer(buffers, sent) + async def sendto(self, data: ReadableBuffer, address: _socket._Address, /) -> None: with self.__conflict_detection("send", abort_errno=_errno.ECONNABORTED): socket = self.__check_not_closed() @@ -152,7 +176,8 @@ def __conflict_detection(self, task_id: _SocketTaskId, *, abort_errno: int = _er if task_id in self.__waiters: raise _utils.error_from_errno(_errno.EBUSY) - _ = TaskUtils.current_asyncio_task(self.__loop) + # Checks if we are within the bound loop + TaskUtils.current_asyncio_task(self.__loop) # type: ignore[unused-awaitable] with CancelScope() as scope, contextlib.ExitStack() as stack: self.__scopes.add(scope) diff --git a/src/easynetwork/lowlevel/asyncio/stream/socket.py b/src/easynetwork/lowlevel/asyncio/stream/socket.py index 9cdf7277..e5c26e25 100644 --- a/src/easynetwork/lowlevel/asyncio/stream/socket.py +++ b/src/easynetwork/lowlevel/asyncio/stream/socket.py @@ -166,6 +166,12 @@ async def send_all(self, data: bytes | bytearray | memoryview) -> None: async def send_eof(self) -> None: await self.__socket.shutdown(_socket.SHUT_WR) + async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None: + try: + await self.__socket.sendmsg(iterable_of_data) + except UnsupportedOperation: + await super().send_all_from_iterable(iterable_of_data) + @property def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: socket = self.__socket.socket diff --git a/src/easynetwork/lowlevel/constants.py b/src/easynetwork/lowlevel/constants.py index 7f6af608..347991a4 100644 --- a/src/easynetwork/lowlevel/constants.py +++ b/src/easynetwork/lowlevel/constants.py @@ -22,6 +22,7 @@ "DEFAULT_STREAM_BUFSIZE", "MAX_DATAGRAM_BUFSIZE", "NOT_CONNECTED_SOCKET_ERRNOS", + "SC_IOV_MAX", "SSL_HANDSHAKE_TIMEOUT", "SSL_SHUTDOWN_TIMEOUT", "_DEFAULT_LIMIT", @@ -80,3 +81,20 @@ # Buffer size limit when waiting for a byte sequence _DEFAULT_LIMIT: Final[int] = 64 * 1024 # 64 KiB + + +def __get_sysconf(name: str, /) -> int: + import os + + try: + # os.sysconf() can return a negative value if 'name' is not defined + return os.sysconf(name) # type: ignore[attr-defined,unused-ignore] + except (AttributeError, OSError): + return -1 + + +# Maximum number of buffer that can accept sendmsg(2) +# Can be a negative value +SC_IOV_MAX: Final[int] = __get_sysconf("SC_IOV_MAX") + +del __get_sysconf diff --git a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py index d1f87f22..40a53876 100644 --- a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py +++ b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py @@ -283,7 +283,7 @@ def use_ssl(request: Any) -> bool: case "USE_SSL": return True case _: - raise SystemError + pytest.fail(f"Invalid parameter: {request.param}") @pytest.fixture @staticmethod diff --git a/tests/tools.py b/tests/tools.py index b4234ab6..81f42f74 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -1,5 +1,6 @@ from __future__ import annotations +import asyncio import sys import time from collections.abc import Generator @@ -56,3 +57,11 @@ def __exit__(self, exc_type: type[Exception] | None, exc_value: Exception | None return assert self.start_time >= 0 assert end_time - self.start_time == pytest.approx(self.expected_time, rel=self.approx) + + +def is_proactor_event_loop(event_loop: asyncio.AbstractEventLoop) -> bool: + try: + ProactorEventLoop: type[asyncio.AbstractEventLoop] = getattr(asyncio, "ProactorEventLoop") + except AttributeError: + return False + return isinstance(event_loop, ProactorEventLoop) diff --git a/tests/unit_test/base.py b/tests/unit_test/base.py index 18ad7507..a13bd93f 100644 --- a/tests/unit_test/base.py +++ b/tests/unit_test/base.py @@ -1,15 +1,19 @@ from __future__ import annotations -from socket import AF_INET, AF_INET6 +from socket import AF_INET, AF_INET6, socket as Socket from typing import TYPE_CHECKING from easynetwork.lowlevel.socket import AddressFamily +import pytest + from ._utils import get_all_socket_families if TYPE_CHECKING: from unittest.mock import MagicMock + from pytest_mock import MockerFixture + SUPPORTED_FAMILIES: tuple[str, ...] = tuple(sorted(AddressFamily.__members__)) UNSUPPORTED_FAMILIES: tuple[str, ...] = tuple(sorted(get_all_socket_families().difference(SUPPORTED_FAMILIES))) @@ -83,3 +87,23 @@ def configure_socket_mock_to_raise_ENOTCONN(cls, mock_socket: MagicMock) -> OSEr enotconn_exception = OSError(errno.ENOTCONN, os.strerror(errno.ENOTCONN)) mock_socket.getpeername.side_effect = enotconn_exception return enotconn_exception + + +class MixinTestSocketSendMSG: + @pytest.fixture(autouse=True) + @staticmethod + def SC_IOV_MAX(request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch) -> int: + try: + value: int = request.param + except AttributeError: + value = 1024 + monkeypatch.setattr("easynetwork.lowlevel.constants.SC_IOV_MAX", value) + return value + + @pytest.fixture(autouse=True) + @staticmethod + def supports_socket_sendmsg(mocker: MockerFixture) -> None: + def supports_socket_sendmsg(sock: Socket) -> bool: + return hasattr(sock, "sendmsg") + + mocker.patch("easynetwork.lowlevel._utils.supports_socket_sendmsg", supports_socket_sendmsg) diff --git a/tests/unit_test/test_async/test_api/test_client/test_tcp.py b/tests/unit_test/test_async/test_api/test_client/test_tcp.py index 5f46af3b..8e001b6c 100644 --- a/tests/unit_test/test_async/test_api/test_client/test_tcp.py +++ b/tests/unit_test/test_async/test_api/test_client/test_tcp.py @@ -3,6 +3,7 @@ import contextlib import errno import os +from collections.abc import Generator from socket import AF_INET6, IPPROTO_TCP, SO_ERROR, SO_KEEPALIVE, SOL_SOCKET, TCP_NODELAY from typing import TYPE_CHECKING, Any @@ -126,9 +127,10 @@ def set_default_socket_mock_configuration( @pytest.fixture # DO NOT set autouse=True @staticmethod def setup_producer_mock(mock_stream_protocol: MagicMock) -> None: - mock_stream_protocol.generate_chunks.side_effect = lambda packet: iter( - [str(packet).encode("ascii").removeprefix(b"sentinel.") + b"\n"] - ) + def generate_chunks_side_effect(packet: Any) -> Generator[bytes, None, None]: + yield str(packet).removeprefix("sentinel.").encode("ascii") + b"\n" + + mock_stream_protocol.generate_chunks.side_effect = generate_chunks_side_effect @pytest.fixture # DO NOT set autouse=True @staticmethod diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_socket.py b/tests/unit_test/test_async/test_asyncio_backend/test_socket.py index b168ac7d..3249971d 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_socket.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_socket.py @@ -4,18 +4,22 @@ import asyncio import contextlib -from collections.abc import Callable, Coroutine, Iterator +from collections.abc import Callable, Coroutine, Iterable, Iterator from errno import EBUSY, ECONNABORTED, EINTR, ENOTSOCK from socket import SHUT_RD, SHUT_RDWR, SHUT_WR, socket as Socket from typing import TYPE_CHECKING, Any, final +from easynetwork.exceptions import UnsupportedOperation from easynetwork.lowlevel.asyncio.socket import AsyncSocket import pytest +from ...base import MixinTestSocketSendMSG + if TYPE_CHECKING: from unittest.mock import MagicMock + from _typeshed import ReadableBuffer from pytest_mock import MockerFixture @@ -52,6 +56,31 @@ async def sock_method_patch(sock: Socket, *args: Any, **kwargs: Any) -> Any: monkeypatch.setattr(event_loop, event_loop_method, sock_method_patch) + @pytest.fixture(autouse=True) + @classmethod + def event_loop_mock_event_handlers(cls, event_loop: asyncio.AbstractEventLoop, mocker: MockerFixture) -> None: + to_patch = [ + ("add_reader", "remove_reader"), + ("add_writer", "remove_writer"), + ] + + for add_event_func_name, remove_event_func_name in to_patch: + cls.__patch_event_handler_method(event_loop, add_event_func_name, remove_event_func_name, mocker) + + @staticmethod + def __patch_event_handler_method( + event_loop: asyncio.AbstractEventLoop, + add_event_func_name: str, + remove_event_func_name: str, + mocker: MockerFixture, + ) -> None: + mocker.patch.object( + event_loop, + add_event_func_name, + side_effect=lambda sock, cb, *args: event_loop.call_soon(cb, *args), + ) + mocker.patch.object(event_loop, remove_event_func_name, return_value=None) + @staticmethod @contextlib.contextmanager def _set_sock_method_in_blocking_state(mock_socket_method: MagicMock) -> Iterator[None]: @@ -326,13 +355,20 @@ async def test____accept____returns_socket( mock_tcp_listener_socket.accept.assert_called_once_with() -class TestAsyncStreamSocket(MixinTestAsyncSocketBusy): +class TestAsyncStreamSocket(MixinTestAsyncSocketBusy, MixinTestSocketSendMSG): @pytest.fixture @staticmethod - def mock_tcp_socket(mock_tcp_socket: MagicMock) -> MagicMock: + def mock_tcp_socket(mock_tcp_socket: MagicMock, mocker: MockerFixture) -> MagicMock: mock_tcp_socket.recv.return_value = b"data" mock_tcp_socket.send.side_effect = len mock_tcp_socket.shutdown.return_value = None + + # Always create a new mock instance because sendmsg() is not available on all platforms + # therefore the mocker's autospec will consider sendmsg() unknown on these ones. + mock_tcp_socket.sendmsg = mocker.MagicMock( + spec=lambda *args: None, + side_effect=lambda buffers, *args: sum(map(len, map(memoryview, buffers))), + ) return mock_tcp_socket @pytest.fixture @@ -345,7 +381,7 @@ def socket( mock_tcp_socket.reset_mock() return socket - @pytest.fixture(params=["sendall", "recv"]) + @pytest.fixture(params=["sendall", "sendmsg", "recv"]) @staticmethod def sock_method_name(request: Any) -> str: return request.param @@ -361,10 +397,12 @@ def socket_method(sock_method_name: str, socket: AsyncSocket) -> Callable[[], Co match sock_method_name: case "sendall": return lambda: socket.sendall(b"data") + case "sendmsg": + return lambda: socket.sendmsg([b"data", b"to", b"send"]) case "recv": return lambda: socket.recv(1024) case _: - raise SystemError + pytest.fail(f"Invalid parameter: {sock_method_name}") @pytest.fixture @staticmethod @@ -372,10 +410,12 @@ def mock_socket_method(sock_method_name: str, mock_tcp_socket: MagicMock) -> Mag match sock_method_name: case "sendall": return mock_tcp_socket.send + case "sendmsg": + return mock_tcp_socket.sendmsg case "recv": return mock_tcp_socket.recv case _: - raise SystemError + pytest.fail(f"Invalid parameter: {sock_method_name}") @pytest.fixture @staticmethod @@ -395,6 +435,110 @@ async def test____sendall____sends_data_to_stdlib_socket( # Assert mock_tcp_socket.send.assert_called_once_with(b"data") + async def test____sendmsg____sends_several_buffers_to_stdlib_socket( + self, + socket: AsyncSocket, + mock_tcp_socket: MagicMock, + ) -> None: + # Arrange + chunks: list[list[bytes]] = [] + + def sendmsg_side_effect(buffers: Iterable[ReadableBuffer]) -> int: + buffers = list(buffers) + chunks.append(list(map(bytes, buffers))) + return sum(map(len, map(memoryview, buffers))) + + mock_tcp_socket.sendmsg.side_effect = sendmsg_side_effect + + # Act + await socket.sendmsg(iter([b"data", b"to", b"send"])) + + # Assert + mock_tcp_socket.sendmsg.assert_called_once() + assert chunks == [[b"data", b"to", b"send"]] + + @pytest.mark.parametrize("SC_IOV_MAX", [2], ids=lambda p: f"SC_IOV_MAX=={p}", indirect=True) + async def test____sendmsg____nb_buffers_greather_than_SC_IOV_MAX( + self, + socket: AsyncSocket, + mock_tcp_socket: MagicMock, + ) -> None: + # Arrange + chunks: list[list[bytes]] = [] + + def sendmsg_side_effect(buffers: Iterable[ReadableBuffer]) -> int: + buffers = list(buffers) + chunks.append(list(map(bytes, buffers))) + return sum(map(len, map(memoryview, buffers))) + + mock_tcp_socket.sendmsg.side_effect = sendmsg_side_effect + + # Act + await socket.sendmsg(iter([b"a", b"b", b"c", b"d", b"e"])) + + # Assert + assert mock_tcp_socket.sendmsg.call_count == 3 + assert chunks == [ + [b"a", b"b"], + [b"c", b"d"], + [b"e"], + ] + + async def test____sendmsg____adjust_leftover_buffer( + self, + socket: AsyncSocket, + mock_tcp_socket: MagicMock, + ) -> None: + # Arrange + chunks: list[list[bytes]] = [] + + def sendmsg_side_effect(buffers: Iterable[ReadableBuffer]) -> int: + buffers = list(buffers) + chunks.append(list(map(bytes, buffers))) + return min(sum(map(len, map(memoryview, buffers))), 3) + + mock_tcp_socket.sendmsg.side_effect = sendmsg_side_effect + + # Act + await socket.sendmsg(iter([b"abcd", b"efg", b"hijkl", b"mnop"])) + + # Assert + assert mock_tcp_socket.sendmsg.call_count == 6 + assert chunks == [ + [b"abcd", b"efg", b"hijkl", b"mnop"], + [b"d", b"efg", b"hijkl", b"mnop"], + [b"g", b"hijkl", b"mnop"], + [b"jkl", b"mnop"], + [b"mnop"], + [b"p"], + ] + + async def test____sendmsg____unavailable( + self, + socket: AsyncSocket, + mock_tcp_socket: MagicMock, + ) -> None: + # Arrange + del mock_tcp_socket.sendmsg + + # Act & Assert + with pytest.raises(UnsupportedOperation, match=r"^sendmsg\(\) is not supported$"): + await socket.sendmsg(iter([b"data", b"to", b"send"])) + + @pytest.mark.parametrize("SC_IOV_MAX", [-1, 0], ids=lambda p: f"SC_IOV_MAX=={p}", indirect=True) + async def test____sendmsg____available_but_no_defined_limit( + self, + socket: AsyncSocket, + mock_tcp_socket: MagicMock, + ) -> None: + # Arrange + + # Act & Assert + with pytest.raises(UnsupportedOperation, match=r"^sendmsg\(\) is not supported$"): + await socket.sendmsg(iter([b"data", b"to", b"send"])) + + mock_tcp_socket.sendmsg.assert_not_called() + async def test____recv____receives_data_from_stdlib_socket( self, socket: AsyncSocket, @@ -559,7 +703,7 @@ def socket_method(sock_method_name: str, socket: AsyncSocket) -> Callable[[], Co case "recvfrom": return lambda: socket.recvfrom(1024) case _: - raise SystemError + pytest.fail(f"Invalid parameter: {sock_method_name}") @pytest.fixture @staticmethod @@ -570,7 +714,7 @@ def mock_socket_method(sock_method_name: str, mock_udp_socket: MagicMock) -> Mag case "recvfrom": return mock_udp_socket.recvfrom case _: - raise SystemError + pytest.fail(f"Invalid parameter: {sock_method_name}") @pytest.fixture @staticmethod diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py index 733253b9..265b9367 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py @@ -1000,19 +1000,36 @@ async def test____send_all____sends_data_to_async_socket( # Assert mock_async_socket.sendall.assert_awaited_once_with(b"data") - async def test____send_all_from_iterable____sends_concatenated_data_to_async_socket( + async def test____send_all_from_iterable____use_async_socket_sendmsg( self, socket: RawStreamSocketAdapter, mock_async_socket: MagicMock, ) -> None: # Arrange - mock_async_socket.sendall.return_value = None + mock_async_socket.sendmsg.return_value = None + + # Act + await socket.send_all_from_iterable([b"data", b"to", b"send"]) + + # Assert + mock_async_socket.sendmsg.assert_awaited_once_with([b"data", b"to", b"send"]) + mock_async_socket.sendall.assert_not_called() + + async def test____send_all_from_iterable____fallback_to_sendall( + self, + socket: RawStreamSocketAdapter, + mock_async_socket: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_async_socket.sendmsg.side_effect = UnsupportedOperation # Act await socket.send_all_from_iterable([b"data", b"to", b"send"]) # Assert - mock_async_socket.sendall.assert_awaited_once_with(b"datatosend") + mock_async_socket.sendmsg.assert_awaited_once() + assert mock_async_socket.sendall.await_args_list == list(map(mocker.call, [b"data", b"to", b"send"])) async def test____send_eof____shutdown_socket( self, diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_utils.py b/tests/unit_test/test_async/test_asyncio_backend/test_utils.py index d1cb6bb1..29b349a6 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_utils.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_utils.py @@ -21,6 +21,7 @@ SOCK_DGRAM, SOCK_STREAM, SOL_SOCKET, + SocketType, gaierror, ) from typing import TYPE_CHECKING, Any, Literal, assert_never, cast @@ -30,10 +31,13 @@ create_connection, ensure_resolved, open_listener_sockets_from_getaddrinfo_result, + wait_until_readable, + wait_until_writable, ) import pytest +from ....tools import is_proactor_event_loop from ..._utils import datagram_addrinfo_list, stream_addrinfo_list if TYPE_CHECKING: @@ -635,3 +639,46 @@ def test____open_listener_sockets_from_getaddrinfo_result____ipv6_scope_id_not_p # Assert assert sockets == [mock_socket_ipv6] mock_socket_ipv6.bind.assert_called_once_with(("4e76:f928:6bbc:53ce:c01e:00d5:cdd5:6bbb", 65432, 0, 6)) + + +@pytest.mark.asyncio +@pytest.mark.parametrize( + ["waiter", "event_loop_add_event_func_name", "event_loop_remove_event_func_name"], + [ + pytest.param(wait_until_readable, "add_reader", "remove_reader", id="read"), + pytest.param(wait_until_writable, "add_writer", "remove_writer", id="write"), + ], +) +@pytest.mark.parametrize("future_cancelled", [False, True], ids=lambda p: f"future_cancelled=={p}") +async def test____wait_until___event_wakeup( + waiter: Callable[[SocketType, asyncio.AbstractEventLoop], asyncio.Future[None]], + event_loop: asyncio.AbstractEventLoop, + event_loop_add_event_func_name: str, + event_loop_remove_event_func_name: str, + future_cancelled: bool, + mock_socket_factory: Callable[[], MagicMock], + mocker: MockerFixture, +) -> None: + # Arrange + if is_proactor_event_loop(event_loop): + pytest.skip(f"event_loop.{event_loop_add_event_func_name}() is not supported on asyncio.ProactorEventLoop") + + event_loop_add_event = mocker.patch.object( + event_loop, + event_loop_add_event_func_name, + side_effect=lambda sock, cb, *args: event_loop.call_soon(cb, *args), + ) + event_loop_remove_event = mocker.patch.object(event_loop, event_loop_remove_event_func_name) + mock_socket = mock_socket_factory() + + # Act + fut = waiter(mock_socket, event_loop) + if future_cancelled: + fut.cancel() + await asyncio.sleep(0) + else: + await fut + + # Assert + event_loop_add_event.assert_called_once_with(mock_socket, mocker.ANY, mocker.ANY) + event_loop_remove_event.assert_called_once_with(mock_socket) diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_abc.py b/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_abc.py index dcf6f2fa..ea99a65b 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_abc.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_transports/test_abc.py @@ -22,6 +22,7 @@ def mock_transport(mocker: MockerFixture) -> MagicMock: async def test____send_all_from_iterable____concatenates_chunks_and_call_send_all( self, mock_transport: MagicMock, + mocker: MockerFixture, ) -> None: # Arrange mock_transport.send_all.return_value = None @@ -31,7 +32,7 @@ async def test____send_all_from_iterable____concatenates_chunks_and_call_send_al await AsyncStreamTransport.send_all_from_iterable(mock_transport, chunks) # Assert - mock_transport.send_all.assert_awaited_once_with(b"abc") + assert mock_transport.send_all.await_args_list == list(map(mocker.call, chunks)) async def test____send_all_from_iterable____single_yield____no_copy( self, diff --git a/tests/unit_test/test_sync/test_client/test_tcp.py b/tests/unit_test/test_sync/test_client/test_tcp.py index c496926c..8a9b85bd 100644 --- a/tests/unit_test/test_sync/test_client/test_tcp.py +++ b/tests/unit_test/test_sync/test_client/test_tcp.py @@ -28,7 +28,7 @@ from pytest_mock import MockerFixture from ..._utils import DummyLock -from ...base import UNSUPPORTED_FAMILIES +from ...base import UNSUPPORTED_FAMILIES, MixinTestSocketSendMSG from .base import BaseTestClient @@ -37,7 +37,7 @@ def remove_ssl_OP_IGNORE_UNEXPECTED_EOF(monkeypatch: pytest.MonkeyPatch) -> None monkeypatch.delattr("ssl.OP_IGNORE_UNEXPECTED_EOF", raising=False) -class TestTCPNetworkClient(BaseTestClient): +class TestTCPNetworkClient(BaseTestClient, MixinTestSocketSendMSG): @pytest.fixture(scope="class", params=["AF_INET", "AF_INET6"]) @staticmethod def socket_family(request: Any) -> Any: @@ -76,6 +76,8 @@ def use_ssl(request: Any) -> bool: def mock_tcp_socket(mock_tcp_socket: MagicMock, socket_family: int, socket_fileno: int) -> MagicMock: mock_tcp_socket.family = socket_family mock_tcp_socket.fileno.return_value = socket_fileno + with contextlib.suppress(AttributeError): + del mock_tcp_socket.sendmsg return mock_tcp_socket @pytest.fixture diff --git a/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_abc.py b/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_abc.py index 13653b5e..6a368844 100644 --- a/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_abc.py +++ b/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_abc.py @@ -53,8 +53,8 @@ def test____send_all____several_call( self, mock_transport: MagicMock, mock_transport_send: MagicMock, - mocker: MockerFixture, mock_time_perfcounter: MagicMock, + mocker: MockerFixture, ) -> None: # Arrange mock_transport_send.side_effect = [len(b"pack"), len(b"et"), len(b"\n")] @@ -99,8 +99,20 @@ def test____send_all____invalid_send_return_value( def test____send_all_from_iterable____concatenates_chunks_and_call_send_all( self, mock_transport: MagicMock, + mock_time_perfcounter: MagicMock, + mocker: MockerFixture, ) -> None: # Arrange + now = 12345 + mock_time_perfcounter.side_effect = [ + now, + now + 5, + now + 5, + now + 8, + now + 8, + now + 14, + ] + timeout: float = 123456789 mock_transport.send_all.return_value = None chunks: list[bytes | bytearray | memoryview] = [b"a", bytearray(b"b"), memoryview(b"c")] @@ -108,7 +120,11 @@ def test____send_all_from_iterable____concatenates_chunks_and_call_send_all( StreamTransport.send_all_from_iterable(mock_transport, chunks, 123456789) # Assert - mock_transport.send_all.assert_called_once_with(b"abc", 123456789) + assert mock_transport.send_all.call_args_list == [ + mocker.call(b"a", timeout), + mocker.call(bytearray(b"b"), timeout - 5), + mocker.call(memoryview(b"c"), timeout - 8), + ] def test____send_all_from_iterable____single_yield____no_copy( self, diff --git a/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_socket.py b/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_socket.py index d5454928..c8e81fd4 100644 --- a/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_socket.py +++ b/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_socket.py @@ -6,7 +6,7 @@ import math import os import ssl -from collections.abc import Callable +from collections.abc import Callable, Iterable from socket import SHUT_RDWR, SHUT_WR from typing import TYPE_CHECKING, Any @@ -23,13 +23,38 @@ import pytest +from ....base import MixinTestSocketSendMSG + if TYPE_CHECKING: from unittest.mock import MagicMock + from _typeshed import ReadableBuffer from pytest_mock import MockerFixture -class TestSocketStreamTransport: +def _retry_side_effect(callback: Callable[[], Any], timeout: float) -> Any: + while True: + try: + return callback() + except (WouldBlockOnRead, WouldBlockOnWrite): + pass + + +class TestSocketStreamTransport(MixinTestSocketSendMSG): + @pytest.fixture(autouse=True) + @staticmethod + def mock_transport_retry(mocker: MockerFixture) -> MagicMock: + mock_transport_retry = mocker.patch.object(SocketStreamTransport, "_retry") + mock_transport_retry.side_effect = _retry_side_effect + return mock_transport_retry + + @pytest.fixture + @staticmethod + def mock_transport_send_all(mocker: MockerFixture) -> MagicMock: + mock_transport_send_all = mocker.patch.object(SocketStreamTransport, "send_all", spec=lambda data, timeout: None) + mock_transport_send_all.return_value = None + return mock_transport_send_all + @pytest.fixture @staticmethod def socket_fileno(request: pytest.FixtureRequest) -> int: @@ -37,10 +62,17 @@ def socket_fileno(request: pytest.FixtureRequest) -> int: @pytest.fixture @staticmethod - def mock_tcp_socket(mock_tcp_socket: MagicMock, socket_fileno: int) -> MagicMock: + def mock_tcp_socket(mock_tcp_socket: MagicMock, socket_fileno: int, mocker: MockerFixture) -> MagicMock: mock_tcp_socket.fileno.return_value = socket_fileno mock_tcp_socket.getsockname.return_value = ("local_address", 11111) mock_tcp_socket.getpeername.return_value = ("remote_address", 12345) + + # Always create a new mock instance because sendmsg() is not available on all platforms + # therefore the mocker's autospec will consider sendmsg() unknown on these ones. + mock_tcp_socket.sendmsg = mocker.MagicMock( + spec=lambda *args: None, + side_effect=lambda buffers, *args: sum(map(len, map(memoryview, buffers))), + ) return mock_tcp_socket @pytest.fixture @@ -208,6 +240,168 @@ def test____send_noblock____blocking_error( mock_tcp_socket.fileno.assert_called_once() assert exc_info.value.fileno is mock_tcp_socket.fileno.return_value + def test____send_all_from_iterable____use_socket_sendmsg_when_available( + self, + transport: SocketStreamTransport, + mock_tcp_socket: MagicMock, + mock_transport_retry: MagicMock, + mock_transport_send_all: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + chunks: list[list[bytes]] = [] + + def sendmsg_side_effect(buffers: Iterable[ReadableBuffer]) -> int: + buffers = list(buffers) + chunks.append(list(map(bytes, buffers))) + return sum(map(len, map(memoryview, buffers))) + + mock_tcp_socket.sendmsg.side_effect = sendmsg_side_effect + + # Act + transport.send_all_from_iterable(iter([b"data", b"to", b"send"]), 123456) + + # Assert + mock_transport_send_all.assert_not_called() + mock_transport_retry.assert_called_once_with(mocker.ANY, 123456) + mock_tcp_socket.sendmsg.assert_called_once() + assert chunks == [[b"data", b"to", b"send"]] + + @pytest.mark.parametrize("SC_IOV_MAX", [2], ids=lambda p: f"SC_IOV_MAX=={p}", indirect=True) + def test____send_all_from_iterable____nb_buffers_greather_than_SC_IOV_MAX( + self, + transport: SocketStreamTransport, + mock_tcp_socket: MagicMock, + ) -> None: + # Arrange + chunks: list[list[bytes]] = [] + + def sendmsg_side_effect(buffers: Iterable[ReadableBuffer]) -> int: + buffers = list(buffers) + chunks.append(list(map(bytes, buffers))) + return sum(map(len, map(memoryview, buffers))) + + mock_tcp_socket.sendmsg.side_effect = sendmsg_side_effect + + # Act + transport.send_all_from_iterable(iter([b"a", b"b", b"c", b"d", b"e"]), 123456) + + # Assert + assert mock_tcp_socket.sendmsg.call_count == 3 + assert chunks == [ + [b"a", b"b"], + [b"c", b"d"], + [b"e"], + ] + + def test____send_all_from_iterable____adjust_leftover_buffer( + self, + transport: SocketStreamTransport, + mock_tcp_socket: MagicMock, + ) -> None: + # Arrange + chunks: list[list[bytes]] = [] + + def sendmsg_side_effect(buffers: Iterable[ReadableBuffer]) -> int: + buffers = list(buffers) + chunks.append(list(map(bytes, buffers))) + return min(sum(map(len, map(memoryview, buffers))), 3) + + mock_tcp_socket.sendmsg.side_effect = sendmsg_side_effect + + # Act + transport.send_all_from_iterable(iter([b"abcd", b"efg", b"hijkl", b"mnop"]), 123456) + + # Assert + assert mock_tcp_socket.sendmsg.call_count == 6 + assert chunks == [ + [b"abcd", b"efg", b"hijkl", b"mnop"], + [b"d", b"efg", b"hijkl", b"mnop"], + [b"g", b"hijkl", b"mnop"], + [b"jkl", b"mnop"], + [b"mnop"], + [b"p"], + ] + + @pytest.mark.parametrize("error", [BlockingIOError, InterruptedError]) + def test____send_all_from_iterable____blocking_error( + self, + error: type[OSError], + transport: SocketStreamTransport, + mock_tcp_socket: MagicMock, + mock_transport_retry: MagicMock, + mock_transport_send_all: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + to_raise: list[type[OSError]] = [error] + chunks: list[list[bytes]] = [] + + def sendmsg_side_effect(buffers: Iterable[ReadableBuffer]) -> int: + if to_raise: + raise to_raise.pop(0) + buffers = list(buffers) + chunks.append(list(map(bytes, buffers))) + return sum(map(len, map(memoryview, buffers))) + + mock_tcp_socket.sendmsg.side_effect = sendmsg_side_effect + + # Act + transport.send_all_from_iterable(iter([b"data"]), 123456) + + # Assert + mock_transport_send_all.assert_not_called() + mock_transport_retry.assert_called_once_with(mocker.ANY, 123456) + assert mock_tcp_socket.sendmsg.call_count == 2 + assert chunks == [[b"data"]] + + def test____send_all_from_iterable____fallback_to_send_all____sendmsg_unavailable( + self, + transport: SocketStreamTransport, + mock_tcp_socket: MagicMock, + mock_transport_retry: MagicMock, + mock_transport_send_all: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + del mock_tcp_socket.sendmsg + + # Act + transport.send_all_from_iterable(iter([b"data", b"to", b"send"]), 123456) + + # Assert + mock_transport_retry.assert_not_called() + assert mock_transport_send_all.call_args_list == list( + map( + lambda data: mocker.call(data, mocker.ANY), + [b"data", b"to", b"send"], + ) + ) + + @pytest.mark.parametrize("SC_IOV_MAX", [-1, 0], ids=lambda p: f"SC_IOV_MAX=={p}", indirect=True) + def test____send_all_from_iterable____fallback_to_send_all____sendmsg_available_but_no_defined_limit( + self, + transport: SocketStreamTransport, + mock_tcp_socket: MagicMock, + mock_transport_retry: MagicMock, + mock_transport_send_all: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + + # Act + transport.send_all_from_iterable(iter([b"data", b"to", b"send"]), 123456) + + # Assert + mock_transport_retry.assert_not_called() + mock_tcp_socket.sendmsg.assert_not_called() + assert mock_transport_send_all.call_args_list == list( + map( + lambda data: mocker.call(data, mocker.ANY), + [b"data", b"to", b"send"], + ) + ) + @pytest.mark.parametrize( "os_error", [pytest.param(None)] + list(map(pytest.param, sorted(NOT_CONNECTED_SOCKET_ERRNOS | CLOSED_SOCKET_ERRNOS))), @@ -283,14 +477,6 @@ def test____extra_attributes____address_lookup_raises_OSError( assert transport.extra(extra_attribute, mocker.sentinel.default_value) is mocker.sentinel.default_value -def _retry_side_effect(callback: Callable[[], Any], timeout: float) -> Any: - while True: - try: - return callback() - except (WouldBlockOnRead, WouldBlockOnWrite): - pass - - class TestSSLStreamTransport: @pytest.fixture(autouse=True) @staticmethod @@ -308,8 +494,8 @@ def socket_fileno(request: pytest.FixtureRequest) -> int: @staticmethod def mock_ssl_socket(mock_ssl_socket: MagicMock, socket_fileno: int) -> MagicMock: mock_ssl_socket.fileno.return_value = socket_fileno - mock_ssl_socket.do_handshake.side_effect = [WouldBlockOnRead(socket_fileno), WouldBlockOnWrite(socket_fileno), None] - mock_ssl_socket.unwrap.side_effect = [WouldBlockOnRead(socket_fileno), WouldBlockOnWrite(socket_fileno), None] + mock_ssl_socket.do_handshake.side_effect = [ssl.SSLWantReadError, ssl.SSLWantWriteError, None] + mock_ssl_socket.unwrap.side_effect = [ssl.SSLWantReadError, ssl.SSLWantWriteError, None] mock_ssl_socket.getsockname.return_value = ("local_address", 11111) mock_ssl_socket.getpeername.return_value = ("remote_address", 12345) @@ -544,7 +730,7 @@ def test____close____default( ) -> None: # Arrange if unwrap_error is not None: - mock_ssl_socket.unwrap.side_effect = [WouldBlockOnRead(socket_fileno), WouldBlockOnWrite(socket_fileno), unwrap_error] + mock_ssl_socket.unwrap.side_effect = [ssl.SSLWantReadError, ssl.SSLWantWriteError, unwrap_error] if shutdown_error is not None: mock_ssl_socket.shutdown.side_effect = shutdown_error mock_transport_retry.reset_mock() @@ -556,7 +742,9 @@ def test____close____default( if standard_compatible: assert mock_ssl_socket.mock_calls == [ mocker.call.unwrap(), + mocker.call.fileno(), mocker.call.unwrap(), + mocker.call.fileno(), mocker.call.unwrap(), mocker.call.shutdown(SHUT_RDWR), mocker.call.close(), diff --git a/tests/unit_test/test_tools/test_utils.py b/tests/unit_test/test_tools/test_utils.py index de58de13..41d6780c 100644 --- a/tests/unit_test/test_tools/test_utils.py +++ b/tests/unit_test/test_tools/test_utils.py @@ -5,15 +5,17 @@ import os import ssl import threading +from collections import deque from collections.abc import Callable from errno import EINVAL, ENOTCONN, errorcode as errno_errorcode -from socket import SO_ERROR, SOL_SOCKET +from socket import SO_ERROR, SOL_SOCKET, SocketType from typing import TYPE_CHECKING, Any from easynetwork.exceptions import BusyResourceError from easynetwork.lowlevel._utils import ( ElapsedTime, ResourceGuard, + adjust_leftover_buffer, check_real_socket_state, check_socket_family, check_socket_is_connected, @@ -31,6 +33,7 @@ remove_traceback_frames_in_place, replace_kwargs, set_reuseport, + supports_socket_sendmsg, validate_timeout_delay, ) from easynetwork.lowlevel.constants import NOT_CONNECTED_SOCKET_ERRNOS @@ -227,6 +230,14 @@ def test____check_socket_family____invalid_family(socket_family: int) -> None: check_socket_family(socket_family) +def test____supports_socket_sendmsg____checks_socket_type(mock_socket_factory: Callable[[], MagicMock]) -> None: + # Arrange + mock_socket = mock_socket_factory() + + # Act & Assert + assert supports_socket_sendmsg(mock_socket) is hasattr(SocketType, "sendmsg") + + def test____is_ssl_socket____regular_socket(mock_socket_factory: Callable[[], MagicMock]) -> None: # Arrange mock_socket = mock_socket_factory() @@ -348,6 +359,39 @@ def test____iter_bytes____iterate_over_bytes_returning_one_byte() -> None: assert result == expected_result +def test____adjust_leftover_buffer____consume_whole_buffer() -> None: + # Arrange + buffers: deque[memoryview] = deque(map(memoryview, [b"abc", b"de", b"fgh"])) + + # Act + adjust_leftover_buffer(buffers, 8) + + # Assert + assert not buffers + + +def test____adjust_leftover_buffer____remove_some_buffers() -> None: + # Arrange + buffers: deque[memoryview] = deque(map(memoryview, [b"abc", b"de", b"fgh"])) + + # Act + adjust_leftover_buffer(buffers, 5) + + # Assert + assert list(buffers) == list(map(memoryview, [b"fgh"])) + + +def test____adjust_leftover_buffer____partial_buffer_remove() -> None: + # Arrange + buffers: deque[memoryview] = deque(map(memoryview, [b"abc", b"de", b"fgh"])) + + # Act + adjust_leftover_buffer(buffers, 4) + + # Assert + assert list(buffers) == list(map(memoryview, [b"e", b"fgh"])) + + def test____is_socket_connected____getpeername_returns(mock_tcp_socket: MagicMock) -> None: # Arrange mock_tcp_socket.getpeername.return_value = ("127.0.0.1", 12345)