From 47d9de37488ab61bef2317867b747f145bc47545 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Thu, 26 Sep 2024 08:03:35 +0200 Subject: [PATCH 1/2] Trio backend: Better way to connect a socket --- src/easynetwork/lowlevel/_utils.py | 4 +- .../api_async/backend/_trio/dns_resolver.py | 28 ++--- .../test_trio_backend/test_dns_resolver.py | 100 ++++++++---------- 3 files changed, 60 insertions(+), 72 deletions(-) diff --git a/src/easynetwork/lowlevel/_utils.py b/src/easynetwork/lowlevel/_utils.py index 1d1a289a..1be19f7b 100644 --- a/src/easynetwork/lowlevel/_utils.py +++ b/src/easynetwork/lowlevel/_utils.py @@ -166,7 +166,7 @@ def check_socket_family(family: int) -> None: raise ValueError("Only these families are supported: AF_INET, AF_INET6") -def check_real_socket_state(socket: ISocket) -> None: +def check_real_socket_state(socket: ISocket, error_msg: str = "{strerror}") -> None: """Verify socket saved error and raise OSError if there is one There are some functions such as socket.send() which do not immediately fail and save the errno @@ -180,7 +180,7 @@ def check_real_socket_state(socket: ISocket) -> None: errno = socket.getsockopt(_socket.SOL_SOCKET, _socket.SO_ERROR) if errno != 0: # The SO_ERROR is automatically reset to zero after getting the value - raise error_from_errno(errno) + raise error_from_errno(errno, error_msg) class _SupportsSocketSendMSG(Protocol): diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/dns_resolver.py b/src/easynetwork/lowlevel/api_async/backend/_trio/dns_resolver.py index a6ee7522..f44726bd 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_trio/dns_resolver.py +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/dns_resolver.py @@ -21,8 +21,9 @@ import socket as _socket -import trio.socket +import trio +from .... import _utils from .._common.dns_resolver import BaseAsyncDNSResolver @@ -30,21 +31,14 @@ class TrioDNSResolver(BaseAsyncDNSResolver): __slots__ = () async def connect_socket(self, socket: _socket.socket, address: tuple[str, int] | tuple[str, int, int, int]) -> None: - # TL;DR: Why not directly use trio.socket.socket() function? - # When giving a fileno, it tries to guess the real family, type and proto of the file descriptor - # by calling getsockopt(). This extra operation is useless here. - async_socket = trio.socket.from_stdlib_socket( - _socket.socket(socket.family, socket.type, socket.proto, fileno=socket.fileno()) - ) + await trio.lowlevel.checkpoint_if_cancelled() try: - await async_socket.connect(address) - except BaseException: - # If connect() raises an exception, let trio close the socket. - # NOTE: connect() already closes the socket if trio.Cancelled is raised. - socket.detach() - raise + socket.connect(address) + except BlockingIOError: + pass else: - # The operation has succeeded, remove the ownership to the temporary socket. - async_socket.detach() - finally: - async_socket.close() + await trio.lowlevel.cancel_shielded_checkpoint() + return + + await trio.lowlevel.wait_writable(socket) + _utils.check_real_socket_state(socket, error_msg=f"Could not connect to {address!r}: {{strerror}}") diff --git a/tests/unit_test/test_async/test_trio_backend/test_dns_resolver.py b/tests/unit_test/test_async/test_trio_backend/test_dns_resolver.py index bece30ec..b763cb09 100644 --- a/tests/unit_test/test_async/test_trio_backend/test_dns_resolver.py +++ b/tests/unit_test/test_async/test_trio_backend/test_dns_resolver.py @@ -1,40 +1,30 @@ from __future__ import annotations +import errno import socket -from collections.abc import AsyncIterator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any import pytest -from ....fixtures.trio import trio_fixture - if TYPE_CHECKING: - from trio import SocketListener, SocketStream + from unittest.mock import AsyncMock, MagicMock from easynetwork.lowlevel.api_async.backend._trio.dns_resolver import TrioDNSResolver + from pytest_mock import MockerFixture + @pytest.mark.feature_trio(async_test_auto_mark=True) class TestTrioDNSResolver: - @trio_fixture + @pytest.fixture(autouse=True) @staticmethod - async def listener() -> AsyncIterator[SocketListener]: + def mock_trio_lowlevel_wait_writable(mocker: MockerFixture) -> AsyncMock: import trio - async with (await trio.open_tcp_listeners(0, host="127.0.0.1"))[0] as listener: - yield listener - - @pytest.fixture - @staticmethod - def listener_address(listener: SocketListener) -> tuple[str, int]: - return listener.socket.getsockname() + async def wait_writable(sock: Any) -> None: + await trio.lowlevel.checkpoint() - @pytest.fixture - @staticmethod - def client_sock(listener: SocketListener) -> socket.socket: - sock = socket.socket(family=listener.socket.family, type=listener.socket.type) - sock.setblocking(False) - return sock + return mocker.patch("trio.lowlevel.wait_writable", autospec=True, side_effect=wait_writable) @pytest.fixture @staticmethod @@ -43,61 +33,65 @@ def dns_resolver() -> TrioDNSResolver: return TrioDNSResolver() - async def test____connect_socket____works( + async def test____connect_socket____works____non_blocking( self, dns_resolver: TrioDNSResolver, - listener: SocketListener, - listener_address: tuple[str, int], - client_sock: socket.socket, + mock_tcp_socket: MagicMock, + mock_trio_lowlevel_wait_writable: AsyncMock, + mocker: MockerFixture, ) -> None: # Arrange - import trio + mock_tcp_socket.getsockopt.return_value = 0 + mock_tcp_socket.connect.return_value = None # Act - server_stream: SocketStream | None = None - async with trio.open_nursery() as nursery: - nursery.cancel_scope.deadline = trio.current_time() + 1 - nursery.start_soon(dns_resolver.connect_socket, client_sock, listener_address) - - await trio.sleep(0.5) - server_stream = await listener.accept() + await dns_resolver.connect_socket(mock_tcp_socket, ("127.0.0.1", 12345)) # Assert - assert server_stream is not None - assert client_sock.fileno() > 0 + assert mock_tcp_socket.mock_calls == [mocker.call.connect(("127.0.0.1", 12345))] + mock_trio_lowlevel_wait_writable.assert_not_awaited() - async with server_stream, trio.SocketStream(trio.socket.from_stdlib_socket(client_sock)) as client_stream: - await client_stream.send_all(b"data") - assert (await server_stream.receive_some()) == b"data" - - async def test____connect_socket____close_on_cancel( + async def test____connect_socket____works____blocking( self, dns_resolver: TrioDNSResolver, - listener_address: tuple[str, int], - client_sock: socket.socket, + mock_tcp_socket: MagicMock, + mock_trio_lowlevel_wait_writable: AsyncMock, + mocker: MockerFixture, ) -> None: # Arrange - import trio + mock_tcp_socket.getsockopt.side_effect = [0] + mock_tcp_socket.connect.side_effect = [BlockingIOError] # Act - with trio.move_on_after(0) as scope: - await dns_resolver.connect_socket(client_sock, listener_address) + await dns_resolver.connect_socket(mock_tcp_socket, ("127.0.0.1", 12345)) # Assert - assert scope.cancelled_caught - assert client_sock.fileno() < 0 - - async def test____connect_socket____close_on_error( + assert mock_tcp_socket.mock_calls == [ + mocker.call.connect(("127.0.0.1", 12345)), + mocker.call.fileno(), + mocker.call.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR), + ] + mock_trio_lowlevel_wait_writable.assert_awaited_once_with(mock_tcp_socket) + + async def test____connect_socket____works____blocking____connection_error( self, dns_resolver: TrioDNSResolver, - client_sock: socket.socket, + mock_tcp_socket: MagicMock, + mock_trio_lowlevel_wait_writable: AsyncMock, + mocker: MockerFixture, ) -> None: # Arrange - listener_address = ("unknown_address", 12345) + mock_tcp_socket.getsockopt.side_effect = [errno.ECONNREFUSED, 0] + mock_tcp_socket.connect.side_effect = [BlockingIOError] # Act - with pytest.raises(OSError): - await dns_resolver.connect_socket(client_sock, listener_address) + with pytest.raises(ConnectionRefusedError, match=r"^\[Errno \d+\] Could not connect to .+: .*$"): + await dns_resolver.connect_socket(mock_tcp_socket, ("127.0.0.1", 12345)) # Assert - assert client_sock.fileno() < 0 + assert mock_tcp_socket.mock_calls == [ + mocker.call.connect(("127.0.0.1", 12345)), + mocker.call.fileno(), + mocker.call.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR), + ] + mock_trio_lowlevel_wait_writable.assert_awaited_once_with(mock_tcp_socket) From b0ddf933d62ee795c05743b8918075b4b30e60b9 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Thu, 26 Sep 2024 18:39:14 +0200 Subject: [PATCH 2/2] Fixed unit tests --- src/easynetwork/lowlevel/_utils.py | 7 +++++-- tests/unit_test/test_tools/test_utils.py | 20 ++++++++++++++++++++ 2 files changed, 25 insertions(+), 2 deletions(-) diff --git a/src/easynetwork/lowlevel/_utils.py b/src/easynetwork/lowlevel/_utils.py index 1be19f7b..e5afc1cd 100644 --- a/src/easynetwork/lowlevel/_utils.py +++ b/src/easynetwork/lowlevel/_utils.py @@ -166,7 +166,7 @@ def check_socket_family(family: int) -> None: raise ValueError("Only these families are supported: AF_INET, AF_INET6") -def check_real_socket_state(socket: ISocket, error_msg: str = "{strerror}") -> None: +def check_real_socket_state(socket: ISocket, error_msg: str | None = None) -> None: """Verify socket saved error and raise OSError if there is one There are some functions such as socket.send() which do not immediately fail and save the errno @@ -180,7 +180,10 @@ def check_real_socket_state(socket: ISocket, error_msg: str = "{strerror}") -> N errno = socket.getsockopt(_socket.SOL_SOCKET, _socket.SO_ERROR) if errno != 0: # The SO_ERROR is automatically reset to zero after getting the value - raise error_from_errno(errno, error_msg) + if error_msg: + raise error_from_errno(errno, error_msg) + else: + raise error_from_errno(errno) class _SupportsSocketSendMSG(Protocol): diff --git a/tests/unit_test/test_tools/test_utils.py b/tests/unit_test/test_tools/test_utils.py index 75042105..e965eada 100644 --- a/tests/unit_test/test_tools/test_utils.py +++ b/tests/unit_test/test_tools/test_utils.py @@ -338,6 +338,26 @@ def test____check_real_socket_state____socket_with_error(mock_tcp_socket: MagicM mock_error_from_errno.assert_called_once_with(errno) +def test____check_real_socket_state____socket_with_error____custom_message( + mock_tcp_socket: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + errno = 123456 + exception = OSError(errno, "errno message") + mock_tcp_socket.getsockopt.return_value = errno + mock_error_from_errno = mocker.patch(f"{error_from_errno.__module__}.{error_from_errno.__qualname__}", return_value=exception) + + # Act + with pytest.raises(OSError) as exc_info: + check_real_socket_state(mock_tcp_socket, error_msg="unrelated error: {strerror}") + + # Assert + assert exc_info.value is exception + mock_tcp_socket.getsockopt.assert_called_once_with(SOL_SOCKET, SO_ERROR) + mock_error_from_errno.assert_called_once_with(errno, "unrelated error: {strerror}") + + def test____check_real_socket_state____closed_socket(mock_tcp_socket: MagicMock, mocker: MockerFixture) -> None: # Arrange mock_tcp_socket.fileno.return_value = -1