Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Trio backend: Better way to connect a socket #353

Merged
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
7 changes: 5 additions & 2 deletions src/easynetwork/lowlevel/_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -166,7 +166,7 @@ def check_socket_family(family: int) -> None:
raise ValueError("Only these families are supported: AF_INET, AF_INET6")


def check_real_socket_state(socket: ISocket) -> None:
def check_real_socket_state(socket: ISocket, error_msg: str | None = None) -> None:
"""Verify socket saved error and raise OSError if there is one

There are some functions such as socket.send() which do not immediately fail and save the errno
Expand All @@ -180,7 +180,10 @@ def check_real_socket_state(socket: ISocket) -> None:
errno = socket.getsockopt(_socket.SOL_SOCKET, _socket.SO_ERROR)
if errno != 0:
# The SO_ERROR is automatically reset to zero after getting the value
raise error_from_errno(errno)
if error_msg:
raise error_from_errno(errno, error_msg)
else:
raise error_from_errno(errno)


class _SupportsSocketSendMSG(Protocol):
Expand Down
28 changes: 11 additions & 17 deletions src/easynetwork/lowlevel/api_async/backend/_trio/dns_resolver.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,30 +21,24 @@

import socket as _socket

import trio.socket
import trio

from .... import _utils
from .._common.dns_resolver import BaseAsyncDNSResolver


class TrioDNSResolver(BaseAsyncDNSResolver):
__slots__ = ()

async def connect_socket(self, socket: _socket.socket, address: tuple[str, int] | tuple[str, int, int, int]) -> None:
# TL;DR: Why not directly use trio.socket.socket() function?
# When giving a fileno, it tries to guess the real family, type and proto of the file descriptor
# by calling getsockopt(). This extra operation is useless here.
async_socket = trio.socket.from_stdlib_socket(
_socket.socket(socket.family, socket.type, socket.proto, fileno=socket.fileno())
)
await trio.lowlevel.checkpoint_if_cancelled()
try:
await async_socket.connect(address)
except BaseException:
# If connect() raises an exception, let trio close the socket.
# NOTE: connect() already closes the socket if trio.Cancelled is raised.
socket.detach()
raise
socket.connect(address)
except BlockingIOError:
pass
else:
# The operation has succeeded, remove the ownership to the temporary socket.
async_socket.detach()
finally:
async_socket.close()
await trio.lowlevel.cancel_shielded_checkpoint()
return

await trio.lowlevel.wait_writable(socket)
_utils.check_real_socket_state(socket, error_msg=f"Could not connect to {address!r}: {{strerror}}")
100 changes: 47 additions & 53 deletions tests/unit_test/test_async/test_trio_backend/test_dns_resolver.py
Original file line number Diff line number Diff line change
@@ -1,40 +1,30 @@
from __future__ import annotations

import errno
import socket
from collections.abc import AsyncIterator
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any

import pytest

from ....fixtures.trio import trio_fixture

if TYPE_CHECKING:
from trio import SocketListener, SocketStream
from unittest.mock import AsyncMock, MagicMock

from easynetwork.lowlevel.api_async.backend._trio.dns_resolver import TrioDNSResolver

from pytest_mock import MockerFixture


@pytest.mark.feature_trio(async_test_auto_mark=True)
class TestTrioDNSResolver:
@trio_fixture
@pytest.fixture(autouse=True)
@staticmethod
async def listener() -> AsyncIterator[SocketListener]:
def mock_trio_lowlevel_wait_writable(mocker: MockerFixture) -> AsyncMock:
import trio

async with (await trio.open_tcp_listeners(0, host="127.0.0.1"))[0] as listener:
yield listener

@pytest.fixture
@staticmethod
def listener_address(listener: SocketListener) -> tuple[str, int]:
return listener.socket.getsockname()
async def wait_writable(sock: Any) -> None:
await trio.lowlevel.checkpoint()

@pytest.fixture
@staticmethod
def client_sock(listener: SocketListener) -> socket.socket:
sock = socket.socket(family=listener.socket.family, type=listener.socket.type)
sock.setblocking(False)
return sock
return mocker.patch("trio.lowlevel.wait_writable", autospec=True, side_effect=wait_writable)

@pytest.fixture
@staticmethod
Expand All @@ -43,61 +33,65 @@ def dns_resolver() -> TrioDNSResolver:

return TrioDNSResolver()

async def test____connect_socket____works(
async def test____connect_socket____works____non_blocking(
self,
dns_resolver: TrioDNSResolver,
listener: SocketListener,
listener_address: tuple[str, int],
client_sock: socket.socket,
mock_tcp_socket: MagicMock,
mock_trio_lowlevel_wait_writable: AsyncMock,
mocker: MockerFixture,
) -> None:
# Arrange
import trio
mock_tcp_socket.getsockopt.return_value = 0
mock_tcp_socket.connect.return_value = None

# Act
server_stream: SocketStream | None = None
async with trio.open_nursery() as nursery:
nursery.cancel_scope.deadline = trio.current_time() + 1
nursery.start_soon(dns_resolver.connect_socket, client_sock, listener_address)

await trio.sleep(0.5)
server_stream = await listener.accept()
await dns_resolver.connect_socket(mock_tcp_socket, ("127.0.0.1", 12345))

# Assert
assert server_stream is not None
assert client_sock.fileno() > 0
assert mock_tcp_socket.mock_calls == [mocker.call.connect(("127.0.0.1", 12345))]
mock_trio_lowlevel_wait_writable.assert_not_awaited()

async with server_stream, trio.SocketStream(trio.socket.from_stdlib_socket(client_sock)) as client_stream:
await client_stream.send_all(b"data")
assert (await server_stream.receive_some()) == b"data"

async def test____connect_socket____close_on_cancel(
async def test____connect_socket____works____blocking(
self,
dns_resolver: TrioDNSResolver,
listener_address: tuple[str, int],
client_sock: socket.socket,
mock_tcp_socket: MagicMock,
mock_trio_lowlevel_wait_writable: AsyncMock,
mocker: MockerFixture,
) -> None:
# Arrange
import trio
mock_tcp_socket.getsockopt.side_effect = [0]
mock_tcp_socket.connect.side_effect = [BlockingIOError]

# Act
with trio.move_on_after(0) as scope:
await dns_resolver.connect_socket(client_sock, listener_address)
await dns_resolver.connect_socket(mock_tcp_socket, ("127.0.0.1", 12345))

# Assert
assert scope.cancelled_caught
assert client_sock.fileno() < 0

async def test____connect_socket____close_on_error(
assert mock_tcp_socket.mock_calls == [
mocker.call.connect(("127.0.0.1", 12345)),
mocker.call.fileno(),
mocker.call.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR),
]
mock_trio_lowlevel_wait_writable.assert_awaited_once_with(mock_tcp_socket)

async def test____connect_socket____works____blocking____connection_error(
self,
dns_resolver: TrioDNSResolver,
client_sock: socket.socket,
mock_tcp_socket: MagicMock,
mock_trio_lowlevel_wait_writable: AsyncMock,
mocker: MockerFixture,
) -> None:
# Arrange
listener_address = ("unknown_address", 12345)
mock_tcp_socket.getsockopt.side_effect = [errno.ECONNREFUSED, 0]
mock_tcp_socket.connect.side_effect = [BlockingIOError]

# Act
with pytest.raises(OSError):
await dns_resolver.connect_socket(client_sock, listener_address)
with pytest.raises(ConnectionRefusedError, match=r"^\[Errno \d+\] Could not connect to .+: .*$"):
await dns_resolver.connect_socket(mock_tcp_socket, ("127.0.0.1", 12345))

# Assert
assert client_sock.fileno() < 0
assert mock_tcp_socket.mock_calls == [
mocker.call.connect(("127.0.0.1", 12345)),
mocker.call.fileno(),
mocker.call.getsockopt(socket.SOL_SOCKET, socket.SO_ERROR),
]
mock_trio_lowlevel_wait_writable.assert_awaited_once_with(mock_tcp_socket)
20 changes: 20 additions & 0 deletions tests/unit_test/test_tools/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -338,6 +338,26 @@ def test____check_real_socket_state____socket_with_error(mock_tcp_socket: MagicM
mock_error_from_errno.assert_called_once_with(errno)


def test____check_real_socket_state____socket_with_error____custom_message(
mock_tcp_socket: MagicMock,
mocker: MockerFixture,
) -> None:
# Arrange
errno = 123456
exception = OSError(errno, "errno message")
mock_tcp_socket.getsockopt.return_value = errno
mock_error_from_errno = mocker.patch(f"{error_from_errno.__module__}.{error_from_errno.__qualname__}", return_value=exception)

# Act
with pytest.raises(OSError) as exc_info:
check_real_socket_state(mock_tcp_socket, error_msg="unrelated error: {strerror}")

# Assert
assert exc_info.value is exception
mock_tcp_socket.getsockopt.assert_called_once_with(SOL_SOCKET, SO_ERROR)
mock_error_from_errno.assert_called_once_with(errno, "unrelated error: {strerror}")


def test____check_real_socket_state____closed_socket(mock_tcp_socket: MagicMock, mocker: MockerFixture) -> None:
# Arrange
mock_tcp_socket.fileno.return_value = -1
Expand Down