diff --git a/src/easynetwork/lowlevel/_utils.py b/src/easynetwork/lowlevel/_utils.py index 51ef3d1a..b2914efa 100644 --- a/src/easynetwork/lowlevel/_utils.py +++ b/src/easynetwork/lowlevel/_utils.py @@ -50,7 +50,7 @@ from abc import abstractmethod from collections import deque from collections.abc import Callable, Iterable, Iterator -from typing import TYPE_CHECKING, Any, Concatenate, Final, ParamSpec, Protocol, Self, TypeGuard, TypeVar, overload +from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Protocol, Self, TypeGuard, TypeVar, overload try: import ssl as _ssl @@ -180,9 +180,6 @@ def check_real_socket_state(socket: ISocket) -> None: raise error_from_errno(errno) -_HAS_SENDMSG: Final[bool] = hasattr(_socket.socket, "sendmsg") - - class _SupportsSocketSendMSG(Protocol): @abstractmethod def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> int: ... @@ -190,7 +187,7 @@ def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> int: ... def supports_socket_sendmsg(sock: _socket.socket) -> TypeGuard[_SupportsSocketSendMSG]: assert isinstance(sock, _socket.SocketType) # nosec assert_used - return _HAS_SENDMSG + return hasattr(sock, "sendmsg") def is_ssl_socket(socket: _socket.socket) -> TypeGuard[_SSLSocket]: @@ -236,6 +233,8 @@ def iter_bytes(b: bytes | bytearray | memoryview) -> Iterator[bytes]: def adjust_leftover_buffer(buffers: deque[memoryview], nbytes: int) -> None: while nbytes > 0: b = buffers.popleft() + if b.itemsize != 1: + b = b.cast("B") b_len = len(b) if b_len <= nbytes: nbytes -= b_len @@ -277,13 +276,13 @@ def ensure_datagram_socket_bound(sock: _socket.socket) -> None: def set_reuseport(sock: SupportsSocketOptions) -> None: - if not hasattr(_socket, "SO_REUSEPORT"): - raise ValueError("reuse_port not supported by socket module") - else: + if hasattr(_socket, "SO_REUSEPORT"): try: sock.setsockopt(_socket.SOL_SOCKET, _socket.SO_REUSEPORT, True) except OSError: raise ValueError("reuse_port not supported by socket module, SO_REUSEPORT defined but not implemented.") from None + else: + raise ValueError("reuse_port not supported by socket module") def open_listener_sockets_from_getaddrinfo_result( diff --git a/src/easynetwork/lowlevel/api_async/transports/abc.py b/src/easynetwork/lowlevel/api_async/transports/abc.py index f41c3cca..d7f1bbfc 100644 --- a/src/easynetwork/lowlevel/api_async/transports/abc.py +++ b/src/easynetwork/lowlevel/api_async/transports/abc.py @@ -173,7 +173,7 @@ async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytear Parameters: iterable_of_data: An :term:`iterable` yielding the bytes to send. """ - for data in iterable_of_data: + for data in list(iterable_of_data): await self.send_all(data) diff --git a/src/easynetwork/lowlevel/api_sync/transports/abc.py b/src/easynetwork/lowlevel/api_sync/transports/abc.py index df8f41ae..b0da0dc4 100644 --- a/src/easynetwork/lowlevel/api_sync/transports/abc.py +++ b/src/easynetwork/lowlevel/api_sync/transports/abc.py @@ -202,7 +202,7 @@ def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | ValueError: Negative `timeout`. TimeoutError: Operation timed out. """ - for data in iterable_of_data: + for data in list(iterable_of_data): with _utils.ElapsedTime() as elapsed: self.send_all(data, timeout) timeout = elapsed.recompute_timeout(timeout) diff --git a/src/easynetwork/lowlevel/api_sync/transports/socket.py b/src/easynetwork/lowlevel/api_sync/transports/socket.py index 45c4ee5d..1cc0b090 100644 --- a/src/easynetwork/lowlevel/api_sync/transports/socket.py +++ b/src/easynetwork/lowlevel/api_sync/transports/socket.py @@ -130,21 +130,22 @@ def send_noblock(self, data: bytes | bytearray | memoryview) -> int: @_utils.inherit_doc(base_selector.SelectorStreamTransport) def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview], timeout: float) -> None: - _sock = self.__socket - if constants.SC_IOV_MAX <= 0 or not _utils.supports_socket_sendmsg(_sock): + if constants.SC_IOV_MAX <= 0: return super().send_all_from_iterable(iterable_of_data, timeout) - buffers: deque[memoryview] = deque(memoryview(data).cast("B") for data in iterable_of_data) - del iterable_of_data + socket = self.__socket + socket_fileno = self.__socket.fileno + if not _utils.supports_socket_sendmsg(socket): + return super().send_all_from_iterable(iterable_of_data, timeout) - sock_sendmsg = _sock.sendmsg - del _sock + buffers: deque[memoryview] = deque(map(memoryview, iterable_of_data)) + del iterable_of_data def try_sendmsg() -> int: try: - return sock_sendmsg(itertools.islice(buffers, constants.SC_IOV_MAX)) + return socket.sendmsg(itertools.islice(buffers, constants.SC_IOV_MAX)) except (BlockingIOError, InterruptedError): - raise base_selector.WouldBlockOnWrite(self.__socket.fileno()) from None + raise base_selector.WouldBlockOnWrite(socket_fileno()) from None while buffers: sent, timeout = self._retry(try_sendmsg, timeout) @@ -291,6 +292,13 @@ def send_noblock(self, data: bytes | bytearray | memoryview) -> int: except _ssl_module.SSLZeroReturnError if _ssl_module else () as exc: raise _utils.error_from_errno(errno.ECONNRESET) from exc + @_utils.inherit_doc(base_selector.SelectorStreamTransport) + def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview], timeout: float) -> None: + # Send a whole chunk to minimize TLS exchanges + data = b"".join(iterable_of_data) + del iterable_of_data + return self.send_all(data, timeout) + @_utils.inherit_doc(base_selector.SelectorStreamTransport) def send_eof(self) -> None: # ssl.SSLSocket.shutdown() would close both read and write streams diff --git a/tests/unit_test/base.py b/tests/unit_test/base.py index df7d122a..20979a8a 100644 --- a/tests/unit_test/base.py +++ b/tests/unit_test/base.py @@ -1,7 +1,7 @@ from __future__ import annotations from collections.abc import Generator -from socket import AF_INET, AF_INET6, socket as Socket +from socket import AF_INET, AF_INET6 from typing import TYPE_CHECKING, Any import pytest @@ -102,14 +102,6 @@ def SC_IOV_MAX(request: pytest.FixtureRequest, monkeypatch: pytest.MonkeyPatch) monkeypatch.setattr("easynetwork.lowlevel.constants.SC_IOV_MAX", value) return value - @pytest.fixture(autouse=True) - @staticmethod - def supports_socket_sendmsg(mocker: MockerFixture) -> None: - def supports_socket_sendmsg(sock: Socket) -> bool: - return hasattr(sock, "sendmsg") - - mocker.patch("easynetwork.lowlevel._utils.supports_socket_sendmsg", supports_socket_sendmsg) - class BaseTestWithStreamProtocol: @pytest.fixture diff --git a/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_socket.py b/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_socket.py index 5b2eb8b5..a971f7d8 100644 --- a/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_socket.py +++ b/tests/unit_test/test_sync/test_lowlevel_api/test_transports/test_socket.py @@ -562,6 +562,13 @@ def mock_transport_retry(mocker: MockerFixture) -> MagicMock: mock_transport_retry.side_effect = _retry_side_effect return mock_transport_retry + @pytest.fixture + @staticmethod + def mock_transport_send_all(mocker: MockerFixture) -> MagicMock: + mock_transport_send_all = mocker.patch.object(SSLStreamTransport, "send_all", spec=lambda data, timeout: None) + mock_transport_send_all.return_value = None + return mock_transport_send_all + @pytest.fixture @staticmethod def socket_fileno(request: pytest.FixtureRequest) -> int: @@ -1048,6 +1055,19 @@ def test____send_noblock____SSLZeroReturnError( mock_ssl_socket.fileno.assert_not_called() assert exc_info.value.errno == errno.ECONNRESET + def test____send_all_from_iterable____concatenate_data( + self, + transport: SSLStreamTransport, + mock_transport_send_all: MagicMock, + ) -> None: + # Arrange + + # Act + transport.send_all_from_iterable(iter([b"data", b" to ", b"send"]), 123456) + + # Assert + mock_transport_send_all.assert_called_once_with(b"data to send", 123456) + def test____send_eof____default( self, transport: SSLStreamTransport, diff --git a/tests/unit_test/test_tools/test_utils.py b/tests/unit_test/test_tools/test_utils.py index 4ad559f5..ef958e66 100644 --- a/tests/unit_test/test_tools/test_utils.py +++ b/tests/unit_test/test_tools/test_utils.py @@ -10,7 +10,7 @@ from collections import deque from collections.abc import Callable from errno import EINVAL, ENOTCONN, errorcode as errno_errorcode -from socket import SO_ERROR, SOL_SOCKET, SocketType +from socket import SO_ERROR, SOL_SOCKET from typing import TYPE_CHECKING, Any from easynetwork.exceptions import BusyResourceError @@ -368,12 +368,31 @@ def test____check_socket_family____invalid_family(socket_family: int) -> None: check_socket_family(socket_family) -def test____supports_socket_sendmsg____checks_socket_type(mock_socket_factory: Callable[[], MagicMock]) -> None: +def test____supports_socket_sendmsg____have_sendmsg_method( + mock_socket_factory: Callable[[], MagicMock], + mocker: MockerFixture, +) -> None: + # Arrange + mock_socket = mock_socket_factory() + mock_socket.sendmsg = mocker.MagicMock( + spec=lambda *args: None, + side_effect=lambda buffers, *args: sum(map(len, map(memoryview, buffers))), + ) + + # Act & Assert + assert supports_socket_sendmsg(mock_socket) + + +def test____supports_socket_sendmsg____dont_have_sendmsg_method( + mock_socket_factory: Callable[[], MagicMock], + mocker: MockerFixture, +) -> None: # Arrange mock_socket = mock_socket_factory() + del mock_socket.sendmsg # Act & Assert - assert supports_socket_sendmsg(mock_socket) is hasattr(SocketType, "sendmsg") + assert not supports_socket_sendmsg(mock_socket) def test____is_ssl_socket____regular_socket(mock_socket_factory: Callable[[], MagicMock]) -> None: @@ -530,6 +549,21 @@ def test____adjust_leftover_buffer____partial_buffer_remove() -> None: assert list(buffers) == list(map(memoryview, [b"e", b"fgh"])) +def test____adjust_leftover_buffer____handle_view_with_different_item_sizes() -> None: + # Arrange + import array + + item = array.array("i", [56, 23, 45, -4]) + + buffers: deque[memoryview] = deque(map(memoryview, [item])) + + # Act + adjust_leftover_buffer(buffers, item.itemsize * 2) + + # Assert + assert list(buffers) == list(map(memoryview, [bytes(item[2:])])) + + def test____is_socket_connected____getpeername_returns(mock_tcp_socket: MagicMock) -> None: # Arrange mock_tcp_socket.getpeername.return_value = ("127.0.0.1", 12345)