From d4303f92574ec2f3e9ac5a69e33fba4db761a303 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francis=20Clairicia-Rose-Claire-Jos=C3=A9phine?= Date: Sun, 3 Dec 2023 17:08:06 +0100 Subject: [PATCH] Added 'family' parameter to UDP clients (#180) --- src/easynetwork/api_async/client/udp.py | 3 + src/easynetwork/api_sync/client/udp.py | 6 +- .../lowlevel/api_async/backend/abc.py | 1 + .../lowlevel/asyncio/_asyncio_utils.py | 93 +++++++++--- src/easynetwork/lowlevel/asyncio/backend.py | 12 +- tests/unit_test/_utils.py | 6 +- .../test_api/test_client/test_udp.py | 72 ++++++++- .../test_asyncio_backend/test_backend.py | 11 +- .../test_asyncio_backend/test_utils.py | 143 +++++++++++++----- .../test_sync/test_client/test_udp.py | 117 ++++++++++++-- 10 files changed, 373 insertions(+), 91 deletions(-) diff --git a/src/easynetwork/api_async/client/udp.py b/src/easynetwork/api_async/client/udp.py index 333b8a3f..06876b9d 100644 --- a/src/easynetwork/api_async/client/udp.py +++ b/src/easynetwork/api_async/client/udp.py @@ -76,6 +76,7 @@ def __init__( protocol: DatagramProtocol[_SentPacketT, _ReceivedPacketT], *, local_address: tuple[str, int] | None = ..., + family: int = ..., backend: str | AsyncBackend | None = ..., backend_kwargs: Mapping[str, Any] | None = ..., ) -> None: @@ -147,6 +148,8 @@ def __init__( case (str(host), int(port)): if kwargs.get("local_address") is None: kwargs["local_address"] = ("localhost", 0) + if (family := kwargs.get("family", _socket.AF_UNSPEC)) != _socket.AF_UNSPEC: + _utils.check_socket_family(family) socket_factory = _utils.make_callback(backend.create_udp_endpoint, host, port, **kwargs) case _: # pragma: no cover raise TypeError("Invalid arguments") diff --git a/src/easynetwork/api_sync/client/udp.py b/src/easynetwork/api_sync/client/udp.py index 2ed27fe9..4084a07c 100644 --- a/src/easynetwork/api_sync/client/udp.py +++ b/src/easynetwork/api_sync/client/udp.py @@ -53,6 +53,7 @@ def __init__( protocol: DatagramProtocol[_SentPacketT, _ReceivedPacketT], *, local_address: tuple[str, int] | None = ..., + family: int = ..., reuse_port: bool = ..., retry_interval: float = ..., ) -> None: @@ -105,6 +106,8 @@ def __init__( case _socket.socket() as socket if not kwargs: _utils.ensure_datagram_socket_bound(socket) case (str(host), int(port)): + if (family := kwargs.get("family", _socket.AF_UNSPEC)) != _socket.AF_UNSPEC: + _utils.check_socket_family(family) socket = _create_udp_socket(remote_address=(host, port), **kwargs) case _: # pragma: no cover raise TypeError("Invalid arguments") @@ -291,6 +294,7 @@ def _create_udp_socket( *, local_address: tuple[str, int] | None = None, remote_address: tuple[str, int] | None = None, + family: int = _socket.AF_UNSPEC, reuse_port: bool = False, ) -> _socket.socket: local_host: str | None @@ -310,7 +314,7 @@ def _create_udp_socket( for family, _, proto, _, sockaddr in _socket.getaddrinfo( *(remote_address or local_address), - family=_socket.AF_UNSPEC, + family=family, type=_socket.SOCK_DGRAM, flags=flags, ): diff --git a/src/easynetwork/lowlevel/api_async/backend/abc.py b/src/easynetwork/lowlevel/api_async/backend/abc.py index 9fe583f6..875736df 100644 --- a/src/easynetwork/lowlevel/api_async/backend/abc.py +++ b/src/easynetwork/lowlevel/api_async/backend/abc.py @@ -1046,6 +1046,7 @@ async def create_udp_endpoint( remote_port: int, *, local_address: tuple[str, int] | None = ..., + family: int = ..., ) -> transports.AsyncDatagramTransport: """ Opens an endpoint using the UDP protocol. diff --git a/src/easynetwork/lowlevel/asyncio/_asyncio_utils.py b/src/easynetwork/lowlevel/asyncio/_asyncio_utils.py index 988307e1..828a6739 100644 --- a/src/easynetwork/lowlevel/asyncio/_asyncio_utils.py +++ b/src/easynetwork/lowlevel/asyncio/_asyncio_utils.py @@ -19,6 +19,7 @@ __all__ = [ "create_connection", + "create_datagram_connection", "open_listener_sockets_from_getaddrinfo_result", "wait_until_readable", "wait_until_writable", @@ -82,33 +83,14 @@ async def resolve_local_addresses( return sorted(infos) -async def create_connection( - host: str, - port: int, +async def _create_connection_impl( + *, loop: asyncio.AbstractEventLoop, - local_address: tuple[str, int] | None = None, - socktype: int = _socket.SOCK_STREAM, + remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]], + local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None, ) -> _socket.socket: - remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await ensure_resolved( - host, - port, - family=_socket.AF_UNSPEC, - type=socktype, - loop=loop, - ) - local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None = None - if local_address is not None: - local_host, local_port = local_address - local_addrinfo = await ensure_resolved( - local_host, - local_port, - family=_socket.AF_UNSPEC, - type=socktype, - loop=loop, - ) - errors: list[OSError] = [] - for family, _, proto, _, remote_sockaddr in remote_addrinfo: + for family, socktype, proto, _, remote_sockaddr in remote_addrinfo: try: socket = _socket.socket(family, socktype, proto) except OSError as exc: @@ -162,6 +144,69 @@ async def create_connection( errors.clear() +async def create_connection( + host: str, + port: int, + loop: asyncio.AbstractEventLoop, + local_address: tuple[str, int] | None = None, +) -> _socket.socket: + remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await ensure_resolved( + host, + port, + family=_socket.AF_UNSPEC, + type=_socket.SOCK_STREAM, + loop=loop, + ) + local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None = None + if local_address is not None: + local_host, local_port = local_address + local_addrinfo = await ensure_resolved( + local_host, + local_port, + family=_socket.AF_UNSPEC, + type=_socket.SOCK_STREAM, + loop=loop, + ) + + return await _create_connection_impl( + loop=loop, + remote_addrinfo=remote_addrinfo, + local_addrinfo=local_addrinfo, + ) + + +async def create_datagram_connection( + host: str, + port: int, + loop: asyncio.AbstractEventLoop, + local_address: tuple[str, int] | None = None, + family: int = _socket.AF_UNSPEC, +) -> _socket.socket: + remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await ensure_resolved( + host, + port, + family=family, + type=_socket.SOCK_DGRAM, + loop=loop, + ) + local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None = None + if local_address is not None: + local_host, local_port = local_address + local_addrinfo = await ensure_resolved( + local_host, + local_port, + family=family, + type=_socket.SOCK_DGRAM, + loop=loop, + ) + + return await _create_connection_impl( + loop=loop, + remote_addrinfo=remote_addrinfo, + local_addrinfo=local_addrinfo, + ) + + def open_listener_sockets_from_getaddrinfo_result( infos: Iterable[tuple[int, int, int, str, tuple[Any, ...]]], *, diff --git a/src/easynetwork/lowlevel/asyncio/backend.py b/src/easynetwork/lowlevel/asyncio/backend.py index 2cf894a5..294367cb 100644 --- a/src/easynetwork/lowlevel/asyncio/backend.py +++ b/src/easynetwork/lowlevel/asyncio/backend.py @@ -42,7 +42,12 @@ from ...exceptions import UnsupportedOperation from ..api_async.backend.abc import AsyncBackend as AbstractAsyncBackend from ..api_async.backend.sniffio import current_async_library_cvar as _sniffio_current_async_library_cvar -from ._asyncio_utils import create_connection, open_listener_sockets_from_getaddrinfo_result, resolve_local_addresses +from ._asyncio_utils import ( + create_connection, + create_datagram_connection, + open_listener_sockets_from_getaddrinfo_result, + resolve_local_addresses, +) from .datagram.endpoint import create_datagram_endpoint from .datagram.listener import AsyncioTransportDatagramListenerSocketAdapter, RawDatagramListenerSocketAdapter from .datagram.socket import AsyncioTransportDatagramSocketAdapter, RawDatagramSocketAdapter @@ -309,14 +314,15 @@ async def create_udp_endpoint( remote_port: int, *, local_address: tuple[str, int] | None = None, + family: int = _socket.AF_UNSPEC, ) -> AsyncioTransportDatagramSocketAdapter | RawDatagramSocketAdapter: loop = asyncio.get_running_loop() - socket = await create_connection( + socket = await create_datagram_connection( remote_host, remote_port, loop, local_address=local_address, - socktype=_socket.SOCK_DGRAM, + family=family, ) return await self.wrap_connected_datagram_socket(socket) diff --git a/tests/unit_test/_utils.py b/tests/unit_test/_utils.py index 75ac042f..1fd7396e 100644 --- a/tests/unit_test/_utils.py +++ b/tests/unit_test/_utils.py @@ -96,7 +96,11 @@ def get_all_socket_families() -> frozenset[str]: def _get_all_socket_families() -> frozenset[str]: import socket - return frozenset(v for v in dir(socket) if v.startswith("AF_")) + to_exclude = { + "AF_UNSPEC", + } + + return frozenset(v for v in dir(socket) if v.startswith("AF_") and v not in to_exclude) def __addrinfo_list( diff --git a/tests/unit_test/test_async/test_api/test_client/test_udp.py b/tests/unit_test/test_async/test_api/test_client/test_udp.py index 91d8f86b..da170dcb 100644 --- a/tests/unit_test/test_async/test_api/test_client/test_udp.py +++ b/tests/unit_test/test_async/test_api/test_client/test_udp.py @@ -2,7 +2,7 @@ import errno import os -from socket import AF_INET6, SO_ERROR, SOL_SOCKET +from socket import AF_INET6, AF_UNSPEC, SO_ERROR, SOL_SOCKET from typing import TYPE_CHECKING, Any from easynetwork.api_async.client.udp import AsyncUDPNetworkClient @@ -163,6 +163,76 @@ async def test____dunder_init____with_remote_address( ] assert isinstance(client.socket, SocketProxy) + async def test____dunder_init____with_remote_address____socket_family( + self, + remote_address: tuple[str, int], + socket_family: int, + mock_datagram_protocol: MagicMock, + mock_backend: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + + # Act + client: AsyncUDPNetworkClient[Any, Any] = AsyncUDPNetworkClient( + remote_address, + mock_datagram_protocol, + family=socket_family, + local_address=mocker.sentinel.local_address, + ) + await client.wait_connected() + + # Assert + mock_backend.create_udp_endpoint.assert_awaited_once_with( + *remote_address, + family=socket_family, + local_address=mocker.sentinel.local_address, + ) + + async def test____dunder_init____with_remote_address____explicit_AF_UNSPEC( + self, + remote_address: tuple[str, int], + mock_datagram_protocol: MagicMock, + mock_backend: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + + # Act + client: AsyncUDPNetworkClient[Any, Any] = AsyncUDPNetworkClient( + remote_address, + mock_datagram_protocol, + family=AF_UNSPEC, + local_address=mocker.sentinel.local_address, + ) + await client.wait_connected() + + # Assert + mock_backend.create_udp_endpoint.assert_awaited_once_with( + *remote_address, + family=AF_UNSPEC, + local_address=mocker.sentinel.local_address, + ) + + @pytest.mark.parametrize("socket_family", list(UNSUPPORTED_FAMILIES), indirect=True) + async def test____dunder_init____with_remote_address____invalid_socket_family( + self, + remote_address: tuple[str, int], + socket_family: int, + mock_datagram_protocol: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + + # Act & Assert + with pytest.raises(ValueError, match=r"^Only these families are supported: .+$"): + _ = AsyncUDPNetworkClient( + remote_address, + mock_datagram_protocol, + family=socket_family, + local_address=mocker.sentinel.local_address, + ) + async def test____dunder_init____with_remote_address____force_local_address( self, remote_address: tuple[str, int], diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_backend.py b/tests/unit_test/test_async/test_asyncio_backend/test_backend.py index a8c8f70b..a7ae0478 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_backend.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_backend.py @@ -1213,11 +1213,13 @@ async def test____create_tcp_listeners____invalid_backlog( mock_open_listeners.assert_not_called() mock_ListenerSocketAdapter.assert_not_called() + @pytest.mark.parametrize("socket_family", [None, AF_INET, AF_INET6], ids=lambda p: f"family=={p}") async def test____create_udp_endpoint____use_loop_create_datagram_endpoint( self, event_loop: asyncio.AbstractEventLoop, local_address: tuple[str, int] | None, remote_address: tuple[str, int], + socket_family: int | None, backend: AsyncIOBackend, mock_datagram_endpoint_factory: Callable[[], MagicMock], use_asyncio_transport: bool, @@ -1239,20 +1241,23 @@ async def test____create_udp_endpoint____use_loop_create_datagram_endpoint( return_value=mock_endpoint, ) mock_own_create_connection: AsyncMock = mocker.patch( - "easynetwork.lowlevel.asyncio.backend.create_connection", + "easynetwork.lowlevel.asyncio.backend.create_datagram_connection", new_callable=mocker.AsyncMock, return_value=mock_udp_socket, ) # Act - socket = await backend.create_udp_endpoint(*remote_address, local_address=local_address) + if socket_family is None: + socket = await backend.create_udp_endpoint(*remote_address, local_address=local_address) + else: + socket = await backend.create_udp_endpoint(*remote_address, local_address=local_address, family=socket_family) # Assert mock_own_create_connection.assert_awaited_once_with( *remote_address, event_loop, local_address=local_address, - socktype=SOCK_DGRAM, + family=AF_UNSPEC if socket_family is None else socket_family, ) if use_asyncio_transport: mock_create_datagram_endpoint.assert_awaited_once_with(sock=mock_udp_socket) diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_utils.py b/tests/unit_test/test_async/test_asyncio_backend/test_utils.py index 29b349a6..a437acd2 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_utils.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_utils.py @@ -24,11 +24,12 @@ SocketType, gaierror, ) -from typing import TYPE_CHECKING, Any, Literal, assert_never, cast +from typing import TYPE_CHECKING, Any, Literal, Protocol as TypingProtocol, assert_never, cast from easynetwork.lowlevel._utils import error_from_errno from easynetwork.lowlevel.asyncio._asyncio_utils import ( create_connection, + create_datagram_connection, ensure_resolved, open_listener_sockets_from_getaddrinfo_result, wait_until_readable, @@ -203,12 +204,56 @@ async def test____ensure_resolved____propagate_unrelated_gaierror( mock_getaddrinfo.assert_not_awaited() +class _CreateConnectionCallable(TypingProtocol): + async def __call__( + self, + host: str, + port: int, + loop: asyncio.AbstractEventLoop, + local_address: tuple[str, int] | None = None, + ) -> SocketType: + ... + + +class _AddrInfoListFactory(TypingProtocol): + def __call__( + self, + port: int, + families: Sequence[int] = ..., + ) -> Sequence[tuple[int, int, int, str, tuple[Any, ...]]]: + ... + + +@pytest.fixture(params=[SOCK_STREAM, SOCK_DGRAM], ids=lambda sock_type: f"sock_type=={sock_type!r}") +def connection_socktype(request: pytest.FixtureRequest) -> int: + return request.param + + +@pytest.fixture +def create_connection_of_socktype(connection_socktype: int) -> _CreateConnectionCallable: + if connection_socktype == SOCK_STREAM: + return create_connection + if connection_socktype == SOCK_DGRAM: + return create_datagram_connection + pytest.fail("Invalid fixture argument") + + +@pytest.fixture +def addrinfo_list_factory(connection_socktype: int) -> _AddrInfoListFactory: + if connection_socktype == SOCK_STREAM: + return stream_addrinfo_list + if connection_socktype == SOCK_DGRAM: + return datagram_addrinfo_list + pytest.fail("Invalid fixture argument") + + @pytest.mark.asyncio @pytest.mark.parametrize("with_local_address", [False, True], ids=lambda boolean: f"with_local_address=={boolean}") -@pytest.mark.parametrize("sock_type", [None, SOCK_STREAM, SOCK_DGRAM], ids=lambda sock_type: f"sock_type=={sock_type!r}") async def test____create_connection____default( event_loop: asyncio.AbstractEventLoop, - sock_type: int | None, + create_connection_of_socktype: _CreateConnectionCallable, + addrinfo_list_factory: _AddrInfoListFactory, + connection_socktype: int, with_local_address: bool, mock_socket_cls: MagicMock, mock_socket_ipv4: MagicMock, @@ -218,40 +263,30 @@ async def test____create_connection____default( mocker: MockerFixture, ) -> None: # Arrange - expected_sock_type = SOCK_STREAM if sock_type is None else sock_type - expected_proto = IPPROTO_TCP if expected_sock_type == SOCK_STREAM else IPPROTO_UDP + expected_proto = IPPROTO_TCP if connection_socktype == SOCK_STREAM else IPPROTO_UDP remote_host, remote_port = "localhost", 12345 local_address: tuple[str, int] | None = ("localhost", 11111) if with_local_address else None if local_address is None: - if expected_sock_type == SOCK_STREAM: - mock_getaddrinfo.side_effect = [stream_addrinfo_list(remote_port)] - else: - mock_getaddrinfo.side_effect = [datagram_addrinfo_list(remote_port)] + mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] else: - if expected_sock_type == SOCK_STREAM: - mock_getaddrinfo.side_effect = [stream_addrinfo_list(remote_port), stream_addrinfo_list(local_address[1])] - else: - mock_getaddrinfo.side_effect = [datagram_addrinfo_list(remote_port), datagram_addrinfo_list(local_address[1])] + mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port), addrinfo_list_factory(local_address[1])] # Act - if sock_type is None: - socket = await create_connection(remote_host, remote_port, event_loop, local_address=local_address) - else: - socket = await create_connection(remote_host, remote_port, event_loop, local_address=local_address, socktype=sock_type) + socket = await create_connection_of_socktype(remote_host, remote_port, event_loop, local_address=local_address) # Assert if local_address is None: assert mock_getaddrinfo.await_args_list == [ - mocker.call(remote_host, remote_port, family=AF_UNSPEC, type=expected_sock_type, proto=0, flags=0), + mocker.call(remote_host, remote_port, family=AF_UNSPEC, type=connection_socktype, proto=0, flags=0), ] else: assert mock_getaddrinfo.await_args_list == [ - mocker.call(remote_host, remote_port, family=AF_UNSPEC, type=expected_sock_type, proto=0, flags=0), - mocker.call(*local_address, family=AF_UNSPEC, type=expected_sock_type, proto=0, flags=0), + mocker.call(remote_host, remote_port, family=AF_UNSPEC, type=connection_socktype, proto=0, flags=0), + mocker.call(*local_address, family=AF_UNSPEC, type=connection_socktype, proto=0, flags=0), ] - mock_socket_cls.assert_called_once_with(AF_INET, expected_sock_type, expected_proto) + mock_socket_cls.assert_called_once_with(AF_INET, connection_socktype, expected_proto) assert socket is mock_socket_ipv4 mock_socket_ipv4.setblocking.assert_called_once_with(False) @@ -271,6 +306,9 @@ async def test____create_connection____default( async def test____create_connection____first_failed( event_loop: asyncio.AbstractEventLoop, fail_on: Literal["socket", "bind", "connect"], + create_connection_of_socktype: _CreateConnectionCallable, + addrinfo_list_factory: _AddrInfoListFactory, + connection_socktype: int, mock_socket_cls: MagicMock, mock_socket_ipv4: MagicMock, mock_socket_ipv6: MagicMock, @@ -283,9 +321,9 @@ async def test____create_connection____first_failed( local_address: tuple[str, int] | None = ("localhost", 11111) if fail_on == "bind" else None if local_address is None: - mock_getaddrinfo.side_effect = [stream_addrinfo_list(remote_port)] + mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] else: - mock_getaddrinfo.side_effect = [stream_addrinfo_list(remote_port), stream_addrinfo_list(local_address[1])] + mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port), addrinfo_list_factory(local_address[1])] match fail_on: case "socket": @@ -298,13 +336,19 @@ async def test____create_connection____first_failed( assert_never(fail_on) # Act - socket = await create_connection(remote_host, remote_port, event_loop, local_address=local_address) + socket = await create_connection_of_socktype(remote_host, remote_port, event_loop, local_address=local_address) # Assert - assert mock_socket_cls.call_args_list == [ - mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), - ] + if connection_socktype == SOCK_STREAM: + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), + ] + else: + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET, SOCK_DGRAM, IPPROTO_UDP), + mocker.call(AF_INET6, SOCK_DGRAM, IPPROTO_UDP), + ] assert socket is mock_socket_ipv6 if fail_on != "socket": @@ -336,6 +380,9 @@ async def test____create_connection____first_failed( async def test____create_connection____all_failed( event_loop: asyncio.AbstractEventLoop, fail_on: Literal["socket", "bind", "connect"], + create_connection_of_socktype: _CreateConnectionCallable, + addrinfo_list_factory: _AddrInfoListFactory, + connection_socktype: int, mock_socket_cls: MagicMock, mock_socket_ipv4: MagicMock, mock_socket_ipv6: MagicMock, @@ -348,9 +395,9 @@ async def test____create_connection____all_failed( local_address: tuple[str, int] | None = ("localhost", 11111) if fail_on == "bind" else None if local_address is None: - mock_getaddrinfo.side_effect = [stream_addrinfo_list(remote_port)] + mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] else: - mock_getaddrinfo.side_effect = [stream_addrinfo_list(remote_port), stream_addrinfo_list(local_address[1])] + mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port), addrinfo_list_factory(local_address[1])] match fail_on: case "socket": @@ -365,7 +412,7 @@ async def test____create_connection____all_failed( # Act with pytest.raises(ExceptionGroup) as exc_info: - await create_connection(remote_host, remote_port, event_loop, local_address=local_address) + await create_connection_of_socktype(remote_host, remote_port, event_loop, local_address=local_address) # Assert os_errors, exc = exc_info.value.split(OSError) @@ -375,10 +422,16 @@ async def test____create_connection____all_failed( assert all(isinstance(exc, OSError) for exc in os_errors.exceptions) del os_errors - assert mock_socket_cls.call_args_list == [ - mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), - ] + if connection_socktype == SOCK_STREAM: + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), + ] + else: + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET, SOCK_DGRAM, IPPROTO_UDP), + mocker.call(AF_INET6, SOCK_DGRAM, IPPROTO_UDP), + ] if fail_on != "socket": mock_socket_ipv4.setblocking.assert_called_once_with(False) @@ -407,6 +460,8 @@ async def test____create_connection____all_failed( async def test____create_connection____unrelated_exception( event_loop: asyncio.AbstractEventLoop, fail_on: Literal["socket", "connect"], + create_connection_of_socktype: _CreateConnectionCallable, + addrinfo_list_factory: _AddrInfoListFactory, mock_socket_cls: MagicMock, mock_socket_ipv4: MagicMock, mock_getaddrinfo: AsyncMock, @@ -415,7 +470,7 @@ async def test____create_connection____unrelated_exception( # Arrange remote_host, remote_port = "localhost", 12345 - mock_getaddrinfo.side_effect = [stream_addrinfo_list(remote_port)] + mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] expected_failure_exception = BaseException() match fail_on: @@ -428,7 +483,7 @@ async def test____create_connection____unrelated_exception( # Act with pytest.raises(BaseException) as exc_info: - await create_connection(remote_host, remote_port, event_loop) + await create_connection_of_socktype(remote_host, remote_port, event_loop) # Assert assert exc_info.value is expected_failure_exception @@ -441,6 +496,8 @@ async def test____create_connection____unrelated_exception( async def test____create_connection____getaddrinfo_returned_empty_list( event_loop: asyncio.AbstractEventLoop, fail_on: Literal["remote_address", "local_address"], + create_connection_of_socktype: _CreateConnectionCallable, + addrinfo_list_factory: _AddrInfoListFactory, mock_socket_cls: MagicMock, mock_socket_ipv4: MagicMock, mock_socket_ipv6: MagicMock, @@ -455,13 +512,13 @@ async def test____create_connection____getaddrinfo_returned_empty_list( case "remote_address": mock_getaddrinfo.side_effect = [[]] case "local_address": - mock_getaddrinfo.side_effect = [stream_addrinfo_list(remote_port), []] + mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port), []] case _: assert_never(fail_on) # Act with pytest.raises(OSError, match=r"^getaddrinfo\('localhost'\) returned empty list$"): - await create_connection(remote_host, remote_port, event_loop, local_address=local_address) + await create_connection_of_socktype(remote_host, remote_port, event_loop, local_address=local_address) # Assert mock_socket_cls.assert_not_called() @@ -473,6 +530,8 @@ async def test____create_connection____getaddrinfo_returned_empty_list( @pytest.mark.asyncio async def test____create_connection____getaddrinfo_return_mismatch( event_loop: asyncio.AbstractEventLoop, + create_connection_of_socktype: _CreateConnectionCallable, + addrinfo_list_factory: _AddrInfoListFactory, mock_socket_ipv4: MagicMock, mock_socket_ipv6: MagicMock, mock_getaddrinfo: AsyncMock, @@ -483,13 +542,13 @@ async def test____create_connection____getaddrinfo_return_mismatch( local_address: tuple[str, int] = ("localhost", 11111) mock_getaddrinfo.side_effect = [ - stream_addrinfo_list(remote_port, families=[AF_INET6]), - stream_addrinfo_list(local_address[1], families=[AF_INET]), + addrinfo_list_factory(remote_port, families=[AF_INET6]), + addrinfo_list_factory(local_address[1], families=[AF_INET]), ] # Act with pytest.raises(ExceptionGroup) as exc_info: - await create_connection(remote_host, remote_port, event_loop, local_address=local_address) + await create_connection_of_socktype(remote_host, remote_port, event_loop, local_address=local_address) # Assert os_errors, exc = exc_info.value.split(OSError) diff --git a/tests/unit_test/test_sync/test_client/test_udp.py b/tests/unit_test/test_sync/test_client/test_udp.py index bf8ec5c9..e3721bc6 100644 --- a/tests/unit_test/test_sync/test_client/test_udp.py +++ b/tests/unit_test/test_sync/test_client/test_udp.py @@ -197,6 +197,7 @@ def test____dunder_init____create_datagram_endpoint____default( def test____dunder_init____create_datagram_endpoint____with_parameters( self, reuse_port: bool, + socket_family: int, local_address: tuple[str, int] | None, remote_address: tuple[str, int], mock_udp_socket: MagicMock, @@ -210,12 +211,14 @@ def test____dunder_init____create_datagram_endpoint____with_parameters( _ = UDPNetworkClient( remote_address, mock_datagram_protocol, + family=socket_family, local_address=local_address, reuse_port=reuse_port, ) # Assert mock_create_udp_socket.assert_called_once_with( + family=socket_family, local_address=local_address, remote_address=remote_address, reuse_port=reuse_port, @@ -225,6 +228,56 @@ def test____dunder_init____create_datagram_endpoint____with_parameters( mocker.call.setblocking(False), ] + @pytest.mark.parametrize( + "local_address", [None, ("local_address", 12345)], ids=lambda p: f"local_address=={p}", indirect=True + ) + @pytest.mark.parametrize("reuse_port", [False, True], ids=lambda p: f"reuse_port=={p}") + def test____dunder_init____create_datagram_endpoint____with_parameters____explicit_AF_UNSPEC( + self, + reuse_port: bool, + local_address: tuple[str, int] | None, + remote_address: tuple[str, int], + mock_datagram_protocol: MagicMock, + mock_create_udp_socket: MagicMock, + ) -> None: + # Arrange + + # Act + _ = UDPNetworkClient( + remote_address, + mock_datagram_protocol, + family=AF_UNSPEC, + local_address=local_address, + reuse_port=reuse_port, + ) + + # Assert + mock_create_udp_socket.assert_called_once_with( + family=AF_UNSPEC, + local_address=local_address, + remote_address=remote_address, + reuse_port=reuse_port, + ) + + @pytest.mark.parametrize("socket_family", list(UNSUPPORTED_FAMILIES), indirect=True) + def test____dunder_init____create_datagram_endpoint____invalid_family( + self, + socket_family: int, + remote_address: tuple[str, int], + mock_datagram_protocol: MagicMock, + mock_create_udp_socket: MagicMock, + ) -> None: + # Arrange + + # Act & Assert + with pytest.raises(ValueError, match=r"^Only these families are supported: .+$"): + _ = UDPNetworkClient( + remote_address, + mock_datagram_protocol, + family=socket_family, + ) + mock_create_udp_socket.assert_not_called() + @pytest.mark.parametrize("bound", [False, True], ids=lambda p: f"bound=={p}") def test____dunder_init____use_given_socket____default( self, @@ -848,11 +901,22 @@ def test____create_udp_socket____default( @pytest.mark.parametrize("with_local_address", [False, True], ids=lambda boolean: f"with_local_address=={boolean}") @pytest.mark.parametrize("with_remote_address", [False, True], ids=lambda boolean: f"with_remote_address=={boolean}") @pytest.mark.parametrize("set_reuse_port", [False, True], ids=lambda boolean: f"set_reuse_port=={boolean}") + @pytest.mark.parametrize( + ["family", "socket_families"], + [ + pytest.param(AF_UNSPEC, (AF_INET, AF_INET6), id="AF_UNSPEC"), + pytest.param(AF_INET, (AF_INET,), id="AF_INET"), + pytest.param(AF_INET6, (AF_INET6,), id="AF_INET6"), + ], + indirect=["socket_families"], + ) def test____create_udp_socket____with_parameters( self, with_local_address: bool, with_remote_address: bool, set_reuse_port: bool, + family: int, + socket_families: tuple[int, ...], mock_socket_cls: MagicMock, mock_getaddrinfo: MagicMock, mock_socket_ipv4: MagicMock, @@ -865,40 +929,61 @@ def test____create_udp_socket____with_parameters( expected_local_address: tuple[str, int] = local_address if local_address is not None else ("localhost", 0) # Act - udp_socket = create_udp_socket(local_address=local_address, remote_address=remote_address, reuse_port=set_reuse_port) + udp_socket = create_udp_socket( + local_address=local_address, + remote_address=remote_address, + family=family, + reuse_port=set_reuse_port, + ) # Assert - assert udp_socket is mock_socket_ipv4 if remote_address is None: mock_getaddrinfo.assert_called_once_with( *expected_local_address, - family=AF_UNSPEC, + family=family, type=SOCK_DGRAM, flags=AI_PASSIVE, ) else: mock_getaddrinfo.assert_called_once_with( *remote_address, - family=AF_UNSPEC, + family=family, type=SOCK_DGRAM, flags=0, ) - mock_socket_cls.assert_called_once_with(AF_INET, SOCK_DGRAM, IPPROTO_UDP) + if AF_INET in socket_families: + assert udp_socket is mock_socket_ipv4 + mock_socket_cls.assert_called_once_with(AF_INET, SOCK_DGRAM, IPPROTO_UDP) + used_socket = mock_socket_ipv4 + not_used_socket = mock_socket_ipv6 + else: + assert udp_socket is mock_socket_ipv6 + mock_socket_cls.assert_called_once_with(AF_INET6, SOCK_DGRAM, IPPROTO_UDP) + used_socket = mock_socket_ipv6 + not_used_socket = mock_socket_ipv4 if set_reuse_port: - mock_socket_ipv4.setsockopt.assert_called_once_with(SOL_SOCKET, SO_REUSEPORT, True) + used_socket.setsockopt.assert_called_once_with(SOL_SOCKET, SO_REUSEPORT, True) else: - mock_socket_ipv4.setsockopt.assert_not_called() - if remote_address is None: - mock_socket_ipv4.bind.assert_called_once_with(("127.0.0.1", expected_local_address[1])) - mock_socket_ipv4.connect.assert_not_called() + used_socket.setsockopt.assert_not_called() + if used_socket is mock_socket_ipv4: + if remote_address is None: + used_socket.bind.assert_called_once_with(("127.0.0.1", expected_local_address[1])) + used_socket.connect.assert_not_called() + else: + used_socket.bind.assert_called_once_with(expected_local_address) + used_socket.connect.assert_called_once_with(("127.0.0.1", 12345)) else: - mock_socket_ipv4.bind.assert_called_once_with(expected_local_address) - mock_socket_ipv4.connect.assert_called_once_with(("127.0.0.1", 12345)) - mock_socket_ipv4.close.assert_not_called() + if remote_address is None: + used_socket.bind.assert_called_once_with(("::1", expected_local_address[1], 0, 0)) + used_socket.connect.assert_not_called() + else: + used_socket.bind.assert_called_once_with(expected_local_address) + used_socket.connect.assert_called_once_with(("::1", 12345, 0, 0)) + used_socket.close.assert_not_called() - mock_socket_ipv6.setsockopt.assert_not_called() - mock_socket_ipv6.bind.assert_not_called() - mock_socket_ipv6.connect.assert_not_called() + not_used_socket.setsockopt.assert_not_called() + not_used_socket.bind.assert_not_called() + not_used_socket.connect.assert_not_called() @pytest.mark.parametrize("with_remote_address", [False, True], ids=lambda boolean: f"with_remote_address=={boolean}") @pytest.mark.parametrize("fail_on", ["socket", "bind", "connect"], ids=lambda fail_on: f"fail_on=={fail_on}")