From fa15bff1b7963a3db4db1059ebccffd7f413e4c1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Francis=20Clairicia-Rose-Claire-Jos=C3=A9phine?= Date: Sat, 3 Aug 2024 23:48:30 +0200 Subject: [PATCH] Implemented `trio` backend (#337) --- .vscode/settings.example.json | 3 +- README.md | 22 +- benchmark_server/run_benchmark | 188 ++- .../servers/asyncio_tcp_echoserver.py | 20 +- .../servers/asyncio_udp_echoserver.py | 3 +- .../servers/easynetwork_tcp_echoserver.py | 47 +- .../servers/easynetwork_udp_echoserver.py | 58 +- benchmark_server/servers/requirements.txt | 7 + .../servers/trio_tcp_echoserver.py | 146 +++ .../servers/trio_udp_echoserver.py | 67 ++ ...ple2.py => connection_example2_asyncio.py} | 2 +- .../connection_example2_backend_api.py | 29 + .../api_async/connection_example2_trio.py | 29 + .../howto/tcp_clients/usage/api_async.py | 54 +- .../request_handler_explanation.py | 83 +- .../tcp_servers/standalone/server_trio.py | 54 + .../howto/udp_clients/usage/api_async.py | 54 +- .../request_handler_explanation.py | 99 +- .../udp_servers/standalone/server_trio.py | 54 + ...sync_client.py => async_client_asyncio.py} | 0 .../async_client_trio.py | 31 + ...sync_server.py => async_server_asyncio.py} | 10 +- .../async_server_trio.py | 25 + .../echo_client_server_tcp/server.py | 10 +- ...sync_client.py => async_client_asyncio.py} | 0 .../async_client_trio.py | 31 + ...sync_server.py => async_server_asyncio.py} | 10 +- .../async_server_trio.py | 25 + .../echo_client_server_udp/server.py | 10 +- ...sync_server.py => async_server_asyncio.py} | 10 +- .../tutorials/ftp_server/async_server_trio.py | 43 + docs/source/_include/sync-async-variants.rst | 2 +- docs/source/_static/css/details.css | 8 + docs/source/api/lowlevel/async/backend.rst | 88 +- docs/source/conf.py | 2 + .../howto/advanced/standalone_servers.rst | 27 +- docs/source/howto/tcp_clients.rst | 92 +- docs/source/howto/tcp_servers.rst | 66 +- docs/source/howto/udp_clients.rst | 64 +- docs/source/howto/udp_servers.rst | 66 +- docs/source/quickstart/install.rst | 11 + .../tutorials/echo_client_server_tcp.rst | 20 +- .../tutorials/echo_client_server_udp.rst | 20 +- docs/source/tutorials/ftp_server.rst | 10 +- pdm.lock | 92 +- pyproject.toml | 14 +- src/easynetwork/lowlevel/_utils.py | 16 +- .../backend/_asyncio/_asyncio_utils.py | 254 ---- .../api_async/backend/_asyncio/backend.py | 69 +- .../backend/_asyncio/dns_resolver.py | 36 + .../api_async/backend/_asyncio/threads.py | 12 +- .../api_async/backend/_common/__init__.py | 19 + .../api_async/backend/_common/dns_resolver.py | 297 +++++ .../api_async/backend/_trio/__init__.py | 20 + .../api_async/backend/_trio/_trio_utils.py | 77 ++ .../api_async/backend/_trio/backend.py | 278 +++++ .../backend/_trio/datagram/__init__.py | 20 + .../backend/_trio/datagram/listener.py | 103 ++ .../backend/_trio/datagram/socket.py | 81 ++ .../api_async/backend/_trio/dns_resolver.py | 50 + .../backend/_trio/stream/__init__.py | 20 + .../backend/_trio/stream/_sendmsg.py | 38 + .../backend/_trio/stream/listener.py | 96 ++ .../api_async/backend/_trio/stream/socket.py | 112 ++ .../lowlevel/api_async/backend/_trio/tasks.py | 302 +++++ .../api_async/backend/_trio/threads.py | 212 ++++ .../lowlevel/api_async/backend/abc.py | 93 +- .../lowlevel/api_async/backend/utils.py | 40 +- tests/fixtures/__init__.py | 0 tests/fixtures/trio.py | 17 + .../test_backend/test_asyncio_backend.py | 280 +++-- .../test_backend/test_trio_backend.py | 1034 +++++++++++++++++ .../test_async/test_futures.py | 23 +- .../test_async/test_server/test_tcp.py | 11 +- .../test_communication/test_end2end.py | 175 +++ tests/pytest_plugins/async_finalizer.py | 20 + tests/pytest_plugins/extra_features.py | 17 + tests/tools.py | 47 +- tests/unit_test/conftest.py | 40 +- tests/unit_test/test_async/conftest.py | 12 - .../test_asyncio_backend/test_backend.py | 489 +++----- .../test_asyncio_backend/test_dns_resolver.py | 30 + .../test_asyncio_backend/test_stream.py | 4 +- .../test_asyncio_backend/test_utils.py | 744 +----------- .../test_backend/test_backend.py | 68 +- .../test_common_tools/__init__.py | 0 .../test_common_tools/test_dns_resolver.py | 963 +++++++++++++++ .../test_backend/test_utils.py | 94 +- .../test_async/test_trio_backend/__init__.py | 0 .../test_async/test_trio_backend/conftest.py | 146 +++ .../test_trio_backend/test_backend.py | 701 +++++++++++ .../test_trio_backend/test_datagram.py | 336 ++++++ .../test_trio_backend/test_dns_resolver.py | 103 ++ .../test_trio_backend/test_stream.py | 688 +++++++++++ .../test_trio_backend/test_tasks.py | 471 ++++++++ .../test_trio_backend/test_threads.py | 32 + .../test_trio_backend/test_utils.py | 105 ++ tests/unit_test/test_tools/test_utils.py | 42 +- tox.ini | 47 +- 99 files changed, 9049 insertions(+), 1741 deletions(-) create mode 100755 benchmark_server/servers/trio_tcp_echoserver.py create mode 100755 benchmark_server/servers/trio_udp_echoserver.py rename docs/source/_include/examples/howto/tcp_clients/basics/api_async/{connection_example2.py => connection_example2_asyncio.py} (91%) create mode 100644 docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2_backend_api.py create mode 100644 docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2_trio.py create mode 100644 docs/source/_include/examples/howto/tcp_servers/standalone/server_trio.py create mode 100644 docs/source/_include/examples/howto/udp_servers/standalone/server_trio.py rename docs/source/_include/examples/tutorials/echo_client_server_tcp/{async_client.py => async_client_asyncio.py} (100%) create mode 100644 docs/source/_include/examples/tutorials/echo_client_server_tcp/async_client_trio.py rename docs/source/_include/examples/tutorials/echo_client_server_tcp/{async_server.py => async_server_asyncio.py} (76%) create mode 100644 docs/source/_include/examples/tutorials/echo_client_server_tcp/async_server_trio.py rename docs/source/_include/examples/tutorials/echo_client_server_udp/{async_client.py => async_client_asyncio.py} (100%) create mode 100644 docs/source/_include/examples/tutorials/echo_client_server_udp/async_client_trio.py rename docs/source/_include/examples/tutorials/echo_client_server_udp/{async_server.py => async_server_asyncio.py} (76%) create mode 100644 docs/source/_include/examples/tutorials/echo_client_server_udp/async_server_trio.py rename docs/source/_include/examples/tutorials/ftp_server/{async_server.py => async_server_asyncio.py} (85%) create mode 100644 docs/source/_include/examples/tutorials/ftp_server/async_server_trio.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_asyncio/dns_resolver.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_common/__init__.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_common/dns_resolver.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/__init__.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/_trio_utils.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/backend.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/datagram/__init__.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/datagram/listener.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/datagram/socket.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/dns_resolver.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/stream/__init__.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/stream/_sendmsg.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/stream/listener.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/stream/socket.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/tasks.py create mode 100644 src/easynetwork/lowlevel/api_async/backend/_trio/threads.py create mode 100644 tests/fixtures/__init__.py create mode 100644 tests/fixtures/trio.py create mode 100644 tests/functional_test/test_async/test_backend/test_trio_backend.py create mode 100644 tests/functional_test/test_communication/test_end2end.py create mode 100644 tests/unit_test/test_async/test_asyncio_backend/test_dns_resolver.py create mode 100644 tests/unit_test/test_async/test_lowlevel_api/test_backend/test_common_tools/__init__.py create mode 100644 tests/unit_test/test_async/test_lowlevel_api/test_backend/test_common_tools/test_dns_resolver.py create mode 100644 tests/unit_test/test_async/test_trio_backend/__init__.py create mode 100644 tests/unit_test/test_async/test_trio_backend/conftest.py create mode 100644 tests/unit_test/test_async/test_trio_backend/test_backend.py create mode 100644 tests/unit_test/test_async/test_trio_backend/test_datagram.py create mode 100644 tests/unit_test/test_async/test_trio_backend/test_dns_resolver.py create mode 100644 tests/unit_test/test_async/test_trio_backend/test_stream.py create mode 100644 tests/unit_test/test_async/test_trio_backend/test_tasks.py create mode 100644 tests/unit_test/test_async/test_trio_backend/test_threads.py create mode 100644 tests/unit_test/test_async/test_trio_backend/test_utils.py diff --git a/.vscode/settings.example.json b/.vscode/settings.example.json index e923c2a6..6f85cc82 100644 --- a/.vscode/settings.example.json +++ b/.vscode/settings.example.json @@ -47,5 +47,6 @@ "reportUnsupportedDunderAll": "warning", "reportShadowedImports": "none" }, - "python.analysis.autoImportCompletions": true + "python.analysis.autoImportCompletions": true, + "css.format.spaceAroundSelectorSeparator": true } diff --git a/README.md b/README.md index b12dc431..e952b82d 100644 --- a/README.md +++ b/README.md @@ -82,7 +82,8 @@ class EchoRequestHandler(AsyncStreamRequestHandler[RequestType, ResponseType]): self, client: AsyncStreamClient[ResponseType], ) -> AsyncGenerator[None, RequestType]: - data: Any = yield # A JSON request has been sent by this client + # A JSON request has been sent by this client + data: Any = yield self.logger.info(f"{client!r} sent {data!r}") @@ -107,14 +108,14 @@ async def main() -> None: ) async with AsyncTCPNetworkServer(host, port, protocol, handler) as server: - try: - await server.serve_forever() - except asyncio.CancelledError: - pass + await server.serve_forever() if __name__ == "__main__": - asyncio.run(main()) + try: + asyncio.run(main()) + except* KeyboardInterrupt: + pass ``` ### TCP Echo client with JSON data @@ -145,8 +146,7 @@ if __name__ == "__main__": main() ``` -
-Asynchronous version ( with async def ) +#### Asynchronous version ( with `async def` ) ```py import asyncio @@ -158,7 +158,7 @@ from easynetwork.clients import AsyncTCPNetworkClient async def main() -> None: async with AsyncTCPNetworkClient(("localhost", 9000), JSONProtocol()) as client: await client.send_packet({"data": {"my_body": ["as json"]}}) - response = await client.recv_packet() # response should be the sent dictionary + response = await client.recv_packet() print(response) # prints {'data': {'my_body': ['as json']}} @@ -171,6 +171,6 @@ if __name__ == "__main__": ## License This project is licensed under the terms of the [Apache Software License 2.0](https://github.com/francis-clairicia/EasyNetwork/blob/main/LICENSE). -### `easynetwork.lowlevel.typed_attr` +### AnyIO's typed attributes -AnyIO's typed attributes incorporated in `easynetwork.lowlevel.typed_attr` from [anyio 4.2](https://github.com/agronholm/anyio/tree/4.2.0), which is distributed under the [MIT license](https://github.com/agronholm/anyio/blob/4.2.0/LICENSE). +AnyIO's typed attributes is incorporated in `easynetwork.lowlevel.typed_attr` from [anyio 4.2](https://github.com/agronholm/anyio/tree/4.2.0), which is distributed under the [MIT license](https://github.com/agronholm/anyio/blob/4.2.0/LICENSE). diff --git a/benchmark_server/run_benchmark b/benchmark_server/run_benchmark index e507db74..f32dd828 100755 --- a/benchmark_server/run_benchmark +++ b/benchmark_server/run_benchmark @@ -148,6 +148,22 @@ BENCHMARKS_DEF: Final[Sequence[_BenchmarkDef]] = ( }, "client": _tcp_echoclient, }, + { + "name": "tcpecho-easynetwork-trio", + "title": "TCP echo server (easynetwork+trio)", + "server": ( + *_python_cmd, + "/usr/src/servers/easynetwork_tcp_echoserver.py", + f"--port={EXPOSED_PORT}", + "--trio", + ), + "ping": { + "server_address": _tcp_server_address, + "ping_request": b"ping\n", + "socket_type": SOCK_STREAM, + }, + "client": _tcp_echoclient, + }, { "name": "tcpecho-easynetwork-buffered-asyncio", "title": "TCP echo server (easynetwork+buffered+asyncio)", @@ -181,6 +197,23 @@ BENCHMARKS_DEF: Final[Sequence[_BenchmarkDef]] = ( }, "client": _tcp_echoclient, }, + { + "name": "tcpecho-easynetwork-buffered-trio", + "title": "TCP echo server (easynetwork+buffered+trio)", + "server": ( + *_python_cmd, + "/usr/src/servers/easynetwork_tcp_echoserver.py", + f"--port={EXPOSED_PORT}", + "--buffered", + "--trio", + ), + "ping": { + "server_address": _tcp_server_address, + "ping_request": b"ping\n", + "socket_type": SOCK_STREAM, + }, + "client": _tcp_echoclient, + }, { "name": "tcpecho-asyncio-sockets", "title": "TCP echo server (asyncio/sockets)", @@ -245,6 +278,37 @@ BENCHMARKS_DEF: Final[Sequence[_BenchmarkDef]] = ( }, "client": _tcp_echoclient, }, + { + "name": "tcpecho-trio-sockets", + "title": "TCP echo server (trio/sockets)", + "server": ( + *_python_cmd, + "/usr/src/servers/trio_tcp_echoserver.py", + f"--port={EXPOSED_PORT}", + ), + "ping": { + "server_address": _tcp_server_address, + "ping_request": b"ping\n", + "socket_type": SOCK_STREAM, + }, + "client": _tcp_echoclient, + }, + { + "name": "tcpecho-trio-streams", + "title": "TCP echo server (trio/streams)", + "server": ( + *_python_cmd, + "/usr/src/servers/trio_tcp_echoserver.py", + f"--port={EXPOSED_PORT}", + "--streams", + ), + "ping": { + "server_address": _tcp_server_address, + "ping_request": b"ping\n", + "socket_type": SOCK_STREAM, + }, + "client": _tcp_echoclient, + }, ############################################################################## ################################ TCP readline ################################ ############################################################################## @@ -281,6 +345,23 @@ BENCHMARKS_DEF: Final[Sequence[_BenchmarkDef]] = ( }, "client": _tcp_readline_client, }, + { + "name": "readline-easynetwork-trio", + "title": "TCP readline server (easynetwork+trio)", + "server": ( + *_python_cmd, + "/usr/src/servers/easynetwork_tcp_echoserver.py", + f"--port={EXPOSED_PORT}", + "--readline", + "--trio", + ), + "ping": { + "server_address": _tcp_server_address, + "ping_request": b"ping\n", + "socket_type": SOCK_STREAM, + }, + "client": _tcp_readline_client, + }, { "name": "readline-easynetwork-buffered-asyncio", "title": "TCP readline server (easynetwork+buffered+asyncio)", @@ -316,6 +397,24 @@ BENCHMARKS_DEF: Final[Sequence[_BenchmarkDef]] = ( }, "client": _tcp_readline_client, }, + { + "name": "readline-easynetwork-buffered-trio", + "title": "TCP readline server (easynetwork+buffered+trio)", + "server": ( + *_python_cmd, + "/usr/src/servers/easynetwork_tcp_echoserver.py", + f"--port={EXPOSED_PORT}", + "--readline", + "--buffered", + "--trio", + ), + "ping": { + "server_address": _tcp_server_address, + "ping_request": b"ping\n", + "socket_type": SOCK_STREAM, + }, + "client": _tcp_readline_client, + }, { "name": "readline-asyncio-streams", "title": "TCP readline server (asyncio/streams)", @@ -387,6 +486,24 @@ BENCHMARKS_DEF: Final[Sequence[_BenchmarkDef]] = ( }, "client": _ssl_over_tcp_echoclient, }, + { + "name": "sslecho-easynetwork-trio", + "title": "TCP+SSL echo server (easynetwork+trio)", + "server": ( + *_python_cmd, + "/usr/src/servers/easynetwork_tcp_echoserver.py", + f"--port={EXPOSED_PORT}", + "--ssl", + "--trio", + ), + "ping": { + "server_address": _tcp_server_address, + "ping_request": b"ping\n", + "socket_type": SOCK_STREAM, + "ssl": True, + }, + "client": _ssl_over_tcp_echoclient, + }, { "name": "sslecho-easynetwork-buffered-asyncio", "title": "TCP+SSL echo server (easynetwork+buffered+asyncio)", @@ -424,6 +541,25 @@ BENCHMARKS_DEF: Final[Sequence[_BenchmarkDef]] = ( }, "client": _ssl_over_tcp_echoclient, }, + { + "name": "sslecho-easynetwork-buffered-trio", + "title": "TCP+SSL echo server (easynetwork+buffered+trio)", + "server": ( + *_python_cmd, + "/usr/src/servers/easynetwork_tcp_echoserver.py", + f"--port={EXPOSED_PORT}", + "--ssl", + "--buffered", + "--trio", + ), + "ping": { + "server_address": _tcp_server_address, + "ping_request": b"ping\n", + "socket_type": SOCK_STREAM, + "ssl": True, + }, + "client": _ssl_over_tcp_echoclient, + }, { "name": "sslecho-asyncio-streams", "title": "TCP+SSL echo server (asyncio/streams)", @@ -461,6 +597,24 @@ BENCHMARKS_DEF: Final[Sequence[_BenchmarkDef]] = ( }, "client": _ssl_over_tcp_echoclient, }, + { + "name": "sslecho-trio-streams", + "title": "TCP+SSL echo server (trio/streams)", + "server": ( + *_python_cmd, + "/usr/src/servers/trio_tcp_echoserver.py", + f"--port={EXPOSED_PORT}", + "--ssl", + "--streams", + ), + "ping": { + "server_address": _tcp_server_address, + "ping_request": b"ping\n", + "socket_type": SOCK_STREAM, + "ssl": True, + }, + "client": _ssl_over_tcp_echoclient, + }, ########################################################################## ################################ UDP echo ################################ ########################################################################## @@ -495,6 +649,22 @@ BENCHMARKS_DEF: Final[Sequence[_BenchmarkDef]] = ( }, "client": _udp_echoclient, }, + { + "name": "udpecho-easynetwork-trio", + "title": "UDP echo server (easynetwork+trio)", + "server": ( + *_python_cmd, + "/usr/src/servers/easynetwork_udp_echoserver.py", + f"--port={EXPOSED_PORT}", + "--trio", + ), + "ping": { + "server_address": _udp_server_address, + "ping_request": b"ping", + "socket_type": SOCK_DGRAM, + }, + "client": _udp_echoclient, + }, { "name": "udpecho-asyncio-sockets", "title": "UDP echo server (asyncio/sockets)", @@ -561,6 +731,21 @@ BENCHMARKS_DEF: Final[Sequence[_BenchmarkDef]] = ( }, "client": _udp_echoclient, }, + { + "name": "udpecho-trio-sockets", + "title": "UDP echo server (trio/sockets)", + "server": ( + *_python_cmd, + "/usr/src/servers/trio_udp_echoserver.py", + f"--port={EXPOSED_PORT}", + ), + "ping": { + "server_address": _udp_server_address, + "ping_request": b"ping", + "socket_type": SOCK_DGRAM, + }, + "client": _udp_echoclient, + }, ) @@ -604,7 +789,6 @@ def _start_docker_instance( server_cmd, name=container_name, remove=True, - tty=True, detach=True, ports=ports, ), @@ -829,7 +1013,7 @@ def main() -> None: print("Warming up server...") warmup_cmd = benchmark["client"] + warmup - print(" ".join(warmup_cmd)) + print(shlex.join(warmup_cmd)) subprocess.check_output(warmup_cmd) print() diff --git a/benchmark_server/servers/asyncio_tcp_echoserver.py b/benchmark_server/servers/asyncio_tcp_echoserver.py index 7eb9e388..7b0f6105 100755 --- a/benchmark_server/servers/asyncio_tcp_echoserver.py +++ b/benchmark_server/servers/asyncio_tcp_echoserver.py @@ -39,12 +39,14 @@ async def _echo_client(loop: asyncio.AbstractEventLoop, client: socket.socket, a except (OSError, NameError): pass + lock = asyncio.Lock() with client: while True: data = await loop.sock_recv(client, 102400) if not data: break - await loop.sock_sendall(client, data) + async with lock: + await loop.sock_sendall(client, data) LOGGER.info(f"{addr}: Connection closed") @@ -55,14 +57,18 @@ async def echo_client_streams(reader: asyncio.StreamReader, writer: asyncio.Stre sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) except (OSError, NameError): pass + writer.transport.set_write_buffer_limits(0) LOGGER.info(f"Connection from {addr}") + + lock = asyncio.Lock() with contextlib.closing(writer): while True: data = await reader.read(102400) if not data: break - writer.write(data) - await writer.drain() + async with lock: + writer.write(data) + await writer.drain() LOGGER.info(f"{addr}: Connection closed") @@ -73,14 +79,18 @@ async def readline_client_streams(reader: asyncio.StreamReader, writer: asyncio. sock.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) except (OSError, NameError): pass + writer.transport.set_write_buffer_limits(0) LOGGER.info(f"Connection from {addr}") + + lock = asyncio.Lock() with contextlib.closing(writer): while True: data = await reader.readline() if not data: break - writer.write(data) - await writer.drain() + async with lock: + writer.write(data) + await writer.drain() LOGGER.info(f"{addr}: Connection closed") diff --git a/benchmark_server/servers/asyncio_udp_echoserver.py b/benchmark_server/servers/asyncio_udp_echoserver.py index 4b10fee4..c76ba3b3 100755 --- a/benchmark_server/servers/asyncio_udp_echoserver.py +++ b/benchmark_server/servers/asyncio_udp_echoserver.py @@ -22,6 +22,7 @@ async def echo_server(address: tuple[str, int]) -> NoReturn: sock.bind(address) sock.setblocking(False) LOGGER.info(f"Server listening at {sock.getsockname()}") + async with contextlib.AsyncExitStack() as stack: stack.enter_context(sock) task_group = await stack.enter_async_context(asyncio.TaskGroup()) @@ -44,8 +45,8 @@ async def _echo_datagram_client( async def echo_server_stream(address: tuple[str, int]) -> NoReturn: stream = await asyncio_dgram.bind(address) - LOGGER.info(f"Server listening at {stream.sockname}") + async with contextlib.AsyncExitStack() as stack: stack.enter_context(contextlib.closing(stream)) task_group = await stack.enter_async_context(asyncio.TaskGroup()) diff --git a/benchmark_server/servers/easynetwork_tcp_echoserver.py b/benchmark_server/servers/easynetwork_tcp_echoserver.py index aaeb1766..2a64af4f 100755 --- a/benchmark_server/servers/easynetwork_tcp_echoserver.py +++ b/benchmark_server/servers/easynetwork_tcp_echoserver.py @@ -7,7 +7,7 @@ import ssl import sys from collections.abc import AsyncGenerator, Generator -from typing import Any +from typing import Any, Literal from easynetwork.protocol import BufferedStreamProtocol, StreamProtocol from easynetwork.serializers.abc import BufferedIncrementalPacketSerializer @@ -61,23 +61,33 @@ async def handle(self, client: AsyncStreamClient[Any]) -> AsyncGenerator[None, A await client.send_packet(request) +def _get_runner_and_options_from_arg( + runner: Literal["asyncio", "uvloop", "trio"] +) -> tuple[Literal["asyncio", "trio"], dict[str, Any]]: + match runner: + case "asyncio": + print("using asyncio event loop") + return ("asyncio", {}) + case "uvloop": + import uvloop + + print("using uvloop") + return ("asyncio", {"loop_factory": uvloop.new_event_loop}) + case "trio": + print("using trio") + return ("trio", {}) + + def create_tcp_server( *, port: int, over_ssl: bool, - use_uvloop: bool, + runner: Literal["asyncio", "uvloop", "trio"], buffered: bool, readline: bool, context_reuse: bool, ) -> StandaloneTCPNetworkServer[Any, Any]: - asyncio_options = {} - if use_uvloop: - import uvloop - - asyncio_options["loop_factory"] = uvloop.new_event_loop - print("using uvloop") - else: - print("using asyncio event loop") + backend, options = _get_runner_and_options_from_arg(runner) ssl_context: ssl.SSLContext | None = None if over_ssl: ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) @@ -108,7 +118,8 @@ def create_tcp_server( protocol, EchoRequestHandlerInnerLoop() if context_reuse else EchoRequestHandler(), ssl=ssl_context, - runner_options=asyncio_options, + backend=backend, + runner_options=options, max_recv_size=65536, # Default buffer limit of asyncio streams ) @@ -132,11 +143,6 @@ def main() -> None: type=int, default=25000, ) - parser.add_argument( - "--uvloop", - dest="use_uvloop", - action="store_true", - ) parser.add_argument( "--ssl", dest="over_ssl", @@ -158,6 +164,11 @@ def main() -> None: action="store_true", ) + runner_parser = parser.add_mutually_exclusive_group() + runner_parser.add_argument("--uvloop", dest="runner", action="store_const", const="uvloop") + runner_parser.add_argument("--trio", dest="runner", action="store_const", const="trio") + runner_parser.set_defaults(runner="asyncio") + args = parser.parse_args() logging.basicConfig(level=getattr(logging, args.log_level), format="[ %(levelname)s ] [ %(name)s ] %(message)s") @@ -165,7 +176,7 @@ def main() -> None: print(f"Python version: {sys.version}") with create_tcp_server( port=args.port, - use_uvloop=args.use_uvloop, + runner=args.runner, over_ssl=args.over_ssl, buffered=args.buffered, readline=args.readline, @@ -177,5 +188,5 @@ def main() -> None: if __name__ == "__main__": try: main() - except KeyboardInterrupt: + except* KeyboardInterrupt: pass diff --git a/benchmark_server/servers/easynetwork_udp_echoserver.py b/benchmark_server/servers/easynetwork_udp_echoserver.py index d03db2e3..e071d2e6 100755 --- a/benchmark_server/servers/easynetwork_udp_echoserver.py +++ b/benchmark_server/servers/easynetwork_udp_echoserver.py @@ -6,7 +6,7 @@ import sys from collections.abc import AsyncGenerator from contextlib import AsyncExitStack -from typing import Any +from typing import Any, Literal from easynetwork.protocol import DatagramProtocol from easynetwork.serializers.abc import AbstractPacketSerializer @@ -30,11 +30,15 @@ def __init__(self, eager_tasks: bool) -> None: self._eager_tasks: bool = bool(eager_tasks) async def service_init(self, exit_stack: AsyncExitStack, server: Any) -> None: - import asyncio - if self._eager_tasks: - loop = asyncio.get_running_loop() - loop.set_task_factory(getattr(asyncio, "eager_task_factory")) + import asyncio + + try: + loop = asyncio.get_running_loop() + except RuntimeError: + pass + else: + loop.set_task_factory(getattr(asyncio, "eager_task_factory")) class EchoRequestHandlerNoTTL(_BaseRequestHandler): @@ -59,21 +63,31 @@ async def handle(self, client: AsyncDatagramClient[Any]) -> AsyncGenerator[float await client.send_packet(request) +def _get_runner_and_options_from_arg( + runner: Literal["asyncio", "uvloop", "trio"] +) -> tuple[Literal["asyncio", "trio"], dict[str, Any]]: + match runner: + case "asyncio": + print("using asyncio event loop") + return ("asyncio", {}) + case "uvloop": + import uvloop + + print("using uvloop") + return ("asyncio", {"loop_factory": uvloop.new_event_loop}) + case "trio": + print("using trio") + return ("trio", {}) + + def create_udp_server( *, port: int, - use_uvloop: bool, + runner: Literal["asyncio", "uvloop", "trio"], eager_tasks: bool, client_ttl: float, ) -> StandaloneUDPNetworkServer[Any, Any]: - asyncio_options = {} - if use_uvloop: - import uvloop - - asyncio_options["loop_factory"] = uvloop.new_event_loop - print("using uvloop") - else: - print("using asyncio event loop") + backend, options = _get_runner_and_options_from_arg(runner) if eager_tasks: print("with eager task start") if client_ttl > 0: @@ -88,7 +102,8 @@ def create_udp_server( port, DatagramProtocol(NoSerializer()), handler, - runner_options=asyncio_options, + backend=backend, + runner_options=options, ) @@ -111,11 +126,6 @@ def main() -> None: type=int, default=25000, ) - parser.add_argument( - "--uvloop", - dest="use_uvloop", - action="store_true", - ) parser.add_argument( "--eager-tasks", dest="eager_tasks", @@ -127,6 +137,10 @@ def main() -> None: type=float, default=0.0, ) + runner_parser = parser.add_mutually_exclusive_group() + runner_parser.add_argument("--uvloop", dest="runner", action="store_const", const="uvloop") + runner_parser.add_argument("--trio", dest="runner", action="store_const", const="trio") + runner_parser.set_defaults(runner="asyncio") args = parser.parse_args() @@ -135,7 +149,7 @@ def main() -> None: print(f"Python version: {sys.version}") with create_udp_server( port=args.port, - use_uvloop=args.use_uvloop, + runner=args.runner, eager_tasks=args.eager_tasks, client_ttl=args.client_ttl, ) as server: @@ -145,5 +159,5 @@ def main() -> None: if __name__ == "__main__": try: main() - except KeyboardInterrupt: + except* KeyboardInterrupt: pass diff --git a/benchmark_server/servers/requirements.txt b/benchmark_server/servers/requirements.txt index 061f846d..53ea1af7 100644 --- a/benchmark_server/servers/requirements.txt +++ b/benchmark_server/servers/requirements.txt @@ -2,5 +2,12 @@ # Please do not edit it manually. asyncio-dgram==2.1.2 +attrs==23.2.0 +cffi==1.16.0 +idna==3.7 +outcome==1.3.0.post0 +pycparser==2.22 sniffio==1.3.1 +sortedcontainers==2.4.0 +trio==0.26.0 uvloop==0.19.0; os_name == "posix" diff --git a/benchmark_server/servers/trio_tcp_echoserver.py b/benchmark_server/servers/trio_tcp_echoserver.py new file mode 100755 index 00000000..03d944e9 --- /dev/null +++ b/benchmark_server/servers/trio_tcp_echoserver.py @@ -0,0 +1,146 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import logging +import pathlib +import socket +import ssl +import sys +from typing import Any, Final, NoReturn + +import trio + +LOGGER: Final[logging.Logger] = logging.getLogger("trio server") + +ROOT_DIR: Final[pathlib.Path] = pathlib.Path(__file__).parent + + +async def echo_server(address: tuple[str, int]) -> NoReturn: + sock = trio.socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.setsockopt(socket.SOL_SOCKET, socket.SO_REUSEADDR, True) + await sock.bind(address) + sock.listen(256) + LOGGER.info(f"Server listening at {address}") + with sock: + async with trio.open_nursery() as nursery: + while True: + client, addr = await sock.accept() + LOGGER.info(f"Connection from {addr}") + nursery.start_soon(_echo_client, client, addr) + + +async def _echo_client(client: trio.socket.SocketType, addr: Any) -> None: + try: + client.setsockopt(socket.IPPROTO_TCP, socket.TCP_NODELAY, True) + except (OSError, NameError): + pass + + lock = trio.Lock() + with client: + while True: + data = await client.recv(102400) + if not data: + break + async with lock: + while data: + sent = await client.send(data) + data = data[sent:] + LOGGER.info(f"{addr}: Connection closed") + + +async def echo_stream_server(port: int, ssl_context: ssl.SSLContext | None) -> NoReturn: + listeners: list[trio.SocketListener] = await trio.open_tcp_listeners(port) + + LOGGER.info(f"Server listening at {', '.join(str(listener.socket.getsockname()) for listener in listeners)}") + + if ssl_context: + await trio.serve_listeners(echo_client_streams, [trio.SSLListener(listener, ssl_context) for listener in listeners]) + else: + await trio.serve_listeners(echo_client_streams, listeners) + + +def _getaddr_from_stream(stream: trio.SocketStream | trio.SSLStream[trio.SocketStream]) -> tuple[Any, ...]: + match stream: + case trio.SSLStream(transport_stream=socket_stream): + return socket_stream.socket.getpeername() + case _: + return stream.socket.getpeername() + + +async def echo_client_streams(stream: trio.SocketStream | trio.SSLStream[trio.SocketStream]) -> None: + addr = _getaddr_from_stream(stream) + LOGGER.info(f"Connection from {addr}") + + lock = trio.Lock() + async with stream: + while True: + try: + data = await stream.receive_some(102400) + except trio.BrokenResourceError: + break + if not data: + break + async with lock: + await stream.send_all(data) + LOGGER.info(f"{addr}: Connection closed") + + +def main() -> None: + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + server_mode_parser_group = parser.add_mutually_exclusive_group() + server_mode_parser_group.add_argument( + "--streams", + action="store_true", + ) + server_mode_parser_group.add_argument( + "--readline", + action="store_true", + ) + + parser.add_argument( + "-p", + "--port", + dest="port", + type=int, + default=25000, + ) + parser.add_argument( + "--ssl", + dest="over_ssl", + action="store_true", + ) + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="[ %(levelname)s ] [ %(name)s ] %(message)s") + + print(f"Python version: {sys.version}") + ssl_context: ssl.SSLContext | None = None + if args.over_ssl: + ssl_context = ssl.create_default_context(ssl.Purpose.CLIENT_AUTH) + ssl_context.load_cert_chain( + ROOT_DIR / "certs" / "ssl_cert.pem", + ROOT_DIR / "certs" / "ssl_key.pem", + ) + ssl_context.check_hostname = False + ssl_context.verify_mode = ssl.CERT_NONE + + port: int = args.port + if args.readline: + raise NotImplementedError("readline server model is not implemented") + elif args.streams: + trio.run(echo_stream_server, port, ssl_context) + else: + if ssl_context: + raise OSError("loop.sock_sendall() and loop.sock_recv() do not support SSL") + trio.run(echo_server, ("0.0.0.0", port)) + + +if __name__ == "__main__": + try: + main() + except* KeyboardInterrupt: + pass diff --git a/benchmark_server/servers/trio_udp_echoserver.py b/benchmark_server/servers/trio_udp_echoserver.py new file mode 100755 index 00000000..eea9e2a6 --- /dev/null +++ b/benchmark_server/servers/trio_udp_echoserver.py @@ -0,0 +1,67 @@ +#!/usr/bin/env python3 + +from __future__ import annotations + +import argparse +import logging +import pathlib +import socket +import sys +from typing import Any, Final, NoReturn + +import trio + +LOGGER: Final[logging.Logger] = logging.getLogger("trio server") + +ROOT_DIR: Final[pathlib.Path] = pathlib.Path(__file__).parent + + +async def echo_server(address: tuple[str, int]) -> NoReturn: + sock = trio.socket.socket(socket.AF_INET, socket.SOCK_DGRAM) + await sock.bind(address) + LOGGER.info(f"Server listening at {address}") + + with sock: + async with trio.open_nursery() as nursery: + while True: + datagram, addr = await sock.recvfrom(65536) + nursery.start_soon(_echo_client, sock, datagram, addr) + + +async def _echo_client(sock: trio.socket.SocketType, datagram: bytes, addr: Any) -> None: + await sock.sendto(datagram, addr) + + +def main() -> None: + parser = argparse.ArgumentParser(formatter_class=argparse.ArgumentDefaultsHelpFormatter) + + server_mode_parser_group = parser.add_mutually_exclusive_group() + server_mode_parser_group.add_argument( + "--streams", + action="store_true", + ) + + parser.add_argument( + "-p", + "--port", + dest="port", + type=int, + default=25000, + ) + + args = parser.parse_args() + + logging.basicConfig(level=logging.INFO, format="[ %(levelname)s ] [ %(name)s ] %(message)s") + + print(f"Python version: {sys.version}") + + port: int = args.port + + trio.run(echo_server, ("0.0.0.0", port)) + + +if __name__ == "__main__": + try: + main() + except* KeyboardInterrupt: + pass diff --git a/docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2.py b/docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2_asyncio.py similarity index 91% rename from docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2.py rename to docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2_asyncio.py index 2f358591..8a1cfcf5 100644 --- a/docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2.py +++ b/docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2_asyncio.py @@ -12,8 +12,8 @@ async def main() -> None: address = ("localhost", 9000) try: + client = AsyncTCPNetworkClient(address, protocol) async with asyncio.timeout(30): - client = AsyncTCPNetworkClient(address, protocol) await client.wait_connected() except TimeoutError: print(f"Could not connect to {address} after 30 seconds") diff --git a/docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2_backend_api.py b/docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2_backend_api.py new file mode 100644 index 00000000..5e88913c --- /dev/null +++ b/docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2_backend_api.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import asyncio + +from easynetwork.clients import AsyncTCPNetworkClient +from easynetwork.protocol import StreamProtocol +from easynetwork.serializers import JSONSerializer + + +async def main() -> None: + protocol = StreamProtocol(JSONSerializer()) + address = ("localhost", 9000) + + try: + client = AsyncTCPNetworkClient(address, protocol) + with client.backend().timeout(30): + await client.wait_connected() + except TimeoutError: + print(f"Could not connect to {address} after 30 seconds") + return + + async with client: + print(f"Remote address: {client.get_remote_address()}") + + ... + + +if __name__ == "__main__": + asyncio.run(main()) diff --git a/docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2_trio.py b/docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2_trio.py new file mode 100644 index 00000000..a61b6fdd --- /dev/null +++ b/docs/source/_include/examples/howto/tcp_clients/basics/api_async/connection_example2_trio.py @@ -0,0 +1,29 @@ +from __future__ import annotations + +import trio + +from easynetwork.clients import AsyncTCPNetworkClient +from easynetwork.protocol import StreamProtocol +from easynetwork.serializers import JSONSerializer + + +async def main() -> None: + protocol = StreamProtocol(JSONSerializer()) + address = ("localhost", 9000) + + try: + client = AsyncTCPNetworkClient(address, protocol) + with trio.fail_after(30): + await client.wait_connected() + except trio.TooSlowError: + print(f"Could not connect to {address} after 30 seconds") + return + + async with client: + print(f"Remote address: {client.get_remote_address()}") + + ... + + +if __name__ == "__main__": + trio.run(main) diff --git a/docs/source/_include/examples/howto/tcp_clients/usage/api_async.py b/docs/source/_include/examples/howto/tcp_clients/usage/api_async.py index 168fc34f..4868cc73 100644 --- a/docs/source/_include/examples/howto/tcp_clients/usage/api_async.py +++ b/docs/source/_include/examples/howto/tcp_clients/usage/api_async.py @@ -4,6 +4,8 @@ import socket from typing import Any +import trio + from easynetwork.clients import AsyncTCPNetworkClient from easynetwork.exceptions import StreamProtocolParseError from easynetwork.protocol import StreamProtocol @@ -25,7 +27,7 @@ async def recv_packet_example1(client: AsyncTCPNetworkClient[Any, Any]) -> None: print(f"Received packet: {packet!r}") -async def recv_packet_example2(client: AsyncTCPNetworkClient[Any, Any]) -> None: +async def recv_packet_example2_asyncio(client: AsyncTCPNetworkClient[Any, Any]) -> None: # [start] try: async with asyncio.timeout(30): @@ -36,7 +38,29 @@ async def recv_packet_example2(client: AsyncTCPNetworkClient[Any, Any]) -> None: print(f"Received packet: {packet!r}") -async def recv_packet_example3(client: AsyncTCPNetworkClient[Any, Any]) -> None: +async def recv_packet_example2_trio(client: AsyncTCPNetworkClient[Any, Any]) -> None: + # [start] + try: + with trio.fail_after(30): + packet = await client.recv_packet() + except trio.TooSlowError: + print("Timed out") + else: + print(f"Received packet: {packet!r}") + + +async def recv_packet_example2_backend_api(client: AsyncTCPNetworkClient[Any, Any]) -> None: + # [start] + try: + with client.backend().timeout(30): + packet = await client.recv_packet() + except TimeoutError: + print("Timed out") + else: + print(f"Received packet: {packet!r}") + + +async def recv_packet_example3_asyncio(client: AsyncTCPNetworkClient[Any, Any]) -> None: # [start] try: async with asyncio.timeout(30): @@ -49,6 +73,32 @@ async def recv_packet_example3(client: AsyncTCPNetworkClient[Any, Any]) -> None: print(f"Received packet: {packet!r}") +async def recv_packet_example3_trio(client: AsyncTCPNetworkClient[Any, Any]) -> None: + # [start] + try: + with trio.fail_after(30): + packet = await client.recv_packet() + except StreamProtocolParseError: + print("Received something, but was not valid") + except trio.TooSlowError: + print("Timed out") + else: + print(f"Received packet: {packet!r}") + + +async def recv_packet_example3_backend_api(client: AsyncTCPNetworkClient[Any, Any]) -> None: + # [start] + try: + with client.backend().timeout(30): + packet = await client.recv_packet() + except StreamProtocolParseError: + print("Received something, but was not valid") + except TimeoutError: + print("Timed out") + else: + print(f"Received packet: {packet!r}") + + async def recv_packet_example4(client: AsyncTCPNetworkClient[Any, Any]) -> None: # [start] all_packets = [p async for p in client.iter_received_packets()] diff --git a/docs/source/_include/examples/howto/tcp_servers/request_handler_explanation.py b/docs/source/_include/examples/howto/tcp_servers/request_handler_explanation.py index cb98f22b..87e335c6 100644 --- a/docs/source/_include/examples/howto/tcp_servers/request_handler_explanation.py +++ b/docs/source/_include/examples/howto/tcp_servers/request_handler_explanation.py @@ -8,6 +8,8 @@ from collections.abc import AsyncGenerator from typing import ClassVar +import trio + from easynetwork.exceptions import StreamProtocolParseError from easynetwork.lowlevel.socket import SocketAddress from easynetwork.servers import AsyncTCPNetworkServer @@ -171,7 +173,7 @@ async def handle( await client.send_packet(Response()) -class TimeoutContextRequestHandler(AsyncStreamRequestHandler[Request, Response]): +class TimeoutContextRequestHandlerAsyncIO(AsyncStreamRequestHandler[Request, Response]): async def handle( self, client: AsyncStreamClient[Response], @@ -186,6 +188,36 @@ async def handle( await client.send_packet(Response()) +class TimeoutContextRequestHandlerTrio(AsyncStreamRequestHandler[Request, Response]): + async def handle( + self, + client: AsyncStreamClient[Response], + ) -> AsyncGenerator[None, Request]: + try: + with trio.fail_after(30): + # The client has 30 seconds to send the request to the server. + request: Request = yield + except trio.TooSlowError: + await client.send_packet(TimedOut()) + else: + await client.send_packet(Response()) + + +class TimeoutContextRequestHandlerWithClientBackend(AsyncStreamRequestHandler[Request, Response]): + async def handle( + self, + client: AsyncStreamClient[Response], + ) -> AsyncGenerator[None, Request]: + try: + with client.backend().timeout(30): + # The client has 30 seconds to send the request to the server. + request: Request = yield + except TimeoutError: + await client.send_packet(TimedOut()) + else: + await client.send_packet(Response()) + + class TimeoutYieldedRequestHandler(AsyncStreamRequestHandler[Request, Response]): async def handle( self, @@ -252,7 +284,7 @@ async def handle( await client.send_packet(Response()) -class ServiceInitializationHookRequestHandler(AsyncStreamRequestHandler[Request, Response]): +class ServiceInitializationHookRequestHandlerAsyncIO(AsyncStreamRequestHandler[Request, Response]): async def service_init( self, exit_stack: contextlib.AsyncExitStack, @@ -275,6 +307,53 @@ def _service_quit(self) -> None: print("Service stopped") +class ServiceInitializationHookRequestHandlerTrio(AsyncStreamRequestHandler[Request, Response]): + async def service_init( + self, + exit_stack: contextlib.AsyncExitStack, + server: AsyncTCPNetworkServer[Request, Response], + ) -> None: + exit_stack.callback(self._service_quit) + + self.background_tasks = await exit_stack.enter_async_context(trio.open_nursery()) + + self.background_tasks.start_soon(self._service_actions) + + async def _service_actions(self) -> None: + while True: + await trio.sleep(1) + + # Do some stuff each second in background + ... + + def _service_quit(self) -> None: + print("Service stopped") + + +class ServiceInitializationHookRequestHandlerWithServerBackend(AsyncStreamRequestHandler[Request, Response]): + async def service_init( + self, + exit_stack: contextlib.AsyncExitStack, + server: AsyncTCPNetworkServer[Request, Response], + ) -> None: + exit_stack.callback(self._service_quit) + + self.backend = server.backend() + self.background_tasks = await exit_stack.enter_async_context(self.backend.create_task_group()) + + self.background_tasks.start_soon(self._service_actions) + + async def _service_actions(self) -> None: + while True: + await self.backend.sleep(1) + + # Do some stuff each second in background + ... + + def _service_quit(self) -> None: + print("Service stopped") + + class ClientContextRequestHandler(AsyncStreamRequestHandler[Request, Response]): client_addr_var: ClassVar[contextvars.ContextVar[SocketAddress]] client_addr_var = contextvars.ContextVar("client_addr") diff --git a/docs/source/_include/examples/howto/tcp_servers/standalone/server_trio.py b/docs/source/_include/examples/howto/tcp_servers/standalone/server_trio.py new file mode 100644 index 00000000..50f1acf7 --- /dev/null +++ b/docs/source/_include/examples/howto/tcp_servers/standalone/server_trio.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator +from contextlib import AsyncExitStack +from typing import Any + +import trio + +from easynetwork.protocol import StreamProtocol +from easynetwork.serializers import JSONSerializer +from easynetwork.servers import StandaloneTCPNetworkServer +from easynetwork.servers.handlers import AsyncStreamClient, AsyncStreamRequestHandler + + +class JSONProtocol(StreamProtocol[dict[str, Any], dict[str, Any]]): + def __init__(self) -> None: + super().__init__(JSONSerializer()) + + +class MyRequestHandler(AsyncStreamRequestHandler[dict[str, Any], dict[str, Any]]): + async def service_init(self, exit_stack: AsyncExitStack, server: Any) -> None: + # StandaloneTCPNetworkServer wraps an AsyncTCPNetworkServer instance. + # Therefore, "server" is still asynchronous. + + from easynetwork.servers import AsyncTCPNetworkServer + + assert isinstance(server, AsyncTCPNetworkServer) + + async def handle( + self, + client: AsyncStreamClient[dict[str, Any]], + ) -> AsyncGenerator[None, dict[str, Any]]: + request: dict[str, Any] = yield + + current_task = trio.lowlevel.current_task() + + response = {"task": current_task.name, "request": request} + await client.send_packet(response) + + +def main() -> None: + host, port = "localhost", 9000 + protocol = JSONProtocol() + handler = MyRequestHandler() + + # All the parameters are the same as AsyncTCPNetworkServer. + server = StandaloneTCPNetworkServer(host, port, protocol, handler, backend="trio") + + with server: + server.serve_forever() + + +if __name__ == "__main__": + main() diff --git a/docs/source/_include/examples/howto/udp_clients/usage/api_async.py b/docs/source/_include/examples/howto/udp_clients/usage/api_async.py index 21a73efe..0467cac3 100644 --- a/docs/source/_include/examples/howto/udp_clients/usage/api_async.py +++ b/docs/source/_include/examples/howto/udp_clients/usage/api_async.py @@ -4,6 +4,8 @@ import socket from typing import Any +import trio + from easynetwork.clients import AsyncUDPNetworkClient from easynetwork.exceptions import DatagramProtocolParseError @@ -23,7 +25,7 @@ async def recv_packet_example1(client: AsyncUDPNetworkClient[Any, Any]) -> None: print(f"Received packet: {packet!r}") -async def recv_packet_example2(client: AsyncUDPNetworkClient[Any, Any]) -> None: +async def recv_packet_example2_asyncio(client: AsyncUDPNetworkClient[Any, Any]) -> None: # [start] try: async with asyncio.timeout(30): @@ -34,7 +36,29 @@ async def recv_packet_example2(client: AsyncUDPNetworkClient[Any, Any]) -> None: print(f"Received packet: {packet!r}") -async def recv_packet_example3(client: AsyncUDPNetworkClient[Any, Any]) -> None: +async def recv_packet_example2_trio(client: AsyncUDPNetworkClient[Any, Any]) -> None: + # [start] + try: + with trio.fail_after(30): + packet = await client.recv_packet() + except trio.TooSlowError: + print("Timed out") + else: + print(f"Received packet: {packet!r}") + + +async def recv_packet_example2_backend_api(client: AsyncUDPNetworkClient[Any, Any]) -> None: + # [start] + try: + with client.backend().timeout(30): + packet = await client.recv_packet() + except TimeoutError: + print("Timed out") + else: + print(f"Received packet: {packet!r}") + + +async def recv_packet_example3_asyncio(client: AsyncUDPNetworkClient[Any, Any]) -> None: # [start] try: async with asyncio.timeout(30): @@ -47,6 +71,32 @@ async def recv_packet_example3(client: AsyncUDPNetworkClient[Any, Any]) -> None: print(f"Received packet: {packet!r}") +async def recv_packet_example3_trio(client: AsyncUDPNetworkClient[Any, Any]) -> None: + # [start] + try: + with trio.fail_after(30): + packet = await client.recv_packet() + except DatagramProtocolParseError: + print("Received something, but was not valid") + except trio.TooSlowError: + print("Timed out") + else: + print(f"Received packet: {packet!r}") + + +async def recv_packet_example3_backend_api(client: AsyncUDPNetworkClient[Any, Any]) -> None: + # [start] + try: + with client.backend().timeout(30): + packet = await client.recv_packet() + except DatagramProtocolParseError: + print("Received something, but was not valid") + except TimeoutError: + print("Timed out") + else: + print(f"Received packet: {packet!r}") + + async def recv_packet_example4(client: AsyncUDPNetworkClient[Any, Any]) -> None: # [start] all_packets = [p async for p in client.iter_received_packets()] diff --git a/docs/source/_include/examples/howto/udp_servers/request_handler_explanation.py b/docs/source/_include/examples/howto/udp_servers/request_handler_explanation.py index 40d33a58..72c873df 100644 --- a/docs/source/_include/examples/howto/udp_servers/request_handler_explanation.py +++ b/docs/source/_include/examples/howto/udp_servers/request_handler_explanation.py @@ -8,6 +8,8 @@ from collections.abc import AsyncGenerator from typing import ClassVar +import trio + from easynetwork.exceptions import DatagramProtocolParseError from easynetwork.lowlevel.socket import SocketAddress from easynetwork.servers import AsyncUDPNetworkServer @@ -115,7 +117,7 @@ def need_something_else(self, request: Request, client: AsyncDatagramClient[Resp return True -class TimeoutContextRequestHandler(AsyncDatagramRequestHandler[Request, Response]): +class TimeoutContextRequestHandlerAsyncIO(AsyncDatagramRequestHandler[Request, Response]): async def handle( self, client: AsyncDatagramClient[Response], @@ -138,6 +140,52 @@ async def handle( await client.send_packet(Response()) +class TimeoutContextRequestHandlerTrio(AsyncDatagramRequestHandler[Request, Response]): + async def handle( + self, + client: AsyncDatagramClient[Response], + ) -> AsyncGenerator[None, Request]: + # It is *never* useful to have a timeout for the 1st datagram + # because the datagram is already in the queue. + request: Request = yield + + ... + + await client.send_packet(Response()) + + try: + with trio.fail_after(30): + # The client has 30 seconds to send the 2nd request to the server. + another_request: Request = yield + except trio.TooSlowError: + await client.send_packet(TimedOut()) + else: + await client.send_packet(Response()) + + +class TimeoutContextRequestHandlerWithClientBackend(AsyncDatagramRequestHandler[Request, Response]): + async def handle( + self, + client: AsyncDatagramClient[Response], + ) -> AsyncGenerator[None, Request]: + # It is *never* useful to have a timeout for the 1st datagram + # because the datagram is already in the queue. + request: Request = yield + + ... + + await client.send_packet(Response()) + + try: + with client.backend().timeout(30): + # The client has 30 seconds to send the 2nd request to the server. + another_request: Request = yield + except TimeoutError: + await client.send_packet(TimedOut()) + else: + await client.send_packet(Response()) + + class TimeoutYieldedRequestHandler(AsyncDatagramRequestHandler[Request, Response]): async def handle( self, @@ -161,7 +209,7 @@ async def handle( await client.send_packet(Response()) -class ServiceInitializationHookRequestHandler(AsyncDatagramRequestHandler[Request, Response]): +class ServiceInitializationHookRequestHandlerAsyncIO(AsyncDatagramRequestHandler[Request, Response]): async def service_init( self, exit_stack: contextlib.AsyncExitStack, @@ -184,6 +232,53 @@ def _service_quit(self) -> None: print("Service stopped") +class ServiceInitializationHookRequestHandlerTrio(AsyncDatagramRequestHandler[Request, Response]): + async def service_init( + self, + exit_stack: contextlib.AsyncExitStack, + server: AsyncUDPNetworkServer[Request, Response], + ) -> None: + exit_stack.callback(self._service_quit) + + self.background_tasks = await exit_stack.enter_async_context(trio.open_nursery()) + + self.background_tasks.start_soon(self._service_actions) + + async def _service_actions(self) -> None: + while True: + await trio.sleep(1) + + # Do some stuff each second in background + ... + + def _service_quit(self) -> None: + print("Service stopped") + + +class ServiceInitializationHookRequestHandlerWithServerBackend(AsyncDatagramRequestHandler[Request, Response]): + async def service_init( + self, + exit_stack: contextlib.AsyncExitStack, + server: AsyncUDPNetworkServer[Request, Response], + ) -> None: + exit_stack.callback(self._service_quit) + + self.backend = server.backend() + self.background_tasks = await exit_stack.enter_async_context(self.backend.create_task_group()) + + self.background_tasks.start_soon(self._service_actions) + + async def _service_actions(self) -> None: + while True: + await self.backend.sleep(1) + + # Do some stuff each second in background + ... + + def _service_quit(self) -> None: + print("Service stopped") + + class ClientContextRequestHandler(AsyncDatagramRequestHandler[Request, Response]): client_addr_var: ClassVar[contextvars.ContextVar[SocketAddress]] client_addr_var = contextvars.ContextVar("client_addr") diff --git a/docs/source/_include/examples/howto/udp_servers/standalone/server_trio.py b/docs/source/_include/examples/howto/udp_servers/standalone/server_trio.py new file mode 100644 index 00000000..cc3ae620 --- /dev/null +++ b/docs/source/_include/examples/howto/udp_servers/standalone/server_trio.py @@ -0,0 +1,54 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator +from contextlib import AsyncExitStack +from typing import Any + +import trio + +from easynetwork.protocol import DatagramProtocol +from easynetwork.serializers import JSONSerializer +from easynetwork.servers import StandaloneUDPNetworkServer +from easynetwork.servers.handlers import AsyncDatagramClient, AsyncDatagramRequestHandler + + +class JSONProtocol(DatagramProtocol[dict[str, Any], dict[str, Any]]): + def __init__(self) -> None: + super().__init__(JSONSerializer()) + + +class MyRequestHandler(AsyncDatagramRequestHandler[dict[str, Any], dict[str, Any]]): + async def service_init(self, exit_stack: AsyncExitStack, server: Any) -> None: + # StandaloneUDPNetworkServer wraps an AsyncUDPNetworkServer instance. + # Therefore, "server" is still asynchronous. + + from easynetwork.servers import AsyncUDPNetworkServer + + assert isinstance(server, AsyncUDPNetworkServer) + + async def handle( + self, + client: AsyncDatagramClient[dict[str, Any]], + ) -> AsyncGenerator[None, dict[str, Any]]: + request: dict[str, Any] = yield + + current_task = trio.lowlevel.current_task() + + response = {"task": current_task.name, "request": request} + await client.send_packet(response) + + +def main() -> None: + host, port = "localhost", 9000 + protocol = JSONProtocol() + handler = MyRequestHandler() + + # All the parameters are the same as AsyncUDPNetworkServer. + server = StandaloneUDPNetworkServer(host, port, protocol, handler, backend="trio") + + with server: + server.serve_forever() + + +if __name__ == "__main__": + main() diff --git a/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_client.py b/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_client_asyncio.py similarity index 100% rename from docs/source/_include/examples/tutorials/echo_client_server_tcp/async_client.py rename to docs/source/_include/examples/tutorials/echo_client_server_tcp/async_client_asyncio.py diff --git a/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_client_trio.py b/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_client_trio.py new file mode 100644 index 00000000..91b9ab92 --- /dev/null +++ b/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_client_trio.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import sys + +import trio + +from easynetwork.clients import AsyncTCPNetworkClient + +from json_protocol import JSONProtocol + + +async def main() -> None: + host = "localhost" + port = 9000 + protocol = JSONProtocol() + + # Connect to server + async with AsyncTCPNetworkClient((host, port), protocol) as client: + # Send data + request = {"command-line arguments": sys.argv[1:]} + await client.send_packet(request) + + # Receive data from the server and shut down + response = await client.recv_packet() + + print(f"Sent: {request}") + print(f"Received: {response}") + + +if __name__ == "__main__": + trio.run(main) diff --git a/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_server.py b/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_server_asyncio.py similarity index 76% rename from docs/source/_include/examples/tutorials/echo_client_server_tcp/async_server.py rename to docs/source/_include/examples/tutorials/echo_client_server_tcp/async_server_asyncio.py index 398d1b28..d82f03a6 100644 --- a/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_server.py +++ b/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_server_asyncio.py @@ -15,11 +15,11 @@ async def main() -> None: handler = EchoRequestHandler() async with AsyncTCPNetworkServer(host, port, protocol, handler) as server: - try: - await server.serve_forever() - except asyncio.CancelledError: - pass + await server.serve_forever() if __name__ == "__main__": - asyncio.run(main()) + try: + asyncio.run(main()) + except* KeyboardInterrupt: + pass diff --git a/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_server_trio.py b/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_server_trio.py new file mode 100644 index 00000000..fb035ed1 --- /dev/null +++ b/docs/source/_include/examples/tutorials/echo_client_server_tcp/async_server_trio.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import trio + +from easynetwork.servers import AsyncTCPNetworkServer + +from echo_request_handler import EchoRequestHandler +from json_protocol import JSONProtocol + + +async def main() -> None: + host = None + port = 9000 + protocol = JSONProtocol() + handler = EchoRequestHandler() + + async with AsyncTCPNetworkServer(host, port, protocol, handler) as server: + await server.serve_forever() + + +if __name__ == "__main__": + try: + trio.run(main) + except* KeyboardInterrupt: + pass diff --git a/docs/source/_include/examples/tutorials/echo_client_server_tcp/server.py b/docs/source/_include/examples/tutorials/echo_client_server_tcp/server.py index 77851654..b7748b35 100644 --- a/docs/source/_include/examples/tutorials/echo_client_server_tcp/server.py +++ b/docs/source/_include/examples/tutorials/echo_client_server_tcp/server.py @@ -13,11 +13,11 @@ def main() -> None: handler = EchoRequestHandler() with StandaloneTCPNetworkServer(host, port, protocol, handler) as server: - try: - server.serve_forever() - except KeyboardInterrupt: - pass + server.serve_forever() if __name__ == "__main__": - main() + try: + main() + except* KeyboardInterrupt: + pass diff --git a/docs/source/_include/examples/tutorials/echo_client_server_udp/async_client.py b/docs/source/_include/examples/tutorials/echo_client_server_udp/async_client_asyncio.py similarity index 100% rename from docs/source/_include/examples/tutorials/echo_client_server_udp/async_client.py rename to docs/source/_include/examples/tutorials/echo_client_server_udp/async_client_asyncio.py diff --git a/docs/source/_include/examples/tutorials/echo_client_server_udp/async_client_trio.py b/docs/source/_include/examples/tutorials/echo_client_server_udp/async_client_trio.py new file mode 100644 index 00000000..0817ac7e --- /dev/null +++ b/docs/source/_include/examples/tutorials/echo_client_server_udp/async_client_trio.py @@ -0,0 +1,31 @@ +from __future__ import annotations + +import sys + +import trio + +from easynetwork.clients import AsyncUDPNetworkClient + +from json_protocol import JSONDatagramProtocol + + +async def main() -> None: + host = "localhost" + port = 9000 + protocol = JSONDatagramProtocol() + + # Connect to server + async with AsyncUDPNetworkClient((host, port), protocol) as client: + # Send data + request = {"command-line arguments": sys.argv[1:]} + await client.send_packet(request) + + # Receive data from the server and shut down + response = await client.recv_packet() + + print(f"Sent: {request}") + print(f"Received: {response}") + + +if __name__ == "__main__": + trio.run(main) diff --git a/docs/source/_include/examples/tutorials/echo_client_server_udp/async_server.py b/docs/source/_include/examples/tutorials/echo_client_server_udp/async_server_asyncio.py similarity index 76% rename from docs/source/_include/examples/tutorials/echo_client_server_udp/async_server.py rename to docs/source/_include/examples/tutorials/echo_client_server_udp/async_server_asyncio.py index 44c6daca..8c2a5517 100644 --- a/docs/source/_include/examples/tutorials/echo_client_server_udp/async_server.py +++ b/docs/source/_include/examples/tutorials/echo_client_server_udp/async_server_asyncio.py @@ -15,11 +15,11 @@ async def main() -> None: handler = EchoRequestHandler() async with AsyncUDPNetworkServer(host, port, protocol, handler) as server: - try: - await server.serve_forever() - except asyncio.CancelledError: - pass + await server.serve_forever() if __name__ == "__main__": - asyncio.run(main()) + try: + asyncio.run(main()) + except* KeyboardInterrupt: + pass diff --git a/docs/source/_include/examples/tutorials/echo_client_server_udp/async_server_trio.py b/docs/source/_include/examples/tutorials/echo_client_server_udp/async_server_trio.py new file mode 100644 index 00000000..61333ffd --- /dev/null +++ b/docs/source/_include/examples/tutorials/echo_client_server_udp/async_server_trio.py @@ -0,0 +1,25 @@ +from __future__ import annotations + +import trio + +from easynetwork.servers import AsyncUDPNetworkServer + +from echo_request_handler import EchoRequestHandler +from json_protocol import JSONDatagramProtocol + + +async def main() -> None: + host = None + port = 9000 + protocol = JSONDatagramProtocol() + handler = EchoRequestHandler() + + async with AsyncUDPNetworkServer(host, port, protocol, handler) as server: + await server.serve_forever() + + +if __name__ == "__main__": + try: + trio.run(main) + except* KeyboardInterrupt: + pass diff --git a/docs/source/_include/examples/tutorials/echo_client_server_udp/server.py b/docs/source/_include/examples/tutorials/echo_client_server_udp/server.py index 57bf3b29..a8970e73 100644 --- a/docs/source/_include/examples/tutorials/echo_client_server_udp/server.py +++ b/docs/source/_include/examples/tutorials/echo_client_server_udp/server.py @@ -13,11 +13,11 @@ def main() -> None: handler = EchoRequestHandler() with StandaloneUDPNetworkServer(host, port, protocol, handler) as server: - try: - server.serve_forever() - except KeyboardInterrupt: - pass + server.serve_forever() if __name__ == "__main__": - main() + try: + main() + except* KeyboardInterrupt: + pass diff --git a/docs/source/_include/examples/tutorials/ftp_server/async_server.py b/docs/source/_include/examples/tutorials/ftp_server/async_server_asyncio.py similarity index 85% rename from docs/source/_include/examples/tutorials/ftp_server/async_server.py rename to docs/source/_include/examples/tutorials/ftp_server/async_server_asyncio.py index 78047b21..67953643 100644 --- a/docs/source/_include/examples/tutorials/ftp_server/async_server.py +++ b/docs/source/_include/examples/tutorials/ftp_server/async_server_asyncio.py @@ -34,9 +34,9 @@ async def main() -> None: format="[ %(levelname)s ] [ %(name)s ] %(message)s", ) async with AsyncFTPServer() as server: - try: - await server.serve_forever() - except asyncio.CancelledError: - pass + await server.serve_forever() - asyncio.run(main()) + try: + asyncio.run(main()) + except* KeyboardInterrupt: + pass diff --git a/docs/source/_include/examples/tutorials/ftp_server/async_server_trio.py b/docs/source/_include/examples/tutorials/ftp_server/async_server_trio.py new file mode 100644 index 00000000..47c536ce --- /dev/null +++ b/docs/source/_include/examples/tutorials/ftp_server/async_server_trio.py @@ -0,0 +1,43 @@ +from __future__ import annotations + +from collections.abc import Sequence + +from easynetwork.servers import AsyncTCPNetworkServer + +from ftp_reply import FTPReply +from ftp_request import FTPRequest +from ftp_server_protocol import FTPServerProtocol +from ftp_server_request_handler import FTPRequestHandler + + +class AsyncFTPServer(AsyncTCPNetworkServer[FTPRequest, FTPReply]): + def __init__( + self, + host: str | Sequence[str] | None = None, + port: int = 21000, + ) -> None: + super().__init__( + host, + port, + FTPServerProtocol(), + FTPRequestHandler(), + ) + + +if __name__ == "__main__": + import logging + + import trio + + async def main() -> None: + logging.basicConfig( + level=logging.INFO, + format="[ %(levelname)s ] [ %(name)s ] %(message)s", + ) + async with AsyncFTPServer() as server: + await server.serve_forever() + + try: + trio.run(main) + except* KeyboardInterrupt: + pass diff --git a/docs/source/_include/sync-async-variants.rst b/docs/source/_include/sync-async-variants.rst index ce4f7ad6..e71f4c6b 100644 --- a/docs/source/_include/sync-async-variants.rst +++ b/docs/source/_include/sync-async-variants.rst @@ -6,5 +6,5 @@ * Asynchronous API with ``async def`` functions, using an :term:`asynchronous framework` to perform I/O operations. - All asynchronous API examples assume that you are using :mod:`asyncio`, + All asynchronous API examples assume that you are using either :mod:`asyncio` or :mod:`trio`, but you can use a different library thanks to the :doc:`asynchronous backend engine API `. diff --git a/docs/source/_static/css/details.css b/docs/source/_static/css/details.css index 4a556b99..a68db02b 100644 --- a/docs/source/_static/css/details.css +++ b/docs/source/_static/css/details.css @@ -1,3 +1,11 @@ details { margin-bottom: 1rem; } + +details > summary { + cursor: pointer; +} + +details[open] > summary { + margin-bottom: 1rem; +} diff --git a/docs/source/api/lowlevel/async/backend.rst b/docs/source/api/lowlevel/async/backend.rst index 81310011..d0a7fba9 100644 --- a/docs/source/api/lowlevel/async/backend.rst +++ b/docs/source/api/lowlevel/async/backend.rst @@ -13,10 +13,85 @@ Asynchronous Backend Engine API Introduction ============ -.. todo:: +In order not to depend on a single implementation of asynchronous operations (e.g. ``asyncio``), here is a mini-framework +to manage different implementations and keep EasyNetwork unaware of the library used. - Explain this big thing. +.. admonition:: Why not just use anyio directly? + Short answer: Because I don't want to. + + .. collapse:: Click here to expand/collapse the long answer + + The main problem with :mod:`anyio` is the simple fact that it is a framework that already encapsulates the sockets + without providing a way to manipulate the underlying transport, **and this is normal**; it would be a horror to maintain. + + But as a result, the high-level API does not expose the features I need, such as: + + * :class:`anyio.abc.ByteReceiveStream` not having a ``receive_into(buffer)`` method; + + * Implementing zero copy sending of multi-byte buffers with :meth:`~socket.socket.sendmsg`; + + * :func:`anyio.connect_tcp`, :func:`anyio.connect_unix` and :func:`anyio.create_udp_socket` not taking an already connected socket; + + * Managed tasks like :class:`asyncio.Task`; + + * and the list goes on... + + The second problem is having anyio as a dependency. :mod:`asyncio` is part of the standard library, so why would I use an external + (and large) project to manage :mod:`asyncio` and make it mandatory? + Also, it would be heavier to write :mod:`asyncio`-only code if :mod:`anyio` is not installed. + +Usage +----- + +Use The Interface Provided By The High-level API +^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +All asynchronous objects relying on an :class:`AsyncBackend` object have a ``backend()`` method: + +* High-level clients ( :meth:`.AbstractAsyncNetworkClient.backend` ). + +* High-level servers ( :meth:`.AbstractAsyncNetworkServer.backend` ). + + * This includes the clients created for the request handlers ( :meth:`.AsyncBaseClientInterface.backend` ). + +* Low-level endpoints ( :meth:`.AsyncStreamEndpoint.backend` and :meth:`.AsyncDatagramEndpoint.backend` ). + +* Low-level servers ( :meth:`.AsyncStreamServer.backend` and :meth:`.AsyncDatagramServer.backend` ). + + * This includes the clients created for the request handlers: + + * AsyncStreamServer: :meth:`.ConnectedStreamClient.backend`. + + * AsyncDatagramServer: :meth:`.DatagramClientContext.backend`. + +* Data transport adapters ( :meth:`.AsyncBaseTransport.backend` ). + +Obtain An Object By Yourself +^^^^^^^^^^^^^^^^^^^^^^^^^^^^ + +You can use :func:`.new_builtin_backend` to have a backend instance: + +>>> from easynetwork.lowlevel.api_async.backend.utils import new_builtin_backend +>>> new_builtin_backend("asyncio") + +>>> new_builtin_backend("trio") + + +You can also let :mod:`sniffio` determine which backend should be used via :func:`.ensure_backend`: + +>>> from easynetwork.lowlevel.api_async.backend.utils import ensure_backend +>>> import asyncio, trio +>>> +>>> async def main(): +... return ensure_backend(None) +... +>>> asyncio.run(main()) + +>>> trio.run(main) + + +------ Backend Interface ================= @@ -66,6 +141,8 @@ Shielding From Task Cancellation Creating Concurrent Tasks ^^^^^^^^^^^^^^^^^^^^^^^^^ +.. automethod:: AsyncBackend.gather + .. automethod:: AsyncBackend.create_task_group .. autoclass:: TaskGroup @@ -102,6 +179,13 @@ Timeouts Networking ---------- +DNS +^^^ + +.. automethod:: AsyncBackend.getaddrinfo + +.. automethod:: AsyncBackend.getnameinfo + Opening Network Connections ^^^^^^^^^^^^^^^^^^^^^^^^^^^ diff --git a/docs/source/conf.py b/docs/source/conf.py index f5b5935b..a429156c 100644 --- a/docs/source/conf.py +++ b/docs/source/conf.py @@ -117,6 +117,8 @@ "sniffio": ("https://sniffio.readthedocs.io/en/latest", None), "cbor2": ("https://cbor2.readthedocs.io/en/stable", None), "msgpack": ("https://msgpack-python.readthedocs.io/en/stable", None), + "trio": ("https://trio.readthedocs.io/en/stable", None), + "anyio": ("https://anyio.readthedocs.io/en/stable/", None), } diff --git a/docs/source/howto/advanced/standalone_servers.rst b/docs/source/howto/advanced/standalone_servers.rst index 5a02e1c2..db8d2c2c 100644 --- a/docs/source/howto/advanced/standalone_servers.rst +++ b/docs/source/howto/advanced/standalone_servers.rst @@ -16,8 +16,6 @@ Standalone servers are classes that can create ready-to-run servers with no pred (i.e., no :keyword:`async` / :keyword:`await`). They use (and block) a thread to accept requests, and their methods are meant to be used by other threads for control (e.g. shutdown). -By default, the runner is ``"asyncio"``, but it can be changed during object creation. - Server Object ============= @@ -37,6 +35,31 @@ Server Object :emphasize-lines: 10,21-26,46-50 +Use An Other Runner +------------------- + +By default, the runner is ``"asyncio"``, but it can be changed during object creation. + +.. seealso:: + + :func:`.new_builtin_backend` + The token is passed to this function. + +.. tabs:: + + .. group-tab:: StandaloneTCPNetworkServer + + .. literalinclude:: ../../_include/examples/howto/tcp_servers/standalone/server_trio.py + :linenos: + :emphasize-lines: 7,35,37,47 + + .. group-tab:: StandaloneUDPNetworkServer + + .. literalinclude:: ../../_include/examples/howto/udp_servers/standalone/server_trio.py + :linenos: + :emphasize-lines: 7,35,37,47 + + Run Server In Background ------------------------ diff --git a/docs/source/howto/tcp_clients.rst b/docs/source/howto/tcp_clients.rst index 01f316cb..2bb73e80 100644 --- a/docs/source/howto/tcp_clients.rst +++ b/docs/source/howto/tcp_clients.rst @@ -50,10 +50,28 @@ You need the host address (domain name or IP) and the port of connection in orde You can control the connection timeout by adding a timeout scope using the :term:`asynchronous framework`: - .. literalinclude:: ../_include/examples/howto/tcp_clients/basics/api_async/connection_example2.py - :pyobject: main - :lineno-match: - :emphasize-lines: 5-13 + .. tabs:: + + .. group-tab:: Using ``asyncio`` + + .. literalinclude:: ../_include/examples/howto/tcp_clients/basics/api_async/connection_example2_asyncio.py + :pyobject: main + :lineno-match: + :emphasize-lines: 5-13 + + .. group-tab:: Using ``trio`` + + .. literalinclude:: ../_include/examples/howto/tcp_clients/basics/api_async/connection_example2_trio.py + :pyobject: main + :lineno-match: + :emphasize-lines: 5-13 + + .. group-tab:: Using the ``AsyncBackend`` API + + .. literalinclude:: ../_include/examples/howto/tcp_clients/basics/api_async/connection_example2_backend_api.py + :pyobject: main + :lineno-match: + :emphasize-lines: 5-13 .. note:: @@ -152,13 +170,33 @@ You get the next available packet, already parsed. Extraneous data is kept for t :dedent: :linenos: - You can control the receive timeout by adding a timeout scope using the asynchronous framework: + You can control the receive timeout by adding a timeout scope using the :term:`asynchronous framework`: - .. literalinclude:: ../_include/examples/howto/tcp_clients/usage/api_async.py - :pyobject: recv_packet_example2 - :start-after: [start] - :dedent: - :linenos: + .. tabs:: + + .. group-tab:: Using ``asyncio`` + + .. literalinclude:: ../_include/examples/howto/tcp_clients/usage/api_async.py + :pyobject: recv_packet_example2_asyncio + :start-after: [start] + :dedent: + :linenos: + + .. group-tab:: Using ``trio`` + + .. literalinclude:: ../_include/examples/howto/tcp_clients/usage/api_async.py + :pyobject: recv_packet_example2_trio + :start-after: [start] + :dedent: + :linenos: + + .. group-tab:: Using the ``AsyncBackend`` API + + .. literalinclude:: ../_include/examples/howto/tcp_clients/usage/api_async.py + :pyobject: recv_packet_example2_backend_api + :start-after: [start] + :dedent: + :linenos: .. tip:: @@ -178,12 +216,34 @@ You get the next available packet, already parsed. Extraneous data is kept for t .. group-tab:: Asynchronous - .. literalinclude:: ../_include/examples/howto/tcp_clients/usage/api_async.py - :pyobject: recv_packet_example3 - :start-after: [start] - :dedent: - :linenos: - :emphasize-lines: 4-5 + .. tabs:: + + .. group-tab:: Using ``asyncio`` + + .. literalinclude:: ../_include/examples/howto/tcp_clients/usage/api_async.py + :pyobject: recv_packet_example3_asyncio + :start-after: [start] + :dedent: + :linenos: + :emphasize-lines: 4-5 + + .. group-tab:: Using ``trio`` + + .. literalinclude:: ../_include/examples/howto/tcp_clients/usage/api_async.py + :pyobject: recv_packet_example3_trio + :start-after: [start] + :dedent: + :linenos: + :emphasize-lines: 4-5 + + .. group-tab:: Using the ``AsyncBackend`` API + + .. literalinclude:: ../_include/examples/howto/tcp_clients/usage/api_async.py + :pyobject: recv_packet_example3_backend_api + :start-after: [start] + :dedent: + :linenos: + :emphasize-lines: 4-5 Receiving Multiple Packets At Once diff --git a/docs/source/howto/tcp_servers.rst b/docs/source/howto/tcp_servers.rst index 54d6385b..545411ec 100644 --- a/docs/source/howto/tcp_servers.rst +++ b/docs/source/howto/tcp_servers.rst @@ -170,18 +170,38 @@ Cancellation And Timeouts Since all :exc:`BaseException` subclasses are thrown into the generator, you can apply a timeout to the read stream using the :term:`asynchronous framework` (the cancellation exception is retrieved in the generator): - .. literalinclude:: ../_include/examples/howto/tcp_servers/request_handler_explanation.py - :pyobject: TimeoutContextRequestHandler.handle - :dedent: - :linenos: - :emphasize-lines: 6,9-10 + .. tabs:: + + .. group-tab:: Using ``asyncio`` + + .. literalinclude:: ../_include/examples/howto/tcp_servers/request_handler_explanation.py + :pyobject: TimeoutContextRequestHandlerAsyncIO.handle + :dedent: + :linenos: + :emphasize-lines: 6,9-10 + + .. group-tab:: Using ``trio`` + + .. literalinclude:: ../_include/examples/howto/tcp_servers/request_handler_explanation.py + :pyobject: TimeoutContextRequestHandlerTrio.handle + :dedent: + :linenos: + :emphasize-lines: 6,9-10 + + .. group-tab:: Using the ``AsyncBackend`` API + + .. literalinclude:: ../_include/examples/howto/tcp_servers/request_handler_explanation.py + :pyobject: TimeoutContextRequestHandlerWithClientBackend.handle + :dedent: + :linenos: + :emphasize-lines: 6,9-10 .. warning:: Note that this behavior works because the generator is always executed and closed in the same asynchronous task for the current implementation. - This feature is available so that features like ``anyio.CancelScope`` can be used. + This feature is available so that features like :class:`trio.CancelScope` can be used. However, it may be removed in a future release. @@ -224,12 +244,34 @@ at the beginning of the :meth:`~.AsyncTCPNetworkServer.serve_forever` task to se This allows you to do something like this: -.. literalinclude:: ../_include/examples/howto/tcp_servers/request_handler_explanation.py - :pyobject: ServiceInitializationHookRequestHandler - :start-after: ServiceInitializationHookRequestHandler - :dedent: - :linenos: - :emphasize-lines: 1 +.. tabs:: + + .. group-tab:: Using ``asyncio`` + + .. literalinclude:: ../_include/examples/howto/tcp_servers/request_handler_explanation.py + :pyobject: ServiceInitializationHookRequestHandlerAsyncIO + :start-after: ServiceInitializationHookRequestHandlerAsyncIO + :dedent: + :linenos: + :emphasize-lines: 1 + + .. group-tab:: Using ``trio`` + + .. literalinclude:: ../_include/examples/howto/tcp_servers/request_handler_explanation.py + :pyobject: ServiceInitializationHookRequestHandlerTrio + :start-after: ServiceInitializationHookRequestHandlerTrio + :dedent: + :linenos: + :emphasize-lines: 1 + + .. group-tab:: Using the ``AsyncBackend`` API + + .. literalinclude:: ../_include/examples/howto/tcp_servers/request_handler_explanation.py + :pyobject: ServiceInitializationHookRequestHandlerWithServerBackend + :start-after: ServiceInitializationHookRequestHandlerWithServerBackend + :dedent: + :linenos: + :emphasize-lines: 1,8,15 Per-client variables (``contextvars`` integration) diff --git a/docs/source/howto/udp_clients.rst b/docs/source/howto/udp_clients.rst index 4838012b..672dcbd4 100644 --- a/docs/source/howto/udp_clients.rst +++ b/docs/source/howto/udp_clients.rst @@ -146,11 +146,31 @@ You get the next available packet, already parsed. You can control the receive timeout by adding a timeout scope using the :term:`asynchronous framework` : - .. literalinclude:: ../_include/examples/howto/udp_clients/usage/api_async.py - :pyobject: recv_packet_example2 - :start-after: [start] - :dedent: - :linenos: + .. tabs:: + + .. group-tab:: Using ``asyncio`` + + .. literalinclude:: ../_include/examples/howto/udp_clients/usage/api_async.py + :pyobject: recv_packet_example2_asyncio + :start-after: [start] + :dedent: + :linenos: + + .. group-tab:: Using ``trio`` + + .. literalinclude:: ../_include/examples/howto/udp_clients/usage/api_async.py + :pyobject: recv_packet_example2_trio + :start-after: [start] + :dedent: + :linenos: + + .. group-tab:: Using the ``AsyncBackend`` API + + .. literalinclude:: ../_include/examples/howto/udp_clients/usage/api_async.py + :pyobject: recv_packet_example2_backend_api + :start-after: [start] + :dedent: + :linenos: .. tip:: @@ -170,12 +190,34 @@ You get the next available packet, already parsed. .. group-tab:: Asynchronous - .. literalinclude:: ../_include/examples/howto/udp_clients/usage/api_async.py - :pyobject: recv_packet_example3 - :start-after: [start] - :dedent: - :linenos: - :emphasize-lines: 4-5 + .. tabs:: + + .. group-tab:: Using ``asyncio`` + + .. literalinclude:: ../_include/examples/howto/udp_clients/usage/api_async.py + :pyobject: recv_packet_example3_asyncio + :start-after: [start] + :dedent: + :linenos: + :emphasize-lines: 4-5 + + .. group-tab:: Using ``trio`` + + .. literalinclude:: ../_include/examples/howto/udp_clients/usage/api_async.py + :pyobject: recv_packet_example3_trio + :start-after: [start] + :dedent: + :linenos: + :emphasize-lines: 4-5 + + .. group-tab:: Using the ``AsyncBackend`` API + + .. literalinclude:: ../_include/examples/howto/udp_clients/usage/api_async.py + :pyobject: recv_packet_example3_backend_api + :start-after: [start] + :dedent: + :linenos: + :emphasize-lines: 4-5 Receiving Multiple Packets At Once diff --git a/docs/source/howto/udp_servers.rst b/docs/source/howto/udp_servers.rst index deca639f..a0e771dd 100644 --- a/docs/source/howto/udp_servers.rst +++ b/docs/source/howto/udp_servers.rst @@ -129,18 +129,38 @@ Cancellation And Timeouts Since all :exc:`BaseException` subclasses are thrown into the generator, you can apply a timeout to the read stream using the :term:`asynchronous framework` (the cancellation exception is retrieved in the generator): - .. literalinclude:: ../_include/examples/howto/udp_servers/request_handler_explanation.py - :pyobject: TimeoutContextRequestHandler.handle - :dedent: - :linenos: - :emphasize-lines: 14,17-18 + .. tabs:: + + .. group-tab:: Using ``asyncio`` + + .. literalinclude:: ../_include/examples/howto/udp_servers/request_handler_explanation.py + :pyobject: TimeoutContextRequestHandlerAsyncIO.handle + :dedent: + :linenos: + :emphasize-lines: 14,17-18 + + .. group-tab:: Using ``trio`` + + .. literalinclude:: ../_include/examples/howto/udp_servers/request_handler_explanation.py + :pyobject: TimeoutContextRequestHandlerTrio.handle + :dedent: + :linenos: + :emphasize-lines: 14,17-18 + + .. group-tab:: Using the ``AsyncBackend`` API + + .. literalinclude:: ../_include/examples/howto/udp_servers/request_handler_explanation.py + :pyobject: TimeoutContextRequestHandlerWithClientBackend.handle + :dedent: + :linenos: + :emphasize-lines: 14,17-18 .. warning:: Note that this behavior works because the generator is always executed and closed in the same asynchronous task for the current implementation. - This feature is available so that features like ``anyio.CancelScope`` can be used. + This feature is available so that features like :class:`trio.CancelScope` can be used. However, it may be removed in a future release. @@ -152,12 +172,34 @@ at the beginning of the :meth:`~.AsyncUDPNetworkServer.serve_forever` task to se This allows you to do something like this: -.. literalinclude:: ../_include/examples/howto/udp_servers/request_handler_explanation.py - :pyobject: ServiceInitializationHookRequestHandler - :start-after: ServiceInitializationHookRequestHandler - :dedent: - :linenos: - :emphasize-lines: 1 +.. tabs:: + + .. group-tab:: Using ``asyncio`` + + .. literalinclude:: ../_include/examples/howto/udp_servers/request_handler_explanation.py + :pyobject: ServiceInitializationHookRequestHandlerAsyncIO + :start-after: ServiceInitializationHookRequestHandlerAsyncIO + :dedent: + :linenos: + :emphasize-lines: 1 + + .. group-tab:: Using ``trio`` + + .. literalinclude:: ../_include/examples/howto/udp_servers/request_handler_explanation.py + :pyobject: ServiceInitializationHookRequestHandlerTrio + :start-after: ServiceInitializationHookRequestHandlerTrio + :dedent: + :linenos: + :emphasize-lines: 1 + + .. group-tab:: Using the ``AsyncBackend`` API + + .. literalinclude:: ../_include/examples/howto/udp_servers/request_handler_explanation.py + :pyobject: ServiceInitializationHookRequestHandlerWithServerBackend + :start-after: ServiceInitializationHookRequestHandlerWithServerBackend + :dedent: + :linenos: + :emphasize-lines: 1,8,15 Per-client variables (``contextvars`` integration) diff --git a/docs/source/quickstart/install.rst b/docs/source/quickstart/install.rst index 42e89882..0a9f6319 100644 --- a/docs/source/quickstart/install.rst +++ b/docs/source/quickstart/install.rst @@ -24,6 +24,17 @@ Here is the full list: * ``msgpack``: Installs the required dependencies for :class:`.MessagePackSerializer`. +* Asynchronous I/O extensions: + + * ``trio``: Installs the *minimum* version supported of :github:repo:`trio `. + +.. warning:: + + :mod:`trio` is an alpha project and we reserve the right to make semantic changes to the backend implementation + and **update the minimum supported version at any time**. + + Also, to avoid having to make a new release for every 0.x release, the minor is *not* pinned. Keep in mind that this can lead + to `breaking changes `_ after updating trio. Example where the ``cbor`` and ``msgpack`` extensions are installed: diff --git a/docs/source/tutorials/echo_client_server_tcp.rst b/docs/source/tutorials/echo_client_server_tcp.rst index d35d2260..af1f91ab 100644 --- a/docs/source/tutorials/echo_client_server_tcp.rst +++ b/docs/source/tutorials/echo_client_server_tcp.rst @@ -104,9 +104,15 @@ and the request handler instance. :linenos: :caption: server.py - .. group-tab:: Asynchronous + .. group-tab:: Asynchronous (asyncio) - .. literalinclude:: ../_include/examples/tutorials/echo_client_server_tcp/async_server.py + .. literalinclude:: ../_include/examples/tutorials/echo_client_server_tcp/async_server_asyncio.py + :linenos: + :caption: server.py + + .. group-tab:: Asynchronous (trio) + + .. literalinclude:: ../_include/examples/tutorials/echo_client_server_tcp/async_server_trio.py :linenos: :caption: server.py @@ -129,9 +135,15 @@ This is the client side: :linenos: :caption: client.py - .. group-tab:: Asynchronous + .. group-tab:: Asynchronous (asyncio) + + .. literalinclude:: ../_include/examples/tutorials/echo_client_server_tcp/async_client_asyncio.py + :linenos: + :caption: client.py + + .. group-tab:: Asynchronous (trio) - .. literalinclude:: ../_include/examples/tutorials/echo_client_server_tcp/async_client.py + .. literalinclude:: ../_include/examples/tutorials/echo_client_server_tcp/async_client_trio.py :linenos: :caption: client.py diff --git a/docs/source/tutorials/echo_client_server_udp.rst b/docs/source/tutorials/echo_client_server_udp.rst index b3874cb1..ffe31841 100644 --- a/docs/source/tutorials/echo_client_server_udp.rst +++ b/docs/source/tutorials/echo_client_server_udp.rst @@ -62,9 +62,15 @@ and the request handler instance. :linenos: :caption: server.py - .. group-tab:: Asynchronous + .. group-tab:: Asynchronous (asyncio) - .. literalinclude:: ../_include/examples/tutorials/echo_client_server_udp/async_server.py + .. literalinclude:: ../_include/examples/tutorials/echo_client_server_udp/async_server_asyncio.py + :linenos: + :caption: server.py + + .. group-tab:: Asynchronous (trio) + + .. literalinclude:: ../_include/examples/tutorials/echo_client_server_udp/async_server_trio.py :linenos: :caption: server.py @@ -82,9 +88,15 @@ This is the client side: :linenos: :caption: client.py - .. group-tab:: Asynchronous + .. group-tab:: Asynchronous (asyncio) + + .. literalinclude:: ../_include/examples/tutorials/echo_client_server_udp/async_client_asyncio.py + :linenos: + :caption: client.py + + .. group-tab:: Asynchronous (trio) - .. literalinclude:: ../_include/examples/tutorials/echo_client_server_udp/async_client.py + .. literalinclude:: ../_include/examples/tutorials/echo_client_server_udp/async_client_trio.py :linenos: :caption: client.py diff --git a/docs/source/tutorials/ftp_server.rst b/docs/source/tutorials/ftp_server.rst index 7ca73552..0d848c8e 100644 --- a/docs/source/tutorials/ftp_server.rst +++ b/docs/source/tutorials/ftp_server.rst @@ -194,9 +194,15 @@ Start The Server :linenos: :caption: server.py - .. group-tab:: Asynchronous + .. group-tab:: Asynchronous (asyncio) - .. literalinclude:: ../_include/examples/tutorials/ftp_server/async_server.py + .. literalinclude:: ../_include/examples/tutorials/ftp_server/async_server_asyncio.py + :linenos: + :caption: server.py + + .. group-tab:: Asynchronous (trio) + + .. literalinclude:: ../_include/examples/tutorials/ftp_server/async_server_trio.py :linenos: :caption: server.py diff --git a/pdm.lock b/pdm.lock index 340529d4..f33eb402 100644 --- a/pdm.lock +++ b/pdm.lock @@ -2,10 +2,10 @@ # It is not intended for manual editing. [metadata] -groups = ["default", "bandit", "benchmark-servers", "benchmark-servers-deps", "build", "cbor", "coverage", "dev", "doc", "flake8", "format", "micro-benchmark", "msgpack", "mypy", "pre-commit", "test", "tox", "types-msgpack", "uvloop"] +groups = ["default", "bandit", "benchmark-servers", "benchmark-servers-deps", "build", "cbor", "coverage", "dev", "doc", "flake8", "format", "micro-benchmark", "msgpack", "mypy", "pre-commit", "test", "test-trio", "tox", "trio", "types-msgpack", "uvloop"] strategy = ["cross_platform", "inherit_metadata"] lock_version = "4.4.2" -content_hash = "sha256:37cee3bc78469542dadc46333f722ba9db2ac80a461bdcd8445821b8d4c4d5b7" +content_hash = "sha256:2146760187db05141fd8cf8be7dd1ef03fb5ea3bb60078b832e1898d184bd97f" [[package]] name = "alabaster" @@ -76,6 +76,17 @@ files = [ {file = "asyncio_dgram-2.1.2-py3-none-any.whl", hash = "sha256:9ef55fc760f93c8212709329a1e28a1cf1c1f0fc8222f1be0227c2b7606a10a2"}, ] +[[package]] +name = "attrs" +version = "23.2.0" +requires_python = ">=3.7" +summary = "Classes Without Boilerplate" +groups = ["benchmark-servers-deps", "test-trio", "trio"] +files = [ + {file = "attrs-23.2.0-py3-none-any.whl", hash = "sha256:99b87a485a5820b23b879f04c2305b44b951b502fd64be915879d77a7e8fc6f1"}, + {file = "attrs-23.2.0.tar.gz", hash = "sha256:935dc3b529c262f6cf76e50877d35a4bd3c1de194fd41f47a2b7ae8f19971f30"}, +] + [[package]] name = "autodocsumm" version = "0.2.12" @@ -280,7 +291,7 @@ name = "cffi" version = "1.16.0" requires_python = ">=3.8" summary = "Foreign Function Interface for Python calling C code." -groups = ["dev", "test"] +groups = ["benchmark-servers-deps", "dev", "test", "test-trio", "trio"] dependencies = [ "pycparser", ] @@ -391,7 +402,7 @@ name = "colorama" version = "0.4.6" requires_python = "!=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7" summary = "Cross-platform colored terminal text." -groups = ["bandit", "benchmark-servers", "build", "dev", "doc", "format", "micro-benchmark", "test", "tox"] +groups = ["bandit", "benchmark-servers", "build", "dev", "doc", "format", "micro-benchmark", "test", "test-trio", "tox"] files = [ {file = "colorama-0.4.6-py2.py3-none-any.whl", hash = "sha256:4f1d9991f5acc0ca119f9d443620b77f9d6b33703e51011c16baf57afb285fc6"}, {file = "colorama-0.4.6.tar.gz", hash = "sha256:08695f5cb7ed6e0531a20572697297273c47b8cae5a63ffc6d6ed5c201be6e44"}, @@ -807,7 +818,7 @@ name = "idna" version = "3.7" requires_python = ">=3.5" summary = "Internationalized Domain Names in Applications (IDNA)" -groups = ["benchmark-servers", "dev", "doc", "test"] +groups = ["benchmark-servers", "benchmark-servers-deps", "dev", "doc", "test", "test-trio", "trio"] files = [ {file = "idna-3.7-py3-none-any.whl", hash = "sha256:82fee1fc78add43492d3a1898bfa6d8a904cc97d8427f683ed8e798d07761aa0"}, {file = "idna-3.7.tar.gz", hash = "sha256:028ff3aadf0609c1fd278d8ea3089299412a7a8b9bd005dd08b9f8285bcb5cfc"}, @@ -843,7 +854,7 @@ name = "iniconfig" version = "2.0.0" requires_python = ">=3.7" summary = "brain-dead simple config-ini parsing" -groups = ["micro-benchmark", "test"] +groups = ["micro-benchmark", "test", "test-trio"] files = [ {file = "iniconfig-2.0.0-py3-none-any.whl", hash = "sha256:b6a85871a79d2e3b22d2d1b94ac2824226a63c6b741c88f7ae975f18b6778374"}, {file = "iniconfig-2.0.0.tar.gz", hash = "sha256:2d91e135bf72d31a410b17c16da610a82cb55f6b0477d1a902134b24a455b8b3"}, @@ -1126,12 +1137,26 @@ files = [ {file = "nodeenv-1.9.1.tar.gz", hash = "sha256:6ec12890a2dab7946721edbfbcd91f3319c6ccc9aec47be7c7e6b7011ee6645f"}, ] +[[package]] +name = "outcome" +version = "1.3.0.post0" +requires_python = ">=3.7" +summary = "Capture the outcome of Python function calls." +groups = ["benchmark-servers-deps", "test-trio", "trio"] +dependencies = [ + "attrs>=19.2.0", +] +files = [ + {file = "outcome-1.3.0.post0-py2.py3-none-any.whl", hash = "sha256:e771c5ce06d1415e356078d3bdd68523f284b4ce5419828922b6871e65eda82b"}, + {file = "outcome-1.3.0.post0.tar.gz", hash = "sha256:9dcf02e65f2971b80047b377468e72a268e15c0af3cf1238e6ff14f7f91143b8"}, +] + [[package]] name = "packaging" version = "24.1" requires_python = ">=3.8" summary = "Core utilities for Python packages" -groups = ["benchmark-servers", "build", "dev", "doc", "format", "micro-benchmark", "test", "tox"] +groups = ["benchmark-servers", "build", "dev", "doc", "format", "micro-benchmark", "test", "test-trio", "tox"] files = [ {file = "packaging-24.1-py3-none-any.whl", hash = "sha256:5b8f2217dbdbd2f7f384c41c628544e6d52f2d0f53c6d0c3ea61aa5d1d7ff124"}, {file = "packaging-24.1.tar.gz", hash = "sha256:026ed72c8ed3fcce5bf8950572258698927fd1dbda10a5e981cdf0ac37f4f002"}, @@ -1203,7 +1228,7 @@ name = "pluggy" version = "1.5.0" requires_python = ">=3.8" summary = "plugin and hook calling mechanisms for python" -groups = ["dev", "micro-benchmark", "test", "tox"] +groups = ["dev", "micro-benchmark", "test", "test-trio", "tox"] files = [ {file = "pluggy-1.5.0-py3-none-any.whl", hash = "sha256:44e1ad92c8ca002de6377e165f3e0f1be63266ab4d554740532335b9d75ea669"}, {file = "pluggy-1.5.0.tar.gz", hash = "sha256:2cffa88e94fdc978c4c574f15f9e59b7f4201d439195c3715ca9e2486f1d0cf1"}, @@ -1263,7 +1288,7 @@ name = "pycparser" version = "2.22" requires_python = ">=3.8" summary = "C parser in Python" -groups = ["dev", "test"] +groups = ["benchmark-servers-deps", "dev", "test", "test-trio", "trio"] files = [ {file = "pycparser-2.22-py3-none-any.whl", hash = "sha256:c3702b6d3dd8c7abc1afa565d7e63d53a1d0bd86cdc24edd75470f4de499cfcc"}, {file = "pycparser-2.22.tar.gz", hash = "sha256:491c8be9c040f5390f5bf44a5b07752bd07f56edf992381b05c701439eec10f6"}, @@ -1345,7 +1370,7 @@ name = "pytest" version = "7.4.4" requires_python = ">=3.7" summary = "pytest: simple powerful testing with Python" -groups = ["micro-benchmark", "test"] +groups = ["micro-benchmark", "test", "test-trio"] dependencies = [ "colorama; sys_platform == \"win32\"", "iniconfig", @@ -1445,6 +1470,22 @@ dependencies = [ "pytest>=7.0.0", ] +[[package]] +name = "pytest-trio" +version = "0.8.0" +requires_python = ">=3.7" +summary = "Pytest plugin for trio" +groups = ["test-trio"] +dependencies = [ + "outcome>=1.1.0", + "pytest>=7.2.0", + "trio>=0.22.0", +] +files = [ + {file = "pytest-trio-0.8.0.tar.gz", hash = "sha256:8363db6336a79e6c53375a2123a41ddbeccc4aa93f93788651641789a56fb52e"}, + {file = "pytest_trio-0.8.0-py3-none-any.whl", hash = "sha256:e6a7e7351ae3e8ec3f4564d30ee77d1ec66e1df611226e5618dbb32f9545c841"}, +] + [[package]] name = "pytest-xdist" version = "3.6.1" @@ -1628,7 +1669,7 @@ name = "sniffio" version = "1.3.1" requires_python = ">=3.7" summary = "Sniff out which async library your code is running under" -groups = ["benchmark-servers-deps", "default", "dev"] +groups = ["benchmark-servers-deps", "default", "dev", "test-trio", "trio"] files = [ {file = "sniffio-1.3.1-py3-none-any.whl", hash = "sha256:2f6da418d1f1e0fddd844478f41680e794e6051915791a034ff65e5f100525a2"}, {file = "sniffio-1.3.1.tar.gz", hash = "sha256:f4324edc670a0f49750a81b895f35c3adb843cca46f0530f79fc1babb23789dc"}, @@ -1644,6 +1685,16 @@ files = [ {file = "snowballstemmer-2.2.0.tar.gz", hash = "sha256:09b16deb8547d3412ad7b590689584cd0fe25ec8db3be37788be3810cbf19cb1"}, ] +[[package]] +name = "sortedcontainers" +version = "2.4.0" +summary = "Sorted Containers -- Sorted List, Sorted Dict, Sorted Set" +groups = ["benchmark-servers-deps", "test-trio", "trio"] +files = [ + {file = "sortedcontainers-2.4.0-py2.py3-none-any.whl", hash = "sha256:a163dcaede0f1c021485e957a39245190e74249897e2ae4b2aa38595db237ee0"}, + {file = "sortedcontainers-2.4.0.tar.gz", hash = "sha256:25caa5a06cc30b6b83d11423433f65d1f9d76c4c6a0c90e3379eaa43b9bfdb88"}, +] + [[package]] name = "soupsieve" version = "2.5" @@ -1977,6 +2028,25 @@ files = [ {file = "tox_pdm-0.7.2.tar.gz", hash = "sha256:a841a7e1e942a71805624703b9a6d286663bd6af79bba6130ba756975c315308"}, ] +[[package]] +name = "trio" +version = "0.26.0" +requires_python = ">=3.8" +summary = "A friendly Python library for async concurrency and I/O" +groups = ["benchmark-servers-deps", "test-trio", "trio"] +dependencies = [ + "attrs>=23.2.0", + "cffi>=1.14; os_name == \"nt\" and implementation_name != \"pypy\"", + "idna", + "outcome", + "sniffio>=1.3.0", + "sortedcontainers", +] +files = [ + {file = "trio-0.26.0-py3-none-any.whl", hash = "sha256:bb9c1b259591af941fccfbabbdc65bc7ed764bd2db76428454c894cd5e3d2032"}, + {file = "trio-0.26.0.tar.gz", hash = "sha256:67c5ec3265dd4abc7b1d1ab9ca4fe4c25b896f9c93dac73713778adab487f9c4"}, +] + [[package]] name = "trove-classifiers" version = "2024.7.2" diff --git a/pyproject.toml b/pyproject.toml index 63c74a34..ffed016b 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -56,6 +56,10 @@ cbor = [ msgpack = [ "msgpack>=1.0.7,<2", ] +trio = [ + "trio>=0.26,<1", + "outcome~=1.3", +] ############################ pdm configuration ############################ @@ -106,6 +110,9 @@ test = [ # Temporary use VCS to get the modifications added on main (c.f. https://github.com/str0zzapreti/pytest-retry/pull/39) "pytest-retry @ git+https://github.com/str0zzapreti/pytest-retry.git@bb465fff6f01f3f90a77229468f7e08a3bdbce20", ] +test-trio = [ + "pytest-trio~=0.8.0", +] coverage = [ "coverage~=7.0", ] @@ -126,7 +133,7 @@ benchmark-servers = [ "plotly~=5.18", ] benchmark-servers-deps = [ - "easynetwork[uvloop]", + "easynetwork[uvloop, trio]", "asyncio-dgram==2.1.2", ] @@ -187,8 +194,9 @@ __version_tuple__ = {version_tuple!r} profile = "black" line_length = 130 combine_as_imports = true -sections = ["FUTURE", "STDLIB", "EASYNETWORK", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] +sections = ["FUTURE", "STDLIB", "TRIO", "EASYNETWORK", "THIRDPARTY", "FIRSTPARTY", "LOCALFOLDER"] known_easynetwork = ["easynetwork"] +known_trio = ["trio", "outcome"] add_imports = ["from __future__ import annotations"] extend_skip = [ "docs/source/conf.py", @@ -222,7 +230,7 @@ local_partial_types = true enable_error_code = ["truthy-bool", "ignore-without-code", "unused-awaitable"] [[tool.mypy.overrides]] -module = ["docker.*", "plotly.*", "pytest_benchmark.*"] +module = ["docker.*", "plotly.*", "pytest_benchmark.*", "pytest_trio.*"] ignore_missing_imports = true [tool.pytest.ini_options] diff --git a/src/easynetwork/lowlevel/_utils.py b/src/easynetwork/lowlevel/_utils.py index f1cb46a3..1d1a289a 100644 --- a/src/easynetwork/lowlevel/_utils.py +++ b/src/easynetwork/lowlevel/_utils.py @@ -32,11 +32,13 @@ "lock_with_timeout", "make_callback", "missing_extra_deps", + "open_listener_sockets_from_getaddrinfo_result", "prepend_argument", "remove_traceback_frames_in_place", "replace_kwargs", "set_reuseport", "supports_socket_sendmsg", + "validate_listener_hosts", "validate_timeout_delay", ] @@ -50,7 +52,7 @@ import time from abc import abstractmethod from collections import deque -from collections.abc import Callable, Iterable, Iterator +from collections.abc import Callable, Iterable, Iterator, Sequence from typing import TYPE_CHECKING, Any, Concatenate, ParamSpec, Protocol, Self, TypeGuard, TypeVar, overload try: @@ -286,6 +288,18 @@ def set_reuseport(sock: SupportsSocketOptions) -> None: raise ValueError("reuse_port not supported by socket module") +def validate_listener_hosts(host: str | Sequence[str] | None) -> list[str | None]: + match host: + case "" | None: + return [None] + case str(): + return [host] + case _ if all(isinstance(h, str) for h in host): + return list(host) + case _: + raise TypeError(host) + + def open_listener_sockets_from_getaddrinfo_result( infos: Iterable[tuple[int, int, int, str, tuple[Any, ...]]], *, diff --git a/src/easynetwork/lowlevel/api_async/backend/_asyncio/_asyncio_utils.py b/src/easynetwork/lowlevel/api_async/backend/_asyncio/_asyncio_utils.py index 4af7050e..e12a8c16 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_asyncio/_asyncio_utils.py +++ b/src/easynetwork/lowlevel/api_async/backend/_asyncio/_asyncio_utils.py @@ -18,266 +18,12 @@ from __future__ import annotations __all__ = [ - "create_connection", - "create_datagram_connection", - "ensure_resolved", - "resolve_local_addresses", "wait_until_readable", "wait_until_writable", ] import asyncio -import itertools -import math import socket as _socket -from collections import OrderedDict -from collections.abc import Sequence -from typing import Any, cast - - -async def ensure_resolved( - host: str | None, - port: int, - family: int, - type: int, - proto: int = 0, - flags: int = 0, -) -> Sequence[tuple[int, int, int, str, tuple[Any, ...]]]: - try: - info = _socket.getaddrinfo( - host, port, family=family, type=type, proto=proto, flags=flags | _socket.AI_NUMERICHOST | _socket.AI_NUMERICSERV - ) - except _socket.gaierror as exc: - if exc.errno != _socket.EAI_NONAME: - raise - loop = asyncio.get_running_loop() - info = await loop.getaddrinfo(host, port, family=family, type=type, proto=proto, flags=flags) - if not info: - raise OSError(f"getaddrinfo({host!r}) returned empty list") - return info - - -async def resolve_local_addresses( - hosts: Sequence[str | None], - port: int, - socktype: int, -) -> Sequence[tuple[int, int, int, str, tuple[Any, ...]]]: - infos: set[tuple[int, int, int, str, tuple[Any, ...]]] = set( - itertools.chain.from_iterable( - await asyncio.gather( - *[ - ensure_resolved( - host, - port, - _socket.AF_UNSPEC, - socktype, - flags=_socket.AI_PASSIVE | _socket.AI_ADDRCONFIG, - ) - for host in hosts - ] - ) - ) - ) - return sorted(infos) - - -async def _create_connection_impl( - *, - remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]], - local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None, -) -> _socket.socket: - loop = asyncio.get_running_loop() - errors: list[OSError] = [] - for family, socktype, proto, _, remote_sockaddr in remote_addrinfo: - try: - socket = _socket.socket(family, socktype, proto) - except OSError as exc: - errors.append(exc) - continue - except BaseException: - errors.clear() - raise - try: - socket.setblocking(False) - - if local_addrinfo is not None: - bind_errors: list[OSError] = [] - try: - for lfamily, _, _, _, local_sockaddr in local_addrinfo: - # skip local addresses of different family - if lfamily != family: - continue - try: - socket.bind(local_sockaddr) - break - except OSError as exc: - msg = f"error while attempting to bind on address {local_sockaddr!r}: {exc.strerror.lower()}" - bind_errors.append(OSError(exc.errno, msg).with_traceback(exc.__traceback__)) - else: # all bind attempts failed - if bind_errors: - socket.close() - errors.extend(bind_errors) - continue - raise OSError(f"no matching local address with {family=} found") - finally: - bind_errors.clear() - del bind_errors - - await loop.sock_connect(socket, remote_sockaddr) - errors.clear() - return socket - except OSError as exc: - socket.close() - errors.append(exc) - continue - except BaseException: - errors.clear() - socket.close() - raise - - assert errors # nosec assert_used - try: - raise ExceptionGroup("create_connection() failed", errors) - finally: - errors.clear() - - -# Taken from asyncio library (https://github.com/python/cpython/tree/v3.12.0/Lib/asyncio) -def _interleave_addrinfos( - addrinfos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] -) -> list[tuple[int, int, int, str, tuple[Any, ...]]]: - """Interleave list of addrinfo tuples by family.""" - # Group addresses by family - addrinfos_by_family: OrderedDict[int, list[tuple[Any, ...]]] = OrderedDict() - for addr in addrinfos: - family = addr[0] - if family not in addrinfos_by_family: - addrinfos_by_family[family] = [] - addrinfos_by_family[family].append(addr) - addrinfos_lists = list(addrinfos_by_family.values()) - return [addr for addr in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists)) if addr is not None] - - -# Taken from anyio project (https://github.com/agronholm/anyio/tree/4.2.0) -def _prioritize_ipv6_over_ipv4( - addrinfos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] -) -> list[tuple[int, int, int, str, tuple[Any, ...]]]: - # Organize the list so that the first address is an IPv6 address (if available) - # and the second one is an IPv4 addresses. The rest can be in whatever order. - v6_found = v4_found = False - reordered: list[tuple[int, int, int, str, tuple[Any, ...]]] = [] - for addr in addrinfos: - family = addr[0] - if family == _socket.AF_INET6 and not v6_found: - v6_found = True - reordered.insert(0, addr) - elif family == _socket.AF_INET and not v4_found and v6_found: - v4_found = True - reordered.insert(1, addr) - else: - reordered.append(addr) - return reordered - - -async def _staggered_race_connection_impl( - *, - remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]], - local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None, - happy_eyeballs_delay: float, -) -> _socket.socket: - from .tasks import CancelScope - - remote_addrinfo = _interleave_addrinfos(_prioritize_ipv6_over_ipv4(remote_addrinfo)) - winner: _socket.socket | None = cast(_socket.socket | None, None) - errors: list[OSError | BaseExceptionGroup[OSError]] = [] - - async def try_connect(addr: tuple[int, int, int, str, tuple[Any, ...]]) -> None: - nonlocal winner - try: - socket = await _create_connection_impl(remote_addrinfo=[addr], local_addrinfo=local_addrinfo) - except* OSError as excgrp: - errors.extend(excgrp.exceptions) - else: - if winner is None: - winner = socket - connection_scope.cancel() - else: - socket.close() - - try: - with CancelScope() as connection_scope: - async with asyncio.TaskGroup() as task_group: - for addr in remote_addrinfo: - await asyncio.wait({task_group.create_task(try_connect(addr))}, timeout=happy_eyeballs_delay) - - if winner is None: - raise BaseExceptionGroup("create_connection() failed", errors) - return winner - except BaseException: - if winner is not None: - winner.close() - raise - finally: - errors.clear() - - -async def create_connection( - host: str, - port: int, - *, - local_address: tuple[str, int] | None = None, - happy_eyeballs_delay: float = math.inf, -) -> _socket.socket: - remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await ensure_resolved( - host, - port, - family=_socket.AF_UNSPEC, - type=_socket.SOCK_STREAM, - ) - local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None = None - if local_address is not None: - local_host, local_port = local_address - local_addrinfo = await ensure_resolved( - local_host, - local_port, - family=_socket.AF_UNSPEC, - type=_socket.SOCK_STREAM, - ) - - return await _staggered_race_connection_impl( - remote_addrinfo=remote_addrinfo, - local_addrinfo=local_addrinfo, - happy_eyeballs_delay=happy_eyeballs_delay, - ) - - -async def create_datagram_connection( - host: str, - port: int, - *, - local_address: tuple[str, int] | None = None, - family: int = _socket.AF_UNSPEC, -) -> _socket.socket: - remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await ensure_resolved( - host, - port, - family=family, - type=_socket.SOCK_DGRAM, - ) - local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None = None - if local_address is not None: - local_host, local_port = local_address - local_addrinfo = await ensure_resolved( - local_host, - local_port, - family=family, - type=_socket.SOCK_DGRAM, - ) - - return await _create_connection_impl( - remote_addrinfo=remote_addrinfo, - local_addrinfo=local_addrinfo, - ) def wait_until_readable(sock: _socket.socket, loop: asyncio.AbstractEventLoop) -> asyncio.Future[None]: diff --git a/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py b/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py index 92ad3ff3..c6c6b2fa 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py +++ b/src/easynetwork/lowlevel/api_async/backend/_asyncio/backend.py @@ -26,14 +26,13 @@ import socket as _socket import sys from collections.abc import Awaitable, Callable, Coroutine, Mapping, Sequence -from typing import Any, NoReturn, ParamSpec, TypeVar, TypeVarTuple +from typing import Any, NoReturn, TypeVar, TypeVarTuple from .... import _utils from ....constants import HAPPY_EYEBALLS_DELAY as _DEFAULT_HAPPY_EYEBALLS_DELAY from ...transports.abc import AsyncDatagramListener, AsyncDatagramTransport, AsyncListener, AsyncStreamTransport from ..abc import AsyncBackend as AbstractAsyncBackend, CancelScope, ICondition, IEvent, ILock, TaskGroup, TaskInfo, ThreadsPortal -_P = ParamSpec("_P") _T = TypeVar("_T") _T_co = TypeVar("_T_co", covariant=True) _T_PosArgs = TypeVarTuple("_T_PosArgs") @@ -45,11 +44,13 @@ class AsyncIOBackend(AbstractAsyncBackend): "__coro_yield", "__cancel_shielded_coro_yield", "__cancel_shielded_await", + "__dns_resolver", ) def __init__(self) -> None: import asyncio + from .dns_resolver import AsyncIODNSResolver from .tasks import TaskUtils self.__asyncio = asyncio @@ -58,6 +59,11 @@ def __init__(self) -> None: self.__cancel_shielded_coro_yield = TaskUtils.cancel_shielded_coro_yield self.__cancel_shielded_await = TaskUtils.cancel_shielded_await + self.__dns_resolver = AsyncIODNSResolver() + + def __repr__(self) -> str: + return f"<{type(self).__qualname__} object at {id(self):#x}>" + def bootstrap( self, coro_func: Callable[[*_T_PosArgs], Coroutine[Any, Any, _T]], @@ -113,6 +119,31 @@ def get_current_task(self) -> TaskInfo: current_task = TaskUtils.current_asyncio_task() return TaskUtils.create_task_info(current_task) + async def getaddrinfo( + self, + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> Sequence[tuple[int, int, int, str, tuple[str, int] | tuple[str, int, int, int]]]: + loop = self.__asyncio.get_running_loop() + + return await loop.getaddrinfo( + host, + port, + family=family, + type=type, + proto=proto, + flags=flags, + ) + + async def getnameinfo(self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int = 0) -> tuple[str, str]: + loop = self.__asyncio.get_running_loop() + + return await loop.getnameinfo(sockaddr, flags) + async def create_tcp_connection( self, host: str, @@ -124,9 +155,8 @@ async def create_tcp_connection( if happy_eyeballs_delay is None: happy_eyeballs_delay = _DEFAULT_HAPPY_EYEBALLS_DELAY - from ._asyncio_utils import create_connection - - socket = await create_connection( + socket = await self.__dns_resolver.create_stream_connection( + self, host, port, local_address=local_address, @@ -157,21 +187,15 @@ async def create_tcp_listeners( if not isinstance(backlog, int): raise TypeError("backlog: Expected an integer") - from ._asyncio_utils import resolve_local_addresses from .stream.listener import AcceptedSocketFactory, ListenerSocketAdapter reuse_address: bool = os.name not in ("nt", "cygwin") and sys.platform != "cygwin" - hosts: Sequence[str | None] - if host == "" or host is None: - hosts = [None] - elif isinstance(host, str): - hosts = [host] - else: - hosts = host + hosts: Sequence[str | None] = _utils.validate_listener_hosts(host) del host - infos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await resolve_local_addresses( + infos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await self.__dns_resolver.resolve_listener_addresses( + self, hosts, port, _socket.SOCK_STREAM, @@ -196,9 +220,8 @@ async def create_udp_endpoint( local_address: tuple[str, int] | None = None, family: int = _socket.AF_UNSPEC, ) -> AsyncDatagramTransport: - from ._asyncio_utils import create_datagram_connection - - socket = await create_datagram_connection( + socket = await self.__dns_resolver.create_datagram_connection( + self, remote_host, remote_port, local_address=local_address, @@ -221,22 +244,16 @@ async def create_udp_listeners( *, reuse_port: bool = False, ) -> Sequence[AsyncDatagramListener[tuple[Any, ...]]]: - from ._asyncio_utils import resolve_local_addresses from .datagram.listener import DatagramListenerProtocol, DatagramListenerSocketAdapter loop = self.__asyncio.get_running_loop() - hosts: Sequence[str | None] - if host == "" or host is None: - hosts = [None] - elif isinstance(host, str): - hosts = [host] - else: - hosts = host + hosts: Sequence[str | None] = _utils.validate_listener_hosts(host) del host - infos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await resolve_local_addresses( + infos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await self.__dns_resolver.resolve_listener_addresses( + self, hosts, port, _socket.SOCK_DGRAM, diff --git a/src/easynetwork/lowlevel/api_async/backend/_asyncio/dns_resolver.py b/src/easynetwork/lowlevel/api_async/backend/_asyncio/dns_resolver.py new file mode 100644 index 00000000..c387f2cf --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_asyncio/dns_resolver.py @@ -0,0 +1,36 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""asyncio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = ["AsyncIODNSResolver"] + +import asyncio +import socket as _socket +from typing import final + +from .._common.dns_resolver import BaseAsyncDNSResolver + + +@final +class AsyncIODNSResolver(BaseAsyncDNSResolver): + __slots__ = () + + async def connect_socket(self, socket: _socket.socket, address: tuple[str, int] | tuple[str, int, int, int]) -> None: + loop = asyncio.get_running_loop() + + await loop.sock_connect(socket, address) diff --git a/src/easynetwork/lowlevel/api_async/backend/_asyncio/threads.py b/src/easynetwork/lowlevel/api_async/backend/_asyncio/threads.py index 6163c528..571873f5 100644 --- a/src/easynetwork/lowlevel/api_async/backend/_asyncio/threads.py +++ b/src/easynetwork/lowlevel/api_async/backend/_asyncio/threads.py @@ -102,9 +102,7 @@ async def coroutine(waiter: asyncio.Future[None]) -> None: except BaseException as exc: if future.set_running_or_notify_cancel(): future.set_exception(exc) - if not isinstance(exc, Exception): - raise - elif future.cancelled(): + else: loop = asyncio.get_running_loop() loop.call_soon( loop.call_exception_handler, @@ -118,13 +116,13 @@ async def coroutine(waiter: asyncio.Future[None]) -> None: if future.set_running_or_notify_cancel(): future.set_result(result) - def schedule_task() -> concurrent.futures.Future[_T]: + def schedule_task() -> None: loop = asyncio.get_running_loop() waiter = self.__register_waiter(self.__call_soon_waiters, loop) _ = self.__task_group.create_task(coroutine(waiter), name=TaskUtils.compute_task_name_from_func(coro_func)) - return future - return self.run_sync(schedule_task) + self.run_sync_soon(schedule_task) + return future def run_sync_soon(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> concurrent.futures.Future[_T]: import sniffio @@ -142,8 +140,6 @@ def callback() -> None: raise _utils.exception_with_notes(TypeError(msg), note) except BaseException as exc: future.set_exception(exc) - if not isinstance(exc, Exception): - raise else: future.set_result(result) diff --git a/src/easynetwork/lowlevel/api_async/backend/_common/__init__.py b/src/easynetwork/lowlevel/api_async/backend/_common/__init__.py new file mode 100644 index 00000000..b77c3697 --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_common/__init__.py @@ -0,0 +1,19 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""Common utility functions and classes for asynchronous backends.""" + +from __future__ import annotations + +__all__ = [] # type: list[str] diff --git a/src/easynetwork/lowlevel/api_async/backend/_common/dns_resolver.py b/src/easynetwork/lowlevel/api_async/backend/_common/dns_resolver.py new file mode 100644 index 00000000..5913afb3 --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_common/dns_resolver.py @@ -0,0 +1,297 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""DNS Resolver module.""" + +from __future__ import annotations + +__all__ = ["BaseAsyncDNSResolver"] + +import itertools +import math +import socket as _socket +from abc import ABCMeta, abstractmethod +from collections import OrderedDict +from collections.abc import Sequence +from typing import Any, cast + +from ..abc import AsyncBackend, IEvent + + +class BaseAsyncDNSResolver(metaclass=ABCMeta): + __slots__ = () + + @abstractmethod + async def connect_socket(self, socket: _socket.socket, address: tuple[str, int] | tuple[str, int, int, int]) -> None: + raise NotImplementedError + + async def ensure_resolved( + self, + backend: AsyncBackend, + host: str | None, + port: int, + family: int, + type: int, + proto: int = 0, + flags: int = 0, + ) -> Sequence[tuple[int, int, int, str, tuple[str, int] | tuple[str, int, int, int]]]: + info: Sequence[tuple[int, int, int, str, tuple[str, int] | tuple[str, int, int, int]]] + try: + info = _socket.getaddrinfo( + host, port, family=family, type=type, proto=proto, flags=flags | _socket.AI_NUMERICHOST | _socket.AI_NUMERICSERV + ) + except _socket.gaierror as exc: + if exc.errno != _socket.EAI_NONAME: + raise + info = await backend.getaddrinfo(host, port, family=family, type=type, proto=proto, flags=flags) + if not info: + raise OSError(f"getaddrinfo({host!r}) returned empty list") + return info + + async def resolve_listener_addresses( + self, + backend: AsyncBackend, + hosts: Sequence[str | None], + port: int, + socktype: int, + ) -> Sequence[tuple[int, int, int, str, tuple[str, int] | tuple[str, int, int, int]]]: + infos = set( + itertools.chain.from_iterable( + await backend.gather( + *[ + self.ensure_resolved( + backend, + host, + port, + _socket.AF_UNSPEC, + socktype, + flags=_socket.AI_PASSIVE | _socket.AI_ADDRCONFIG, + ) + for host in hosts + ] + ) + ) + ) + return sorted(infos) + + async def create_stream_connection( + self, + backend: AsyncBackend, + host: str, + port: int, + *, + local_address: tuple[str, int] | None = None, + happy_eyeballs_delay: float = math.inf, + ) -> _socket.socket: + remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await self.ensure_resolved( + backend, + host, + port, + family=_socket.AF_UNSPEC, + type=_socket.SOCK_STREAM, + ) + local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None = None + if local_address is not None: + local_host, local_port = local_address + local_addrinfo = await self.ensure_resolved( + backend, + local_host, + local_port, + family=_socket.AF_UNSPEC, + type=_socket.SOCK_STREAM, + ) + + return await self._staggered_race_connection_impl( + backend=backend, + remote_addrinfo=remote_addrinfo, + local_addrinfo=local_addrinfo, + happy_eyeballs_delay=happy_eyeballs_delay, + ) + + async def create_datagram_connection( + self, + backend: AsyncBackend, + host: str, + port: int, + *, + local_address: tuple[str, int] | None = None, + family: int = _socket.AF_UNSPEC, + ) -> _socket.socket: + remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await self.ensure_resolved( + backend, + host, + port, + family=family, + type=_socket.SOCK_DGRAM, + ) + local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None = None + if local_address is not None: + local_host, local_port = local_address + local_addrinfo = await self.ensure_resolved( + backend, + local_host, + local_port, + family=family, + type=_socket.SOCK_DGRAM, + ) + + return await self._create_connection_impl( + remote_addrinfo=remote_addrinfo, + local_addrinfo=local_addrinfo, + ) + + async def _create_connection_impl( + self, + *, + remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]], + local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None, + ) -> _socket.socket: + errors: list[OSError] = [] + for family, socktype, proto, _, remote_sockaddr in remote_addrinfo: + try: + socket = _socket.socket(family, socktype, proto) + except OSError as exc: + errors.append(exc) + continue + except BaseException: + errors.clear() + raise + try: + socket.setblocking(False) + + if local_addrinfo is not None: + bind_errors: list[OSError] = [] + try: + for lfamily, _, _, _, local_sockaddr in local_addrinfo: + # skip local addresses of different family + if lfamily != family: + continue + try: + socket.bind(local_sockaddr) + break + except OSError as exc: + msg = f"error while attempting to bind on address {local_sockaddr!r}: {exc.strerror.lower()}" + bind_errors.append(OSError(exc.errno, msg).with_traceback(exc.__traceback__)) + else: # all bind attempts failed + if bind_errors: + socket.close() + errors.extend(bind_errors) + continue + raise OSError(f"no matching local address with {family=} found") + finally: + bind_errors.clear() + del bind_errors + + await self.connect_socket(socket, remote_sockaddr) + errors.clear() + return socket + except OSError as exc: + socket.close() + errors.append(exc) + continue + except BaseException: + errors.clear() + socket.close() + raise + + assert errors # nosec assert_used + try: + raise ExceptionGroup("create_connection() failed", errors) + finally: + errors.clear() + + # Taken from anyio project (https://github.com/agronholm/anyio/tree/4.2.0) + async def _staggered_race_connection_impl( + self, + backend: AsyncBackend, + *, + remote_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]], + local_addrinfo: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] | None, + happy_eyeballs_delay: float, + ) -> _socket.socket: + remote_addrinfo = _interleave_addrinfos(_prioritize_ipv6_over_ipv4(remote_addrinfo)) + winner: _socket.socket | None = cast(_socket.socket | None, None) + errors: list[OSError | BaseExceptionGroup[OSError]] = [] + + async def try_connect(addr: tuple[int, int, int, str, tuple[Any, ...]], done: IEvent) -> None: + nonlocal winner + try: + socket = await self._create_connection_impl(remote_addrinfo=[addr], local_addrinfo=local_addrinfo) + except* OSError as excgrp: + errors.extend(excgrp.exceptions) + else: + if winner is None: + winner = socket + connection_scope.cancel() + else: + socket.close() + finally: + done.set() + del done + + try: + with backend.open_cancel_scope() as connection_scope: + async with backend.create_task_group() as task_group: + for addr in remote_addrinfo: + done = backend.create_event() + task_group.start_soon(try_connect, addr, done) + with backend.move_on_after(happy_eyeballs_delay): + await done.wait() + + if winner is None: + raise BaseExceptionGroup("create_connection() failed", errors) + return winner + except BaseException: + if winner is not None: + winner.close() + raise + finally: + errors.clear() + + +# Taken from asyncio library (https://github.com/python/cpython/tree/v3.12.0/Lib/asyncio) +def _interleave_addrinfos( + addrinfos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] +) -> list[tuple[int, int, int, str, tuple[Any, ...]]]: + """Interleave list of addrinfo tuples by family.""" + # Group addresses by family + addrinfos_by_family: OrderedDict[int, list[tuple[Any, ...]]] = OrderedDict() + for addr in addrinfos: + family = addr[0] + if family not in addrinfos_by_family: + addrinfos_by_family[family] = [] + addrinfos_by_family[family].append(addr) + addrinfos_lists = list(addrinfos_by_family.values()) + return [addr for addr in itertools.chain.from_iterable(itertools.zip_longest(*addrinfos_lists)) if addr is not None] + + +# Taken from anyio project (https://github.com/agronholm/anyio/tree/4.2.0) +def _prioritize_ipv6_over_ipv4( + addrinfos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] +) -> list[tuple[int, int, int, str, tuple[Any, ...]]]: + # Organize the list so that the first address is an IPv6 address (if available) + # and the second one is an IPv4 addresses. The rest can be in whatever order. + v6_found = v4_found = False + reordered: list[tuple[int, int, int, str, tuple[Any, ...]]] = [] + for addr in addrinfos: + family = addr[0] + if family == _socket.AF_INET6 and not v6_found: + v6_found = True + reordered.insert(0, addr) + elif family == _socket.AF_INET and not v4_found and v6_found: + v4_found = True + reordered.insert(1, addr) + else: + reordered.append(addr) + return reordered diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/__init__.py b/src/easynetwork/lowlevel/api_async/backend/_trio/__init__.py new file mode 100644 index 00000000..344d3a9c --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = [] # type: list[str] diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/_trio_utils.py b/src/easynetwork/lowlevel/api_async/backend/_trio/_trio_utils.py new file mode 100644 index 00000000..bbf30d11 --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/_trio_utils.py @@ -0,0 +1,77 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = ["convert_trio_resource_errors"] + +import contextlib +import errno as _errno +import types + +import trio + +from .... import _utils + + +class convert_trio_resource_errors(contextlib.AbstractContextManager[None]): + def __init__(self, *, broken_resource_errno: int) -> None: + self.__broken_resource_errno: int = broken_resource_errno + + def __enter__(self) -> None: + return + + def __exit__( + self, + exc_type: type[BaseException] | None, + exc_value: BaseException | None, + traceback: types.TracebackType | None, + ) -> None: + if exc_type is None: + return + + if exc_value is None: + exc_value = exc_type() # pragma: no cover + + try: + if issubclass(exc_type, trio.ClosedResourceError): + raise self.__get_error_from_cause(exc_value, _errno.EBADF) + if issubclass(exc_type, trio.BrokenResourceError): + raise self.__get_error_from_cause(exc_value, self.__broken_resource_errno) + if issubclass(exc_type, trio.BusyResourceError): + raise self.__get_error_from_cause(exc_value, _errno.EBUSY) + except BaseException as new_exc: + _utils.remove_traceback_frames_in_place(new_exc, 1) + raise + finally: + del exc_value, traceback + + @staticmethod + def __get_error_from_cause( + exc_value: BaseException, + fallback_errno: int, + ) -> OSError: + match exc_value.__cause__: + case OSError() as error: + error.__cause__ = None + error.__suppress_context__ = True + return error + case _: + error = _utils.error_from_errno(fallback_errno) + error.__cause__ = exc_value + error.__suppress_context__ = True + return error.with_traceback(None) diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/backend.py b/src/easynetwork/lowlevel/api_async/backend/_trio/backend.py new file mode 100644 index 00000000..429654f3 --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/backend.py @@ -0,0 +1,278 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = ["TrioBackend"] + +import math +import os +import socket as _socket +import sys +from collections.abc import Awaitable, Callable, Coroutine, Mapping, Sequence +from typing import Any, NoReturn, TypeVar, TypeVarTuple + +from .... import _utils +from ....constants import HAPPY_EYEBALLS_DELAY as _DEFAULT_HAPPY_EYEBALLS_DELAY +from ...transports.abc import AsyncDatagramListener, AsyncDatagramTransport, AsyncListener, AsyncStreamTransport +from ..abc import AsyncBackend as AbstractAsyncBackend, CancelScope, ICondition, IEvent, ILock, TaskGroup, TaskInfo, ThreadsPortal + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_T_PosArgs = TypeVarTuple("_T_PosArgs") + + +class TrioBackend(AbstractAsyncBackend): + __slots__ = ( + "__trio", + "__dns_resolver", + ) + + def __init__(self) -> None: + try: + import trio + except ModuleNotFoundError as exc: + raise _utils.missing_extra_deps("trio") from exc + + from .dns_resolver import TrioDNSResolver + + self.__trio = trio + self.__dns_resolver = TrioDNSResolver() + + def __repr__(self) -> str: + return f"<{type(self).__qualname__} object at {id(self):#x}>" + + def bootstrap( + self, + coro_func: Callable[[*_T_PosArgs], Coroutine[Any, Any, _T]], + *args: *_T_PosArgs, + runner_options: Mapping[str, Any] | None = None, + ) -> _T: + runner_options = runner_options or {} + return self.__trio.run(coro_func, *args, **runner_options) + + async def coro_yield(self) -> None: + await self.__trio.lowlevel.checkpoint() + + async def cancel_shielded_coro_yield(self) -> None: + await self.__trio.lowlevel.cancel_shielded_checkpoint() + + def get_cancelled_exc_class(self) -> type[BaseException]: + return self.__trio.Cancelled + + async def ignore_cancellation(self, coroutine: Awaitable[_T_co]) -> _T_co: + with self.__trio.CancelScope(shield=True): + try: + return await coroutine + finally: + del coroutine + raise AssertionError("Expected code to be unreachable") + + def open_cancel_scope(self, *, deadline: float = math.inf) -> CancelScope: + from .tasks import CancelScope + + return CancelScope(deadline=deadline) + + def current_time(self) -> float: + return self.__trio.current_time() + + async def sleep(self, delay: float) -> None: + await self.__trio.sleep(delay) + + async def sleep_forever(self) -> NoReturn: + await self.__trio.sleep_forever() + raise AssertionError("Expected code to be unreachable") + + def create_task_group(self) -> TaskGroup: + from .tasks import TaskGroup + + return TaskGroup() + + def get_current_task(self) -> TaskInfo: + from .tasks import TaskUtils + + current_task = self.__trio.lowlevel.current_task() + return TaskUtils.create_task_info(current_task) + + async def getaddrinfo( + self, + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> Sequence[tuple[int, int, int, str, tuple[str, int] | tuple[str, int, int, int]]]: + return await self.__trio.socket.getaddrinfo( + host, + port, + family=family, + type=type, + proto=proto, + flags=flags, + ) + + async def getnameinfo(self, sockaddr: tuple[str, int] | tuple[str, int, int, int], flags: int = 0) -> tuple[str, str]: + return await self.__trio.socket.getnameinfo(sockaddr, flags) + + async def create_tcp_connection( + self, + host: str, + port: int, + *, + local_address: tuple[str, int] | None = None, + happy_eyeballs_delay: float | None = None, + ) -> AsyncStreamTransport: + if happy_eyeballs_delay is None: + happy_eyeballs_delay = _DEFAULT_HAPPY_EYEBALLS_DELAY + + socket = await self.__dns_resolver.create_stream_connection( + self, + host, + port, + local_address=local_address, + happy_eyeballs_delay=happy_eyeballs_delay, + ) + + return await self.wrap_stream_socket(socket) + + async def wrap_stream_socket(self, socket: _socket.socket) -> AsyncStreamTransport: + from .stream.socket import TrioStreamSocketAdapter + + _utils.check_socket_no_ssl(socket) + trio_socket = self.__trio.socket.from_stdlib_socket(socket) + trio_stream = self.__trio.SocketStream(trio_socket) + + return TrioStreamSocketAdapter(self, trio_stream) + + async def create_tcp_listeners( + self, + host: str | Sequence[str] | None, + port: int, + backlog: int, + *, + reuse_port: bool = False, + ) -> Sequence[AsyncListener[AsyncStreamTransport]]: + from .stream.listener import TrioListenerSocketAdapter + + reuse_address: bool = os.name not in ("nt", "cygwin") and sys.platform != "cygwin" + hosts: Sequence[str | None] = _utils.validate_listener_hosts(host) + + del host + + infos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await self.__dns_resolver.resolve_listener_addresses( + self, + hosts, + port, + _socket.SOCK_STREAM, + ) + + sockets: list[_socket.socket] = _utils.open_listener_sockets_from_getaddrinfo_result( + infos, + backlog=backlog, + reuse_address=reuse_address, + reuse_port=reuse_port, + ) + + listeners = [ + TrioListenerSocketAdapter(self, self.__trio.SocketListener(sock)) + for sock in map(self.__trio.socket.from_stdlib_socket, sockets) + ] + return listeners + + async def create_udp_endpoint( + self, + remote_host: str, + remote_port: int, + *, + local_address: tuple[str, int] | None = None, + family: int = _socket.AF_UNSPEC, + ) -> AsyncDatagramTransport: + socket = await self.__dns_resolver.create_datagram_connection( + self, + remote_host, + remote_port, + local_address=local_address, + family=family, + ) + return await self.wrap_connected_datagram_socket(socket) + + async def wrap_connected_datagram_socket(self, socket: _socket.socket) -> AsyncDatagramTransport: + from .datagram.socket import TrioDatagramSocketAdapter + + _utils.check_socket_no_ssl(socket) + trio_socket = self.__trio.socket.from_stdlib_socket(socket) + + return TrioDatagramSocketAdapter(self, trio_socket) + + async def create_udp_listeners( + self, + host: str | Sequence[str] | None, + port: int, + *, + reuse_port: bool = False, + ) -> Sequence[AsyncDatagramListener[tuple[Any, ...]]]: + from .datagram.listener import TrioDatagramListenerSocketAdapter + + hosts: Sequence[str | None] = _utils.validate_listener_hosts(host) + + del host + + infos: Sequence[tuple[int, int, int, str, tuple[Any, ...]]] = await self.__dns_resolver.resolve_listener_addresses( + self, + hosts, + port, + _socket.SOCK_DGRAM, + ) + + sockets: list[_socket.socket] = _utils.open_listener_sockets_from_getaddrinfo_result( + infos, + backlog=None, + reuse_address=False, + reuse_port=reuse_port, + ) + + listeners = [ + TrioDatagramListenerSocketAdapter(self, sock) for sock in map(self.__trio.socket.from_stdlib_socket, sockets) + ] + return listeners + + def create_lock(self) -> ILock: + return self.__trio.Lock() + + def create_event(self) -> IEvent: + return self.__trio.Event() + + def create_condition_var(self, lock: ILock | None = None) -> ICondition: + if lock is not None: + assert isinstance(lock, self.__trio.Lock) # nosec assert_used + + return self.__trio.Condition(lock) + + async def run_in_thread( + self, + func: Callable[[*_T_PosArgs], _T], + /, + *args: *_T_PosArgs, + abandon_on_cancel: bool = False, + ) -> _T: + return await self.__trio.to_thread.run_sync(func, *args, abandon_on_cancel=abandon_on_cancel) + + def create_threads_portal(self) -> ThreadsPortal: + from .threads import ThreadsPortal + + return ThreadsPortal() diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/__init__.py b/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/__init__.py new file mode 100644 index 00000000..344d3a9c --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = [] # type: list[str] diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/listener.py b/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/listener.py new file mode 100644 index 00000000..7404c432 --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/listener.py @@ -0,0 +1,103 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = ["TrioDatagramListenerSocketAdapter"] + +import contextlib +import socket as _socket +import warnings +from collections.abc import Callable, Coroutine, Mapping +from typing import Any, NoReturn, final + +import trio + +from ..... import _utils, socket as socket_tools +from ....transports.abc import AsyncDatagramListener +from ...abc import AsyncBackend, TaskGroup + + +@final +class TrioDatagramListenerSocketAdapter(AsyncDatagramListener[tuple[Any, ...]]): + __slots__ = ( + "__backend", + "__listener", + "__trsock", + "__serve_guard", + ) + + from .....constants import MAX_DATAGRAM_BUFSIZE + + def __init__(self, backend: AsyncBackend, sock: trio.socket.SocketType) -> None: + super().__init__() + + if sock.type != _socket.SOCK_DGRAM: + raise ValueError("A 'SOCK_DGRAM' socket is expected") + + self.__backend: AsyncBackend = backend + self.__listener: trio.socket.SocketType = sock + self.__trsock: socket_tools.SocketProxy = socket_tools.SocketProxy(sock) + self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard(f"{self.__class__.__name__}.serve() awaited twice.") + + def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: + try: + listener = self.__listener + except AttributeError: + listener = None + if listener is not None and listener.fileno() >= 0: + _warn(f"unclosed listener {self!r}", ResourceWarning, source=self) + listener.close() + + async def aclose(self) -> None: + self.__listener.close() + await trio.lowlevel.checkpoint() + + def is_closing(self) -> bool: + return self.__listener.fileno() < 0 + + async def serve( + self, + handler: Callable[[bytes, tuple[Any, ...]], Coroutine[Any, Any, None]], + task_group: TaskGroup | None = None, + ) -> NoReturn: + async with contextlib.AsyncExitStack() as stack: + stack.enter_context(self.__serve_guard) + if task_group is None: + task_group = await stack.enter_async_context(self.__backend.create_task_group()) + + buffer: memoryview = stack.enter_context(memoryview(bytearray(self.MAX_DATAGRAM_BUFSIZE))) + + listener = self.__listener + while True: + nbytes, client_address = await listener.recvfrom_into(buffer) + + task_group.start_soon(handler, bytes(buffer[:nbytes]), client_address) + # Always drop references on loop end + del client_address + + raise AssertionError("Expected code to be unreachable.") + + async def send_to(self, data: bytes | bytearray | memoryview, address: tuple[Any, ...]) -> None: + await self.__listener.sendto(data, address) + + def backend(self) -> AsyncBackend: + return self.__backend + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return socket_tools._get_socket_extra(self.__trsock, wrap_in_proxy=False) diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/socket.py b/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/socket.py new file mode 100644 index 00000000..b0ad0e2d --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/datagram/socket.py @@ -0,0 +1,81 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = ["TrioDatagramSocketAdapter"] + +import socket as _socket +import warnings +from collections.abc import Callable, Mapping +from typing import Any, final + +import trio + +from ..... import _utils, socket as socket_tools +from ....transports.abc import AsyncDatagramTransport +from ...abc import AsyncBackend + + +@final +class TrioDatagramSocketAdapter(AsyncDatagramTransport): + __slots__ = ( + "__backend", + "__socket", + "__trsock", + ) + + from .....constants import MAX_DATAGRAM_BUFSIZE + + def __init__(self, backend: AsyncBackend, sock: trio.socket.SocketType) -> None: + super().__init__() + + if sock.type != _socket.SOCK_DGRAM: + raise ValueError("A 'SOCK_DGRAM' socket is expected") + + self.__backend: AsyncBackend = backend + self.__socket: trio.socket.SocketType = sock + self.__trsock: socket_tools.SocketProxy = socket_tools.SocketProxy(sock) + + def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: + try: + socket = self.__socket + except AttributeError: + socket = None + if socket is not None and socket.fileno() >= 0: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + socket.close() + + async def aclose(self) -> None: + self.__socket.close() + await trio.lowlevel.checkpoint() + + def is_closing(self) -> bool: + return self.__socket.fileno() < 0 + + async def recv(self) -> bytes: + return await self.__socket.recv(self.MAX_DATAGRAM_BUFSIZE) + + async def send(self, data: bytes | bytearray | memoryview) -> None: + await self.__socket.send(data) + + def backend(self) -> AsyncBackend: + return self.__backend + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return socket_tools._get_socket_extra(self.__trsock, wrap_in_proxy=False) diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/dns_resolver.py b/src/easynetwork/lowlevel/api_async/backend/_trio/dns_resolver.py new file mode 100644 index 00000000..a6ee7522 --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/dns_resolver.py @@ -0,0 +1,50 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = ["TrioDNSResolver"] + +import socket as _socket + +import trio.socket + +from .._common.dns_resolver import BaseAsyncDNSResolver + + +class TrioDNSResolver(BaseAsyncDNSResolver): + __slots__ = () + + async def connect_socket(self, socket: _socket.socket, address: tuple[str, int] | tuple[str, int, int, int]) -> None: + # TL;DR: Why not directly use trio.socket.socket() function? + # When giving a fileno, it tries to guess the real family, type and proto of the file descriptor + # by calling getsockopt(). This extra operation is useless here. + async_socket = trio.socket.from_stdlib_socket( + _socket.socket(socket.family, socket.type, socket.proto, fileno=socket.fileno()) + ) + try: + await async_socket.connect(address) + except BaseException: + # If connect() raises an exception, let trio close the socket. + # NOTE: connect() already closes the socket if trio.Cancelled is raised. + socket.detach() + raise + else: + # The operation has succeeded, remove the ownership to the temporary socket. + async_socket.detach() + finally: + async_socket.close() diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/stream/__init__.py b/src/easynetwork/lowlevel/api_async/backend/_trio/stream/__init__.py new file mode 100644 index 00000000..344d3a9c --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/stream/__init__.py @@ -0,0 +1,20 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = [] # type: list[str] diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/stream/_sendmsg.py b/src/easynetwork/lowlevel/api_async/backend/_trio/stream/_sendmsg.py new file mode 100644 index 00000000..dd374eee --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/stream/_sendmsg.py @@ -0,0 +1,38 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = ["supports_async_socket_sendmsg"] + +from abc import abstractmethod +from collections.abc import Awaitable, Iterable +from typing import TYPE_CHECKING, Protocol, TypeGuard + +import trio + +if TYPE_CHECKING: + from _typeshed import ReadableBuffer + + +class _SupportsAsyncSocketSendMSG(Protocol): + @abstractmethod + def sendmsg(self, buffers: Iterable[ReadableBuffer], /) -> Awaitable[int]: ... + + +def supports_async_socket_sendmsg(sock: trio.socket.SocketType) -> TypeGuard[_SupportsAsyncSocketSendMSG]: + return hasattr(sock, "sendmsg") diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/stream/listener.py b/src/easynetwork/lowlevel/api_async/backend/_trio/stream/listener.py new file mode 100644 index 00000000..89fa1920 --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/stream/listener.py @@ -0,0 +1,96 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = ["TrioListenerSocketAdapter"] + +import contextlib +import warnings +from collections.abc import Callable, Coroutine, Mapping +from typing import Any, NoReturn, final + +import trio + +from ..... import _utils, socket as socket_tools +from ....transports.abc import AsyncListener +from ...abc import AsyncBackend, TaskGroup +from .._trio_utils import convert_trio_resource_errors +from .socket import TrioStreamSocketAdapter + + +@final +class TrioListenerSocketAdapter(AsyncListener[TrioStreamSocketAdapter]): + __slots__ = ( + "__backend", + "__listener", + "__trsock", + "__serve_guard", + ) + + def __init__(self, backend: AsyncBackend, listener: trio.SocketListener) -> None: + super().__init__() + + self.__backend: AsyncBackend = backend + self.__listener: trio.SocketListener = listener + self.__trsock: socket_tools.SocketProxy = socket_tools.SocketProxy(listener.socket) + self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard(f"{self.__class__.__name__}.serve() awaited twice.") + + def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: + try: + listener = self.__listener + except AttributeError: # pragma: no cover + # Technically possible but not with the common usage because this constructor does not raise. + listener = None + if listener is not None and listener.socket.fileno() >= 0: + _warn(f"unclosed listener {self!r}", ResourceWarning, source=self) + listener.socket.close() + + async def aclose(self) -> None: + return await self.__listener.aclose() + + def is_closing(self) -> bool: + return self.__listener.socket.fileno() < 0 + + async def serve( + self, + handler: Callable[[TrioStreamSocketAdapter], Coroutine[Any, Any, None]], + task_group: TaskGroup | None = None, + ) -> NoReturn: + from errno import EBADF + + async with contextlib.AsyncExitStack() as stack: + stack.enter_context(self.__serve_guard) + if task_group is None: + task_group = await stack.enter_async_context(self.__backend.create_task_group()) + stack.enter_context(convert_trio_resource_errors(broken_resource_errno=EBADF)) + + while True: + # Always drop socket reference on loop begin + client_socket: trio.SocketStream | None = None + + client_socket = await self.__listener.accept() + task_group.start_soon(handler, TrioStreamSocketAdapter(self.__backend, client_socket)) + + raise AssertionError("Expected code to be unreachable.") + + def backend(self) -> AsyncBackend: + return self.__backend + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return socket_tools._get_socket_extra(self.__trsock, wrap_in_proxy=False) diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/stream/socket.py b/src/easynetwork/lowlevel/api_async/backend/_trio/stream/socket.py new file mode 100644 index 00000000..90da8b3b --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/stream/socket.py @@ -0,0 +1,112 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = ["TrioStreamSocketAdapter"] + +import errno as _errno +import itertools +import warnings +from collections import deque +from collections.abc import Callable, Iterable, Mapping +from typing import TYPE_CHECKING, Any, final + +import trio + +from ..... import _utils, constants, socket as socket_tools +from ....transports.abc import AsyncStreamTransport +from ...abc import AsyncBackend +from .._trio_utils import convert_trio_resource_errors +from ._sendmsg import supports_async_socket_sendmsg + +if TYPE_CHECKING: + from _typeshed import WriteableBuffer + + +@final +class TrioStreamSocketAdapter(AsyncStreamTransport): + __slots__ = ( + "__backend", + "__stream", + "__trsock", + ) + + def __init__(self, backend: AsyncBackend, stream: trio.SocketStream) -> None: + super().__init__() + + self.__backend: AsyncBackend = backend + self.__stream: trio.SocketStream = stream + self.__trsock: socket_tools.SocketProxy = socket_tools.SocketProxy(stream.socket) + + def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: + try: + stream = self.__stream + except AttributeError: # pragma: no cover + # Technically possible but not with the common usage because this constructor does not raise. + stream = None + if stream is not None and stream.socket.fileno() >= 0: + _warn(f"unclosed transport {self!r}", ResourceWarning, source=self) + stream.socket.close() + + async def aclose(self) -> None: + await self.__stream.aclose() + + def is_closing(self) -> bool: + return self.__stream.socket.fileno() < 0 + + async def recv(self, bufsize: int) -> bytes: + with convert_trio_resource_errors(broken_resource_errno=_errno.ECONNABORTED): + return await self.__stream.receive_some(bufsize) + + async def recv_into(self, buffer: WriteableBuffer) -> int: + return await self.__stream.socket.recv_into(buffer) + + async def send_all(self, data: bytes | bytearray | memoryview) -> None: + with convert_trio_resource_errors(broken_resource_errno=_errno.ECONNABORTED): + return await self.__stream.send_all(data) + + async def send_all_from_iterable(self, iterable_of_data: Iterable[bytes | bytearray | memoryview]) -> None: + if constants.SC_IOV_MAX <= 0: + return await super().send_all_from_iterable(iterable_of_data) + + socket = self.__stream.socket + if not supports_async_socket_sendmsg(socket): + return await super().send_all_from_iterable(iterable_of_data) + + buffers: deque[memoryview] = deque(map(memoryview, iterable_of_data)) + del iterable_of_data + + if not buffers: + return await self.send_all(b"") + + while buffers: + # Do not send the islice directly because if sendmsg() blocks, + # it would retry with an already consumed iterator. + sent = await socket.sendmsg(list(itertools.islice(buffers, constants.SC_IOV_MAX))) + _utils.adjust_leftover_buffer(buffers, sent) + + async def send_eof(self) -> None: + with convert_trio_resource_errors(broken_resource_errno=_errno.ECONNABORTED): + await self.__stream.send_eof() + + def backend(self) -> AsyncBackend: + return self.__backend + + @property + def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: + return socket_tools._get_socket_extra(self.__trsock, wrap_in_proxy=False) diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/tasks.py b/src/easynetwork/lowlevel/api_async/backend/_trio/tasks.py new file mode 100644 index 00000000..3de0505a --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/tasks.py @@ -0,0 +1,302 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = ["CancelScope", "Task", "TaskGroup"] + +import contextlib +import copy +import math +from collections.abc import Awaitable, Callable, Coroutine +from types import TracebackType +from typing import Any, Generic, Self, TypeVar, TypeVarTuple, final + +import outcome +import trio + +from .... import _utils +from ...._final import runtime_final_class +from ..abc import CancelScope as AbstractCancelScope, Task as AbstractTask, TaskGroup as AbstractTaskGroup, TaskInfo + +_T = TypeVar("_T") +_T_co = TypeVar("_T_co", covariant=True) +_T_PosArgs = TypeVarTuple("_T_PosArgs") + + +@final +@runtime_final_class +class Task(AbstractTask[_T_co]): + __slots__ = ( + "__task", + "__scope", + "__outcome", + ) + + def __init__(self, *, task: trio.lowlevel.Task, scope: trio.CancelScope, outcome: _OutcomeCell[_T_co]) -> None: + self.__task: trio.lowlevel.Task = task + self.__scope: trio.CancelScope = scope + self.__outcome: _OutcomeCell[_T_co] = outcome + + def __repr__(self) -> str: + return repr(self.__task) + + @property + def info(self) -> TaskInfo: + return TaskUtils.create_task_info(self.__task) + + def done(self) -> bool: + return self.__outcome.peek() is not None + + def cancel(self) -> bool: + if self.__outcome.peek() is None: + self.__scope.cancel() + return True + return False + + def cancelled(self) -> bool: + match self.__outcome.peek(): + case outcome.Error(trio.Cancelled()): + return True + case _: + return False + + async def wait(self) -> None: + await self.__outcome.get_no_checkpoints() + + async def join(self) -> _T_co: + outcome = await self.__outcome.get_no_checkpoints() + # Copy object because outcome objects can be unwrapped only once + outcome = copy.copy(outcome) + try: + return outcome.unwrap() + finally: + del outcome, self # This is needed to avoid circular reference with raised exception + + async def join_or_cancel(self) -> _T_co: + try: + outcome = await self.__outcome.get_no_checkpoints() + except trio.Cancelled: + self.__scope.cancel() + with trio.CancelScope(shield=True): + outcome = await self.__outcome.get_no_checkpoints() + if self.cancelled(): + # Re-raise the current exception instead + raise + + # Copy object because outcome objects can be unwrapped only once + outcome = copy.copy(outcome) + try: + return outcome.unwrap() + finally: + del outcome, self # This is needed to avoid circular reference with raised exception + + +@final +@runtime_final_class +class TaskGroup(AbstractTaskGroup): + __slots__ = ("__nursery_ctx", "__nursery") + + def __init__(self) -> None: + super().__init__() + + self.__nursery_ctx: contextlib.AbstractAsyncContextManager[trio.Nursery] = trio.open_nursery(strict_exception_groups=True) + self.__nursery: trio.Nursery | None = None + + async def __aenter__(self) -> Self: + nursery_ctx = self.__nursery_ctx + self.__nursery = await nursery_ctx.__aenter__() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + nursery_ctx = self.__nursery_ctx + try: + await nursery_ctx.__aexit__(exc_type, exc_val, exc_tb) + finally: + del exc_val, exc_tb, nursery_ctx, self + + def start_soon( + self, + coro_func: Callable[[*_T_PosArgs], Coroutine[Any, Any, _T]], + /, + *args: *_T_PosArgs, + name: str | None = None, + ) -> None: + nursery = self.__check_nursery_started() + + nursery.start_soon(coro_func, *args, name=name) + + async def start( + self, + coro_func: Callable[[*_T_PosArgs], Coroutine[Any, Any, _T]], + /, + *args: *_T_PosArgs, + name: str | None = None, + ) -> AbstractTask[_T]: + nursery = self.__check_nursery_started() + + if name is None: + name = TaskUtils.compute_task_name_from_func(coro_func) + + return await nursery.start(self.__task_coroutine, coro_func, args, name=name) + + def __check_nursery_started(self) -> trio.Nursery: + if (n := self.__nursery) is None: + raise RuntimeError("TaskGroup not started") + return n + + @staticmethod + async def __task_coroutine( + coro_func: Callable[[*_T_PosArgs], Awaitable[_T]], + args: tuple[*_T_PosArgs], + *, + task_status: trio.TaskStatus[Task[_T]], + ) -> None: + with trio.CancelScope() as scope: + + coroutine = coro_func(*args) + del coro_func, args + + cell: _OutcomeCell[_T] = _OutcomeCell() + + task_status.started( + Task( + task=trio.lowlevel.current_task(), + scope=scope, + outcome=cell, + ) + ) + + result: _T + try: + result = await coroutine + except BaseException as exc: + cell.set(outcome.Error(_utils.remove_traceback_frames_in_place(exc, 1))) + raise + else: + cell.set(outcome.Value(result)) + finally: + del coroutine + + +@final +@runtime_final_class +class CancelScope(AbstractCancelScope): + __slots__ = ("__scope",) + + def __init__(self, *, deadline: float = math.inf) -> None: + super().__init__() + self.__validate_deadline(deadline) + + self.__scope: trio.CancelScope = trio.CancelScope(deadline=deadline) + + def __enter__(self) -> Self: + scope = self.__scope + type(scope).__enter__(scope) + return self + + def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> bool: + scope = self.__scope + try: + return type(scope).__exit__(scope, exc_type, exc_val, exc_tb) or False + finally: + del exc_val, exc_tb, scope, self + + def cancel(self) -> None: + return self.__scope.cancel() + + def cancel_called(self) -> bool: + return self.__scope.cancel_called + + def cancelled_caught(self) -> bool: + return self.__scope.cancelled_caught + + def when(self) -> float: + return self.__scope.deadline + + def reschedule(self, when: float, /) -> None: + self.__validate_deadline(when) + self.__scope.deadline = when + + def __validate_deadline(self, when: float) -> None: + if math.isnan(when): + raise ValueError("deadline is NaN") + + +@final +@runtime_final_class +class TaskUtils: + + @classmethod + def create_task_info(cls, task: trio.lowlevel.Task) -> TaskInfo: + return TaskInfo(id(task), task.name, task.coro) + + @classmethod + def compute_task_name_from_func(cls, func: Callable[..., Any]) -> str: + return _utils.get_callable_name(func) or repr(func) + + +class _OutcomeCell(Generic[_T_co]): + __slots__ = ( + "__result", + "__waiting_tasks", + ) + + def __init__(self) -> None: + self.__result: outcome.Outcome[_T_co] | None = None + self.__waiting_tasks: set[trio.lowlevel.Task] = set() + + def peek(self) -> outcome.Outcome[_T_co] | None: + return self.__result + + def get_nowait(self) -> outcome.Outcome[_T_co]: + if (result := self.__result) is None: + raise trio.WouldBlock + return result + + async def get_no_checkpoints(self) -> outcome.Outcome[_T_co]: + try: + result = self.get_nowait() + except trio.WouldBlock: + pass + else: + return result + + task = trio.lowlevel.current_task() + self.__waiting_tasks.add(task) + + def abort_fn(_: Any) -> trio.lowlevel.Abort: + self.__waiting_tasks.discard(task) + return trio.lowlevel.Abort.SUCCEEDED + + return await trio.lowlevel.wait_task_rescheduled(abort_fn) + + def set(self, result: outcome.Outcome[_T_co]) -> None: + if self.__result is not None: + raise AssertionError("Already set to a value") + + self.__result = result + + for task in self.__waiting_tasks: + trio.lowlevel.reschedule(task, outcome.Value(result)) + + self.__waiting_tasks.clear() diff --git a/src/easynetwork/lowlevel/api_async/backend/_trio/threads.py b/src/easynetwork/lowlevel/api_async/backend/_trio/threads.py new file mode 100644 index 00000000..4246677b --- /dev/null +++ b/src/easynetwork/lowlevel/api_async/backend/_trio/threads.py @@ -0,0 +1,212 @@ +# Copyright 2021-2024, Francis Clairicia-Rose-Claire-Josephine +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at +# +# http://www.apache.org/licenses/LICENSE-2.0 +# +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License. +# +# +"""trio engine for easynetwork.api_async +""" + +from __future__ import annotations + +__all__ = ["ThreadsPortal"] + +import concurrent.futures +import contextlib +import contextvars +import inspect +import threading +from collections.abc import Awaitable, Callable +from types import TracebackType +from typing import ParamSpec, Self, TypeVar, final + +import trio + +from .... import _lock, _utils +from ...._final import runtime_final_class +from ..abc import ThreadsPortal as AbstractThreadsPortal +from .tasks import TaskGroup, TaskUtils + +_P = ParamSpec("_P") +_T = TypeVar("_T") + + +@final +@runtime_final_class +class ThreadsPortal(AbstractThreadsPortal): + __slots__ = ("__trio_token", "__lock", "__run_sync_soon_waiter", "__task_group") + + def __init__(self) -> None: + super().__init__() + + self.__lock = _lock.ForkSafeLock() + self.__trio_token: trio.lowlevel.TrioToken | None = None + self.__task_group: TaskGroup = TaskGroup() + self.__run_sync_soon_waiter = _PortalRunSyncSoonWaiter() + + async def __aenter__(self) -> Self: + if self.__trio_token is not None: + raise RuntimeError("ThreadsPortal entered twice.") + await self.__task_group.__aenter__() + self.__trio_token = trio.lowlevel.current_trio_token() + return self + + async def __aexit__( + self, + exc_type: type[BaseException] | None, + exc_val: BaseException | None, + exc_tb: TracebackType | None, + ) -> None: + try: + with self.__lock.get(): + self.__trio_token = None + + with trio.CancelScope(shield=True): + await self.__run_sync_soon_waiter.aclose() + await self.__task_group.__aexit__(exc_type, exc_val, exc_tb) + finally: + del self, exc_val, exc_tb + + def run_coroutine_soon( + self, + coro_func: Callable[_P, Awaitable[_T]], + /, + *args: _P.args, + **kwargs: _P.kwargs, + ) -> concurrent.futures.Future[_T]: + + future: concurrent.futures.Future[_T] = concurrent.futures.Future() + + def on_fut_done(payload: tuple[trio.lowlevel.TrioToken, trio.CancelScope], future: concurrent.futures.Future[_T]) -> None: + if future.cancelled(): + with contextlib.suppress(RuntimeError): + trio_token, cancel_scope = payload + trio_token.run_sync_soon(cancel_scope.cancel) + + def exception_handler(task: trio.lowlevel.Task, exc: BaseException) -> None: + import logging + + logger = logging.getLogger("trio") + + log_lines = [ + "Task exception was not retrieved because future object is cancelled", + f"task: {task!r}", + ] + + logger.error("\n".join(log_lines), exc_info=exc) + + async def coroutine() -> None: + with trio.CancelScope() as scope: + try: + future.add_done_callback(_utils.prepend_argument((trio.lowlevel.current_trio_token(), scope), on_fut_done)) + + result = await coro_func(*args, **kwargs) + except trio.Cancelled: + future.cancel() + future.set_running_or_notify_cancel() + raise + except BaseException as exc: + if future.set_running_or_notify_cancel(): + future.set_exception(exc) + else: + exception_handler(trio.lowlevel.current_task(), exc) + else: + if future.set_running_or_notify_cancel(): + future.set_result(result) + + def schedule_task() -> None: + self.__task_group.start_soon(coroutine, name=TaskUtils.compute_task_name_from_func(coro_func)) + + self.run_sync_soon(schedule_task) + return future + + def run_sync_soon(self, func: Callable[_P, _T], /, *args: _P.args, **kwargs: _P.kwargs) -> concurrent.futures.Future[_T]: + import sniffio + + run_sync_soon_waiter = self.__run_sync_soon_waiter + future: concurrent.futures.Future[_T] = concurrent.futures.Future() + + @trio.lowlevel.enable_ki_protection + def callback() -> None: + run_sync_soon_waiter.detach_in_trio_thread() + if not future.set_running_or_notify_cancel(): + return + try: + result = func(*args, **kwargs) + if inspect.iscoroutine(result): + result.close() # Prevent ResourceWarnings + msg = "func is a coroutine function." + note = "You should use run_coroutine() or run_coroutine_soon() instead." + raise _utils.exception_with_notes(TypeError(msg), note) + except BaseException as exc: + future.set_exception(exc) + else: + future.set_result(result) + + with self.__lock.get(): + trio_token = self.__check_current_token() + run_sync_soon_waiter.attach_from_any_thread() + + ctx: contextvars.Context = contextvars.copy_context() + # trio already sets sniffio.thread_local.name + ctx.run(sniffio.current_async_library_cvar.set, None) + + trio_token.run_sync_soon(ctx.run, callback) + return future + + def __check_current_token(self) -> trio.lowlevel.TrioToken: + trio_token = self.__trio_token + if trio_token is None: + raise RuntimeError("ThreadsPortal not running.") + if self.__is_in_this_loop_thread(trio_token): + raise RuntimeError("This function must be called in a different OS thread") + return trio_token + + @staticmethod + def __is_in_this_loop_thread(trio_token: trio.lowlevel.TrioToken) -> bool: + try: + current_trio_token = trio.lowlevel.current_trio_token() + except RuntimeError: + return False + return current_trio_token is trio_token + + +class _PortalRunSyncSoonWaiter: + __slots__ = ("__done", "__thread_lock", "__waiter_count", "__closing") + + def __init__(self) -> None: + self.__done: trio.Event = trio.Event() + self.__thread_lock = _lock.ForkSafeLock(threading.Lock) + self.__waiter_count: int = 0 + self.__closing: bool = False + + async def aclose(self) -> None: + with self.__thread_lock.get(): + self.__closing = True + if not self.__waiter_count: + self.__done.set() + + await self.__done.wait() + + def attach_from_any_thread(self) -> None: + with self.__thread_lock.get(): + if self.__done.is_set(): + raise AssertionError("currently closed") + self.__waiter_count += 1 + + @trio.lowlevel.enable_ki_protection + def detach_in_trio_thread(self) -> None: + with self.__thread_lock.get(): + self.__waiter_count -= 1 + if self.__waiter_count < 0: + raise AssertionError("self.__waiter_count < 0") + if not self.__waiter_count and self.__closing: + self.__done.set() diff --git a/src/easynetwork/lowlevel/api_async/backend/abc.py b/src/easynetwork/lowlevel/api_async/backend/abc.py index a22b7a52..2333cf56 100644 --- a/src/easynetwork/lowlevel/api_async/backend/abc.py +++ b/src/easynetwork/lowlevel/api_async/backend/abc.py @@ -36,7 +36,7 @@ from collections.abc import Awaitable, Callable, Coroutine, Mapping, Sequence from contextlib import AbstractContextManager from types import TracebackType -from typing import Any, Generic, NoReturn, ParamSpec, Protocol, Self, TypeVar, TypeVarTuple, Unpack +from typing import Any, Generic, Literal, NoReturn, ParamSpec, Protocol, Self, TypeVar, TypeVarTuple, Unpack from ..transports import abc as _transports @@ -82,7 +82,7 @@ async def __aexit__( exc_val: BaseException | None, exc_tb: TracebackType | None, /, - ) -> bool | None: ... + ) -> Literal[False] | None: ... @abstractmethod async def acquire(self) -> Any: @@ -398,7 +398,7 @@ def deadline(self) -> float: scope.reschedule(scope.when() + 30) - It is also possible to remove the timeout by deleting the attribute:: + It is also possible to remove the timeout by "deleting" the attribute:: del scope.deadline """ @@ -862,6 +862,56 @@ async def sleep_until(self, deadline: float) -> None: """ return await self.sleep(max(deadline - self.current_time(), 0)) + async def gather(self, *coroutines: Awaitable[_T_co]) -> list[_T_co]: + """ + Run awaitable objects in the `coroutines` sequence concurrently. + + Parameters: + coroutines: any awaitable object. + + Returns: + If all awaitables are completed successfully, the result is an aggregate list of returned values. + The order of result values corresponds to the order of awaitables in `coroutines`. + + Raises: + ExceptionGroup: If one or more awaitable(s) fails. + """ + + if not coroutines: + # Fast path. + return [] + + from ..._utils import remove_traceback_frames_in_place + + async def _await(coro: Awaitable[_T_co]) -> _T_co: + try: + return await coro + except BaseException as exc: + remove_traceback_frames_in_place(exc, 1) + raise + finally: + del coro + + coro_to_task: dict[Awaitable[_T_co], Task[_T_co]] = {} + + children: list[Task[_T_co]] = [] + + async with self.create_task_group() as task_group: + for coro in coroutines: + if coro in coro_to_task: + task = coro_to_task[coro] + else: + task = await self.ignore_cancellation(task_group.start(_await, coro)) + coro_to_task[coro] = task + children.append(task) + + coro_to_task.clear() + + # task_group should raise an ExceptionGroup if one of the coroutine raises an exception + # At this point, all the tasks should be done and join() would neither block nor raise. + assert all(child.done() for child in children) # nosec assert_used + return [await child.join() for child in children] + @abstractmethod def create_task_group(self) -> TaskGroup: """ @@ -887,6 +937,38 @@ def get_current_task(self) -> TaskInfo: """ raise NotImplementedError + async def getaddrinfo( + self, + host: bytes | str | None, + port: bytes | str | int | None, + family: int = 0, + type: int = 0, + proto: int = 0, + flags: int = 0, + ) -> Sequence[tuple[int, int, int, str, tuple[str, int] | tuple[str, int, int, int]]]: + """ + Asynchronous version of :func:`socket.getaddrinfo`. + """ + from ..._utils import make_callback + + getaddrinfo = make_callback(_socket.getaddrinfo, host, port, family=family, type=type, proto=proto, flags=flags) + + return await self.run_in_thread(getaddrinfo, abandon_on_cancel=True) + + async def getnameinfo( + self, + sockaddr: tuple[str, int] | tuple[str, int, int, int], + flags: int = 0, + ) -> tuple[str, str]: + """ + Asynchronous version of :func:`socket.getnameinfo`. + """ + from ..._utils import make_callback + + getnameinfo = make_callback(_socket.getnameinfo, sockaddr, flags) + + return await self.run_in_thread(getnameinfo, abandon_on_cancel=True) + @abstractmethod async def create_tcp_connection( self, @@ -988,6 +1070,7 @@ async def create_udp_endpoint( remote_host: The host IP/domain name. remote_port: Port of connection. local_address: If given, is a ``(local_host, local_port)`` tuple used to bind the socket locally. + family: The address family Raises: OSError: unrelated OS error occurred. @@ -1151,6 +1234,6 @@ def __enter__(self) -> CancelScope: return self.scope.__enter__() def __exit__(self, exc_type: type[BaseException] | None, exc_val: BaseException | None, exc_tb: TracebackType | None) -> None: - cancelled_caught = self.scope.__exit__(exc_type, exc_val, exc_tb) - if cancelled_caught: + self.scope.__exit__(exc_type, exc_val, exc_tb) + if self.scope.cancelled_caught(): raise TimeoutError("timed out") diff --git a/src/easynetwork/lowlevel/api_async/backend/utils.py b/src/easynetwork/lowlevel/api_async/backend/utils.py index 71bbfade..ce71a3fa 100644 --- a/src/easynetwork/lowlevel/api_async/backend/utils.py +++ b/src/easynetwork/lowlevel/api_async/backend/utils.py @@ -19,27 +19,53 @@ __all__ = [ "BuiltinAsyncBackendLiteral", "ensure_backend", + "new_builtin_backend", ] -from typing import Literal, TypeAlias, cast +from typing import Literal, TypeAlias, assert_never, cast import sniffio from .abc import AsyncBackend -BuiltinAsyncBackendLiteral: TypeAlias = Literal["asyncio"] +BuiltinAsyncBackendLiteral: TypeAlias = Literal["asyncio", "trio"] """Supported asynchronous framework names.""" +def new_builtin_backend(name: BuiltinAsyncBackendLiteral) -> AsyncBackend: + """ + Obtain an interface for the given `backend`. + + Here is the list of the supported libraries: + + * ``"asyncio"`` + + * ``"trio"`` + """ + match name: + case "asyncio": + from ._asyncio.backend import AsyncIOBackend + + return AsyncIOBackend() + case "trio": + from ._trio.backend import TrioBackend + + return TrioBackend() + case str(): + raise NotImplementedError(name) + case _: # pragma: no cover + assert_never(name) + + def ensure_backend(backend: AsyncBackend | BuiltinAsyncBackendLiteral | None) -> AsyncBackend: """ - Obtain an interface for the give `backend`. + Obtain an interface for the given `backend`. * If `backend` is already an :class:`.AsyncBackend`, this object is returned. * If `backend` is a string token and matches one of the built-in implementation, a new object is returned. - * Currently, only ``"asyncio"`` is recognized. + * See also :func:`new_builtin_backend`. * If :data:`None`, the function tries to guess the library currently used with :mod:`sniffio`. @@ -51,13 +77,9 @@ def ensure_backend(backend: AsyncBackend | BuiltinAsyncBackendLiteral | None) -> backend = cast(BuiltinAsyncBackendLiteral, sniffio.current_async_library()) match backend: - case "asyncio": - from ._asyncio.backend import AsyncIOBackend - - return AsyncIOBackend() case AsyncBackend(): return backend case str(): - raise NotImplementedError(backend) + return new_builtin_backend(backend) case _: raise TypeError(f"Expected either a string literal or a backend instance, got {backend!r}") diff --git a/tests/fixtures/__init__.py b/tests/fixtures/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/fixtures/trio.py b/tests/fixtures/trio.py new file mode 100644 index 00000000..701137cf --- /dev/null +++ b/tests/fixtures/trio.py @@ -0,0 +1,17 @@ +from __future__ import annotations + +from collections.abc import Callable +from typing import Any, TypeVar + +import pytest + +_F = TypeVar("_F", bound=Callable[..., Any]) + + +def trio_fixture(fixture_function: _F) -> _F: + try: + import pytest_trio + except ImportError: + return pytest.fixture(fixture_function) + else: + return pytest_trio.trio_fixture(fixture_function) diff --git a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py index b7a912e3..648148ee 100644 --- a/tests/functional_test/test_async/test_backend/test_asyncio_backend.py +++ b/tests/functional_test/test_async/test_backend/test_asyncio_backend.py @@ -1,14 +1,15 @@ from __future__ import annotations import asyncio +import contextlib import time from collections.abc import Awaitable, Callable, Iterator from concurrent.futures import CancelledError as FutureCancelledError, wait as wait_concurrent_futures from contextlib import ExitStack from typing import TYPE_CHECKING, Any, Literal, Required, TypedDict -from easynetwork.lowlevel.api_async.backend._asyncio.backend import AsyncIOBackend -from easynetwork.lowlevel.api_async.backend.abc import TaskInfo +from easynetwork.lowlevel.api_async.backend.abc import AsyncBackend, TaskInfo +from easynetwork.lowlevel.api_async.backend.utils import new_builtin_backend import pytest import sniffio @@ -33,14 +34,14 @@ class ExceptionCaughtDict(TypedDict, total=False): @pytest.mark.flaky(retries=3, delay=0) class TestAsyncioBackendBootstrap: - @pytest.fixture + @pytest.fixture(scope="class") @staticmethod - def backend() -> AsyncIOBackend: - return AsyncIOBackend() + def backend() -> AsyncBackend: + return new_builtin_backend("asyncio") def test____bootstrap____sniffio_thread_local_reset( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: assert sniffio.thread_local.name is None @@ -71,14 +72,14 @@ def event_loop_exceptions_caught( with temporary_exception_handler(event_loop, handler_stub): yield event_loop_exceptions_caught - @pytest.fixture + @pytest.fixture(scope="class") @staticmethod - def backend() -> AsyncIOBackend: - return AsyncIOBackend() + def backend() -> AsyncBackend: + return new_builtin_backend("asyncio") async def test____cancel_shielded_coro_yield____mute_cancellation( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: task: asyncio.Task[None] = asyncio.create_task(backend.cancel_shielded_coro_yield()) @@ -94,7 +95,7 @@ async def test____cancel_shielded_coro_yield____mute_cancellation( async def test____cancel_shielded_coro_yield____cancel_at_the_next_checkpoint( self, cancel_message: str | None, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: test_list: list[str] = [] @@ -125,7 +126,7 @@ async def coroutine() -> None: async def test____ignore_cancellation____always_continue_on_cancellation( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() checkpoints: list[int] = [] @@ -150,7 +151,7 @@ async def coroutine() -> int: async def test____ignore_cancellation____exception_raised_in_task( self, direct_raise: bool, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() exception = Exception("error") @@ -177,7 +178,7 @@ async def coroutine() -> None: async def test____ignore_cancellation____runs_in_current_task( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> asyncio.Task[Any]: task = asyncio.current_task() @@ -188,7 +189,7 @@ async def coroutine() -> asyncio.Task[Any]: async def test____ignore_cancellation____remove_future_blocking_flag_like_default_implementation( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -202,7 +203,7 @@ async def coroutine() -> None: async def test____ignore_cancellation____forbid_await_itself_like_default_implementation( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> None: task = asyncio.current_task() @@ -218,7 +219,7 @@ async def test____ignore_cancellation____coroutine_cancelled_itself( self, cancel_method: Literal["fut_cancel", "raise"], with_delay: bool, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -248,14 +249,14 @@ async def self_cancellation() -> None: async def test____timeout____respected( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: with backend.timeout(1): assert await asyncio.sleep(0.5, 42) == 42 async def test____timeout____timeout_error( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: with pytest.raises(TimeoutError): with backend.timeout(0.25): @@ -263,7 +264,7 @@ async def test____timeout____timeout_error( async def test____timeout____cancellation( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -279,7 +280,7 @@ async def coroutine() -> None: async def test____timeout_at____respected( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -288,7 +289,7 @@ async def test____timeout_at____respected( async def test____timeout_at____timeout_error( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -298,7 +299,7 @@ async def test____timeout_at____timeout_error( async def test____timeout_at____cancellation( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -314,7 +315,7 @@ async def coroutine() -> None: async def test____move_on_after____respected( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: with backend.move_on_after(1) as scope: assert await asyncio.sleep(0.5, 42) == 42 @@ -323,7 +324,7 @@ async def test____move_on_after____respected( async def test____move_on_after____timeout_error( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: with backend.move_on_after(0.25) as scope: await asyncio.sleep(0.5, 42) @@ -332,7 +333,7 @@ async def test____move_on_after____timeout_error( async def test____move_on_after____cancellation( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -348,7 +349,7 @@ async def coroutine() -> None: async def test____move_on_at____respected( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -359,7 +360,7 @@ async def test____move_on_at____respected( async def test____move_on_at____timeout_error( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -370,7 +371,7 @@ async def test____move_on_at____timeout_error( async def test____move_on_at____cancellation( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -386,7 +387,7 @@ async def coroutine() -> None: async def test____sleep_forever____sleep_until_cancellation( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -399,7 +400,7 @@ async def test____sleep_forever____sleep_until_cancellation( async def test____open_cancel_scope____unbound_cancel_scope____cancel_when_entering( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> None: current_task = asyncio.current_task() @@ -423,7 +424,7 @@ async def coroutine() -> None: async def test____open_cancel_scope____unbound_cancel_scope____deadline_scheduled_when_entering( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -447,7 +448,7 @@ async def coroutine() -> None: async def test____open_cancel_scope____overwrite_defined_deadline( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> None: current_task = asyncio.current_task() @@ -467,14 +468,14 @@ async def coroutine() -> None: async def test____open_cancel_scope____invalid_deadline( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: with pytest.raises(ValueError): _ = backend.open_cancel_scope(deadline=float("nan")) async def test____open_cancel_scope____context_reuse( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: with backend.open_cancel_scope() as scope: with pytest.raises(RuntimeError, match=r"^CancelScope entered twice$"): @@ -487,14 +488,14 @@ async def test____open_cancel_scope____context_reuse( async def test____open_cancel_scope____context_exit_before_enter( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: with pytest.raises(RuntimeError, match=r"^This cancel scope is not active$"), ExitStack() as stack: stack.push(backend.open_cancel_scope()) async def test____open_cancel_scope____task_misnesting( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> ExitStack: stack = ExitStack() @@ -507,7 +508,7 @@ async def coroutine() -> ExitStack: async def test____open_cancel_scope____scope_misnesting( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: stack = ExitStack() stack.enter_context(backend.open_cancel_scope()) @@ -518,9 +519,70 @@ async def test____open_cancel_scope____scope_misnesting( stack.close() stack.pop_all() + async def test____gather____no_parameters( + self, + backend: AsyncBackend, + ) -> None: + result: list[int] = await backend.gather() + assert result == [] + + async def test____gather____concurrent_await( + self, + backend: AsyncBackend, + ) -> None: + async def coroutine(value: int) -> int: + return await asyncio.sleep(0.5, value) + + with backend.timeout(0.6): + result = await backend.gather( + coroutine(42), + coroutine(54), + ) + + assert result == [42, 54] + + async def test____gather____concurrent_await____exception_raises( + self, + backend: AsyncBackend, + ) -> None: + async def coroutine(value: int) -> int: + return await backend.ignore_cancellation(asyncio.sleep(0.5, value)) + + async def coroutine_error(exception: Exception) -> int: + await backend.ignore_cancellation(asyncio.sleep(0.5)) + raise exception + + with pytest.raises(ExceptionGroup) as exc_info: + await backend.gather( + coroutine(42), + coroutine_error(ValueError("conversion error")), + coroutine(54), + coroutine_error(KeyError("unknown")), + ) + + assert len(exc_info.value.exceptions) == 2 + + async def test____gather____duplicate_awaitable( + self, + backend: AsyncBackend, + mocker: MockerFixture, + ) -> None: + awaited = mocker.async_stub() + + async def coroutine(value: int) -> int: + await awaited() + return await asyncio.sleep(0.5, value) + + awaitable = coroutine(42) + + result = await backend.gather(awaitable, awaitable, awaitable) + + assert result == [42, 42, 42] + awaited.assert_awaited_once() + async def test____create_task_group____start_soon( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: tasks: list[TaskInfo] = [] @@ -536,7 +598,7 @@ async def coroutine(value: int) -> int: async def test____create_task_group____start_soon____set_name( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: tasks: list[TaskInfo] = [] @@ -553,7 +615,7 @@ async def coroutine(value: int) -> int: async def test____create_task_group____start_and_wait( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: tasks: list[TaskInfo] = [] @@ -589,7 +651,7 @@ async def coroutine(value: int) -> int: async def test____create_task_group____start_and_wait____set_name( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine(value: int) -> int: return await asyncio.sleep(0.5, value) @@ -603,7 +665,7 @@ async def coroutine(value: int) -> int: async def test____create_task_group____start_and_wait____waiter_cancelled( self, - backend: AsyncIOBackend, + backend: AsyncBackend, mocker: MockerFixture, ) -> None: awaited = mocker.async_stub() @@ -620,7 +682,7 @@ async def coroutine(value: int) -> int: async def test____create_task_group____task_cancellation( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine(value: int) -> int: return await asyncio.sleep(0.5, value) @@ -650,7 +712,7 @@ async def coroutine(value: int) -> int: async def test____create_task_group____task_join_cancel_shielding( self, join_method: Literal["join", "join_or_cancel", "wait"], - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -686,7 +748,7 @@ async def coroutine(value: int) -> int: async def test____create_task_group____task_wait( self, task_state: Literal["result", "exception", "cancelled"], - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() future: asyncio.Future[None] = event_loop.create_future() @@ -710,7 +772,7 @@ async def coroutine() -> None: event_loop.call_later(0.1, set_future_result) - try: + with pytest.raises(ExceptionGroup) if task_state == "exception" else contextlib.nullcontext() as exc_info: async with backend.create_task_group() as task_group: task = await task_group.start(coroutine) @@ -720,12 +782,13 @@ async def coroutine() -> None: # Must not yield if task is already done async with asyncio.timeout(0): await task.wait() - except* FutureException: - pass + + if exc_info is not None: + assert isinstance(exc_info.value.exceptions[0], FutureException) async def test____run_in_thread____cannot_be_cancelled_by_default( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() task = asyncio.create_task(backend.run_in_thread(time.sleep, 0.5)) @@ -740,7 +803,7 @@ async def test____run_in_thread____cannot_be_cancelled_by_default( async def test____run_in_thread____abandon_on_cancel( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() task = asyncio.create_task(backend.run_in_thread(time.sleep, 0.5, abandon_on_cancel=True)) @@ -751,7 +814,7 @@ async def test____run_in_thread____abandon_on_cancel( assert task.cancelled() - async def test____run_in_thread____sniffio_contextvar_reset(self, backend: AsyncIOBackend) -> None: + async def test____run_in_thread____sniffio_contextvar_reset(self, backend: AsyncBackend) -> None: sniffio.current_async_library_cvar.set("asyncio") def callback() -> str | None: @@ -765,7 +828,7 @@ def callback() -> str | None: async def test____create_threads_portal____run_coroutine_from_thread( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() threads_portal = backend.create_threads_portal() @@ -793,7 +856,7 @@ def thread() -> int: async def test____create_threads_portal____run_coroutine_from_thread____can_be_called_from_other_event_loop( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -814,7 +877,7 @@ async def main() -> int: @pytest.mark.parametrize("exception_cls", [Exception, BaseException]) async def test____create_threads_portal____run_coroutine_from_thread____exception_raised( self, - backend: AsyncIOBackend, + backend: AsyncBackend, exception_cls: type[BaseException], event_loop_exceptions_caught: list[ExceptionCaughtDict], ) -> None: @@ -827,27 +890,16 @@ def thread() -> int: return threads_portal.run_coroutine(coroutine, 42) threads_portal = backend.create_threads_portal() - if issubclass(exception_cls, Exception): - async with threads_portal: - with pytest.raises(BaseException) as exc_info: - await backend.run_in_thread(thread) - - assert exc_info.value is expected_exception - assert len(event_loop_exceptions_caught) == 0 - else: - with pytest.raises(BaseExceptionGroup) as exc_group_info: - async with threads_portal: - with pytest.raises(BaseException) as exc_info: - await backend.run_in_thread(thread) - - assert exc_info.value is expected_exception + async with threads_portal: + with pytest.raises(BaseException) as exc_info: + await backend.run_in_thread(thread) - assert len(event_loop_exceptions_caught) == 0 - assert exc_group_info.value.exceptions[0] is expected_exception + assert exc_info.value is expected_exception + assert len(event_loop_exceptions_caught) == 0 async def test____create_threads_portal____run_coroutine_from_thread____coroutine_cancelled( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine(value: int) -> int: task = asyncio.current_task() @@ -865,7 +917,7 @@ def thread() -> int: async def test____create_threads_portal____run_coroutine_from_thread____explicit_concurrent_future_Cancelled( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine(value: int) -> int: raise FutureCancelledError() @@ -880,7 +932,7 @@ def thread() -> int: async def test____create_threads_portal____run_coroutine_from_thread____sniffio_contextvar_reset( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: sniffio.current_async_library_cvar.set("main") @@ -899,7 +951,7 @@ def thread() -> str | None: async def test____create_threads_portal____run_sync_from_thread_in_event_loop( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() threads_portal = backend.create_threads_portal() @@ -927,7 +979,7 @@ def thread() -> int: async def test____create_threads_portal____run_sync_from_thread_in_event_loop____can_be_called_from_other_event_loop( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -948,7 +1000,7 @@ async def main() -> int: @pytest.mark.parametrize("exception_cls", [Exception, BaseException]) async def test____create_threads_portal____run_sync_from_thread_in_event_loop____exception_raised( self, - backend: AsyncIOBackend, + backend: AsyncBackend, exception_cls: type[BaseException], event_loop_exceptions_caught: list[ExceptionCaughtDict], ) -> None: @@ -961,25 +1013,16 @@ def thread() -> int: return threads_portal.run_sync(not_threadsafe_func, 42) threads_portal = backend.create_threads_portal() - if issubclass(exception_cls, Exception): - async with threads_portal: - with pytest.raises(BaseException) as exc_info: - await backend.run_in_thread(thread) - - assert exc_info.value is expected_exception - assert len(event_loop_exceptions_caught) == 0 - else: - async with threads_portal: - with pytest.raises(BaseException) as exc_info: - await backend.run_in_thread(thread) + async with threads_portal: + with pytest.raises(BaseException) as exc_info: + await backend.run_in_thread(thread) - assert exc_info.value is expected_exception - assert len(event_loop_exceptions_caught) == 1 - assert event_loop_exceptions_caught[0]["exception"] is expected_exception + assert exc_info.value is expected_exception + assert len(event_loop_exceptions_caught) == 0 async def test____create_threads_portal____run_sync_from_thread_in_event_loop____explicit_concurrent_future_Cancelled( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: def not_threadsafe_func(value: int) -> int: raise FutureCancelledError() @@ -994,7 +1037,7 @@ def thread() -> int: async def test____create_threads_portal____run_sync_from_thread_in_event_loop____async_function_given( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> None: raise AssertionError("Should not be called") @@ -1010,7 +1053,7 @@ def thread() -> None: async def test____create_threads_portal____run_sync_from_thread_in_event_loop____sniffio_contextvar_reset( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: sniffio.current_async_library_cvar.set("main") @@ -1029,7 +1072,7 @@ def thread() -> str | None: async def test____create_threads_portal____run_sync_soon____future_cancelled_before_call( self, - backend: AsyncIOBackend, + backend: AsyncBackend, mocker: MockerFixture, ) -> None: event_loop = asyncio.get_running_loop() @@ -1054,7 +1097,7 @@ def thread() -> None: async def test____create_threads_portal____run_coroutine_soon____future_cancelled( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: def thread() -> None: future = threads_portal.run_coroutine_soon(asyncio.sleep, 1) @@ -1073,7 +1116,7 @@ def thread() -> None: async def test____create_threads_portal____run_coroutine_soon____future_cancelled____cancellation_ignored( self, value: int | Exception, - backend: AsyncIOBackend, + backend: AsyncBackend, event_loop_exceptions_caught: list[ExceptionCaughtDict], mocker: MockerFixture, ) -> None: @@ -1117,7 +1160,7 @@ def thread() -> None: async def test____create_threads_portal____run_coroutine_soon____future_cancelled_before_await( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() checkpoints: list[str] = [] @@ -1155,7 +1198,7 @@ def event_loop_slowdown() -> None: # Drastically slow down event loop @pytest.mark.skipif(not hasattr(asyncio, "eager_task_factory"), reason="asyncio.eager_task_factory not implemented") async def test____create_threads_portal____run_coroutine_soon____eager_task( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() event_loop.set_task_factory(getattr(asyncio, "eager_task_factory")) @@ -1165,7 +1208,6 @@ async def coroutine() -> int: def thread() -> None: future = threads_portal.run_coroutine_soon(coroutine) - assert future.done() assert future.result() == 42 async with backend.create_threads_portal() as threads_portal: @@ -1173,7 +1215,7 @@ def thread() -> None: async def test____create_threads_portal____context_exit____wait_scheduled_call_soon( self, - backend: AsyncIOBackend, + backend: AsyncBackend, mocker: MockerFixture, ) -> None: event_loop = asyncio.get_running_loop() @@ -1194,7 +1236,7 @@ def thread() -> None: async def test____create_threads_portal____context_exit____wait_scheduled_call_soon_for_coroutine( self, - backend: AsyncIOBackend, + backend: AsyncBackend, mocker: MockerFixture, ) -> None: event_loop = asyncio.get_running_loop() @@ -1215,7 +1257,7 @@ def thread() -> None: async def test____create_threads_portal____entered_twice( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async with backend.create_threads_portal() as threads_portal: with pytest.raises(RuntimeError, match=r"ThreadsPortal entered twice\."): @@ -1224,10 +1266,10 @@ async def test____create_threads_portal____entered_twice( @pytest.mark.asyncio class TestAsyncioBackendShieldedCancellation: - @pytest.fixture + @pytest.fixture(scope="class") @staticmethod - def backend() -> AsyncIOBackend: - return AsyncIOBackend() + def backend() -> AsyncBackend: + return new_builtin_backend("asyncio") @pytest.fixture( params=[ @@ -1241,7 +1283,7 @@ def backend() -> AsyncIOBackend: @staticmethod def cancel_shielded_coroutine( request: pytest.FixtureRequest, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> Callable[[], Awaitable[Any]]: match getattr(request, "param"): case "cancel_shielded_coro_yield": @@ -1269,7 +1311,7 @@ async def cancel_shielded_wait_asyncio_futures() -> None: async def test____cancel_shielded_coroutine____do_not_cancel_at_timeout_end( self, cancel_shielded_coroutine: Callable[[], Awaitable[Any]], - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: checkpoints: list[str] = [] @@ -1292,7 +1334,7 @@ async def coroutine(value: int) -> int: async def test____cancel_shielded_coroutine____cancel_at_timeout_end_if_nested( self, cancel_shielded_coroutine: Callable[[], Awaitable[Any]], - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: checkpoints: list[str] = [] @@ -1327,7 +1369,7 @@ async def coroutine(value: int) -> int: async def test____timeout____cancel_at_timeout_end_if_task_cancellation_were_already_delayed( self, cancel_shielded_coroutine: Callable[[], Awaitable[Any]], - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() checkpoints: list[str] = [] @@ -1357,7 +1399,7 @@ async def coroutine(value: int) -> int: async def test____cancel_shielded_coroutine____cancel_at_timeout_end_if_task_cancellation_does_not_come_from_scope( self, cancel_shielded_coroutine: Callable[[], Awaitable[Any]], - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() checkpoints: list[str] = [] @@ -1384,7 +1426,7 @@ async def coroutine(value: int) -> int: async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_1( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> None: current_task = asyncio.current_task() @@ -1407,7 +1449,7 @@ async def coroutine() -> None: async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_2( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> None: current_task = asyncio.current_task() @@ -1433,7 +1475,7 @@ async def coroutine() -> None: async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_3( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> None: current_task = asyncio.current_task() @@ -1452,7 +1494,7 @@ async def coroutine() -> None: async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_4( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> None: current_task = asyncio.current_task() @@ -1471,7 +1513,7 @@ async def coroutine() -> None: async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_5( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> None: current_task = asyncio.current_task() @@ -1501,7 +1543,7 @@ async def coroutine() -> None: async def test____cancel_shielded_coroutine____scope_cancellation_edge_case_6( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: async def coroutine() -> None: current_task = asyncio.current_task() @@ -1533,7 +1575,7 @@ async def coroutine() -> None: async def test____ignore_cancellation____do_not_reschedule_if_inner_task_raises_CancelledError( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() @@ -1562,7 +1604,7 @@ async def coroutine() -> None: async def test____ignore_cancellation____reschedule_erased_cancel_from_parent( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: checkpoints: list[str] = [] diff --git a/tests/functional_test/test_async/test_backend/test_trio_backend.py b/tests/functional_test/test_async/test_backend/test_trio_backend.py new file mode 100644 index 00000000..5e310886 --- /dev/null +++ b/tests/functional_test/test_async/test_backend/test_trio_backend.py @@ -0,0 +1,1034 @@ +from __future__ import annotations + +import concurrent.futures +import contextlib +import time +from typing import TYPE_CHECKING, Literal + +from easynetwork.lowlevel.api_async.backend.abc import AsyncBackend, Task, TaskInfo +from easynetwork.lowlevel.api_async.backend.utils import new_builtin_backend + +import pytest +import sniffio + +from ....tools import call_later_with_nursery + +if TYPE_CHECKING: + from trio import Nursery + + from pytest_mock import MockerFixture + + +@pytest.mark.feature_trio +class TestTrioBackendBootstrap: + @pytest.fixture(scope="class") + @staticmethod + def backend() -> AsyncBackend: + return new_builtin_backend("trio") + + def test____bootstrap____sniffio_thread_local_reset( + self, + backend: AsyncBackend, + ) -> None: + assert sniffio.thread_local.name is None + + async def main() -> str | None: + return sniffio.thread_local.name + + thread_local_inner = backend.bootstrap(main) + + assert thread_local_inner == "trio" + assert sniffio.thread_local.name is None + + +@pytest.mark.feature_trio(async_test_auto_mark=True) +@pytest.mark.flaky(retries=3, delay=0) +class TestTrioBackend: + + @pytest.fixture(scope="class") + @staticmethod + def backend() -> AsyncBackend: + return new_builtin_backend("trio") + + async def test____cancel_shielded_coro_yield____mute_cancellation( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with trio.CancelScope() as scope: + scope.cancel() + await backend.cancel_shielded_coro_yield() + + assert not scope.cancelled_caught + + async def test____ignore_cancellation____always_continue_on_cancellation( + self, + backend: AsyncBackend, + ) -> None: + import trio + + checkpoints: list[int] = [] + + async def coroutine() -> int: + for i in range(2): + await trio.sleep(0.25) + checkpoints.append(i) + return 42 + + value: int = 0 + with trio.CancelScope() as scope: + scope.cancel() + value = await backend.ignore_cancellation(coroutine()) + + assert not scope.cancelled_caught + assert value == 42 + + async def test____ignore_cancellation____runs_in_current_task( + self, + backend: AsyncBackend, + ) -> None: + import trio + + async def coroutine() -> trio.lowlevel.Task: + return trio.lowlevel.current_task() + + assert (await backend.ignore_cancellation(coroutine())) is trio.lowlevel.current_task() + + async def test____timeout____respected( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with backend.timeout(1): + await trio.sleep(0.5) + + async def test____timeout____timeout_error( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with pytest.raises(TimeoutError): + with backend.timeout(0.25): + await trio.sleep(0.5) + + async def test____timeout____cancellation( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with trio.move_on_after(0.10) as root_scope: + with backend.timeout(0.25): + await trio.sleep(0.5) + + assert root_scope.cancelled_caught + + async def test____timeout_at____respected( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with backend.timeout_at(trio.current_time() + 1): + await trio.sleep(0.5) + + async def test____timeout_at____timeout_error( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with pytest.raises(TimeoutError): + with backend.timeout_at(trio.current_time() + 0.25): + await trio.sleep(0.5) + + async def test____timeout_at____cancellation( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with trio.move_on_after(0.10) as root_scope: + with backend.timeout_at(trio.current_time() + 0.25): + await trio.sleep(0.5) + + assert root_scope.cancelled_caught + + async def test____move_on_after____respected( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with backend.move_on_after(1) as scope: + await trio.sleep(0.5) + + assert not scope.cancelled_caught() + + async def test____move_on_after____timeout_error( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with backend.move_on_after(0.25) as scope: + await trio.sleep(0.5) + + assert scope.cancelled_caught() + + async def test____move_on_after____cancellation( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with trio.move_on_after(0.10) as root_scope: + with backend.move_on_after(0.25) as inner_scope: + await trio.sleep(0.5) + + assert not inner_scope.cancelled_caught() + assert root_scope.cancelled_caught + + async def test____move_on_at____respected( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with backend.move_on_at(trio.current_time() + 1) as scope: + await trio.sleep(0.5) + + assert not scope.cancelled_caught() + + async def test____move_on_at____timeout_error( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with backend.move_on_at(trio.current_time() + 0.25) as scope: + await trio.sleep(0.5) + + assert scope.cancelled_caught() + + async def test____move_on_at____cancellation( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with trio.move_on_after(0.10) as root_scope: + with backend.move_on_at(trio.current_time() + 0.25) as inner_scope: + await trio.sleep(0.5) + + assert not inner_scope.cancelled_caught() + assert root_scope.cancelled_caught + + async def test____sleep_forever____sleep_until_cancellation( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with trio.move_on_after(0.5): + await backend.sleep_forever() + + async def test____open_cancel_scope____unbound_cancel_scope____cancel_when_entering( + self, + backend: AsyncBackend, + ) -> None: + scope = backend.open_cancel_scope() + scope.cancel() + assert scope.cancel_called() + + await backend.sleep(0.1) + + with scope: + scope.cancel() + await backend.coro_yield() + + assert scope.cancelled_caught() + + async def test____open_cancel_scope____overwrite_defined_deadline( + self, + backend: AsyncBackend, + ) -> None: + with backend.move_on_after(1) as scope: + await backend.sleep(0.5) + scope.deadline += 1 + await backend.sleep(1) + del scope.deadline + assert scope.deadline == float("+inf") + await backend.sleep(1) + + assert not scope.cancelled_caught() + + async def test____open_cancel_scope____invalid_deadline( + self, + backend: AsyncBackend, + ) -> None: + with pytest.raises(ValueError): + _ = backend.open_cancel_scope(deadline=float("nan")) + + async def test____gather____no_parameters( + self, + backend: AsyncBackend, + ) -> None: + result: list[int] = await backend.gather() + assert result == [] + + async def test____gather____concurrent_await( + self, + backend: AsyncBackend, + ) -> None: + import trio + + async def coroutine(value: int) -> int: + await trio.sleep(0.5) + return value + + with backend.timeout(0.6): + result = await backend.gather( + coroutine(42), + coroutine(54), + ) + + assert result == [42, 54] + + async def test____gather____concurrent_await____exception_raises( + self, + backend: AsyncBackend, + ) -> None: + import trio + + async def coroutine(value: int) -> int: + with trio.CancelScope(shield=True): + await trio.sleep(0.5) + return value + + async def coroutine_error(exception: Exception) -> int: + with trio.CancelScope(shield=True): + await trio.sleep(0.5) + raise exception + + with pytest.raises(ExceptionGroup) as exc_info: + await backend.gather( + coroutine(42), + coroutine_error(ValueError("conversion error")), + coroutine(54), + coroutine_error(KeyError("unknown")), + ) + + assert len(exc_info.value.exceptions) == 2 + + async def test____gather____duplicate_awaitable( + self, + backend: AsyncBackend, + mocker: MockerFixture, + ) -> None: + import trio + + awaited = mocker.async_stub() + + async def coroutine(value: int) -> int: + await awaited() + await trio.sleep(0.5) + return value + + awaitable = coroutine(42) + + result = await backend.gather(awaitable, awaitable, awaitable) + + assert result == [42, 42, 42] + awaited.assert_awaited_once() + + async def test____create_task_group____start_soon( + self, + backend: AsyncBackend, + ) -> None: + import trio + + tasks: list[TaskInfo] = [] + + async def coroutine(value: int) -> int: + tasks.append(backend.get_current_task()) + await trio.sleep(0.5) + return value + + async with backend.create_task_group() as tg: + tg.start_soon(coroutine, 42) + tg.start_soon(coroutine, 54) + await trio.lowlevel.checkpoint() + + assert len(tasks) == 2 + + async def test____create_task_group____start_soon____not_entered( + self, + backend: AsyncBackend, + ) -> None: + import trio + + tasks: list[TaskInfo] = [] + + async def coroutine(value: int) -> int: + tasks.append(backend.get_current_task()) + await trio.sleep(0.5) + return value + + tg = backend.create_task_group() + with pytest.raises(RuntimeError, match=r"^TaskGroup not started$"): + tg.start_soon(coroutine, 42) + with pytest.raises(RuntimeError, match=r"^TaskGroup not started$"): + tg.start_soon(coroutine, 54) + + await trio.sleep(0.5) + + assert len(tasks) == 0 + + async def test____create_task_group____start_soon____set_name( + self, + backend: AsyncBackend, + ) -> None: + import trio + + tasks: list[TaskInfo] = [] + + async def coroutine(value: int) -> int: + tasks.append(backend.get_current_task()) + await trio.sleep(0.5) + return value + + async with backend.create_task_group() as tg: + tg.start_soon(coroutine, 42, name="compute 42") + tg.start_soon(coroutine, 54, name="compute 54") + await trio.lowlevel.checkpoint() + + assert sorted(t.name for t in tasks) == ["compute 42", "compute 54"] + + async def test____create_task_group____start_and_wait( + self, + backend: AsyncBackend, + ) -> None: + import trio + + tasks: list[TaskInfo] = [] + + async def coroutine(value: int) -> int: + tasks.append(backend.get_current_task()) + await trio.sleep(0.5) + return value + + async with backend.create_task_group() as tg: + task_42 = await tg.start(coroutine, 42) + task_54 = await tg.start(coroutine, 54) + + assert len(tasks) == 2 + assert tasks == [task_42.info, task_54.info] + assert not task_42.done() + assert not task_54.done() + + assert task_42.done() + assert task_54.done() + assert not task_42.cancelled() + assert not task_54.cancelled() + assert await task_42.join() == 42 + assert await task_54.join() == 54 + + # Join several should not raise + assert await task_42.join() == 42 + assert await task_54.join() == 54 + + # Task already done cannot be cancelled + assert not task_42.cancel() + assert not task_54.cancel() + assert await task_42.join_or_cancel() == 42 + assert await task_54.join_or_cancel() == 54 + + async def test____create_task_group____start_and_wait____not_entered( + self, + backend: AsyncBackend, + ) -> None: + import trio + + tasks: list[TaskInfo] = [] + + async def coroutine(value: int) -> int: + tasks.append(backend.get_current_task()) + await trio.sleep(0.5) + return value + + tg = backend.create_task_group() + with pytest.raises(RuntimeError, match=r"^TaskGroup not started$"): + await tg.start(coroutine, 42) + with pytest.raises(RuntimeError, match=r"^TaskGroup not started$"): + await tg.start(coroutine, 54) + + await trio.sleep(0.5) + + assert len(tasks) == 0 + + async def test____create_task_group____start_and_wait____set_name( + self, + backend: AsyncBackend, + ) -> None: + import trio + + async def coroutine(value: int) -> int: + await trio.sleep(0.5) + return value + + async with backend.create_task_group() as tg: + task_42 = await tg.start(coroutine, 42, name="compute 42") + task_54 = await tg.start(coroutine, 54, name="compute 54") + + assert task_42.info.name == "compute 42" + assert task_54.info.name == "compute 54" + + async def test____create_task_group____task_cancellation( + self, + backend: AsyncBackend, + ) -> None: + import trio + + async def coroutine(value: int) -> int: + await trio.sleep(0.5) + return value + + async with backend.create_task_group() as tg: + task_42 = await tg.start(coroutine, 42) + task_54 = await tg.start(coroutine, 54) + + await trio.lowlevel.checkpoint() + assert not task_42.done() + assert not task_54.done() + + assert task_42.cancel() + + assert task_42.done() + assert task_54.done() + assert task_42.cancelled() + assert not task_54.cancelled() + with pytest.raises(trio.Cancelled): + await task_42.join() + assert await task_54.join() == 54 + + # Tasks cannot be cancelled twice + assert not task_42.cancel() + + # We can unwrap twice or more + with pytest.raises(trio.Cancelled): + await task_42.join() + + @pytest.mark.parametrize("join_method", ["join", "join_or_cancel", "wait"]) + async def test____create_task_group____task_join_cancel_shielding( + self, + join_method: Literal["join", "join_or_cancel", "wait"], + backend: AsyncBackend, + nursery: Nursery, + ) -> None: + import trio + + async def coroutine(value: int) -> int: + await trio.sleep(0.5) + return value + + async with backend.create_task_group() as task_group: + inner_task = await task_group.start(coroutine, 42) + + with trio.CancelScope() as outer_scope: + call_later_with_nursery(nursery, 0.2, outer_scope.cancel) + + match join_method: + case "join": + await inner_task.join() + case "join_or_cancel": + await inner_task.join_or_cancel() + case "wait": + await inner_task.wait() + case _: + pytest.fail("invalid argument") + + assert outer_scope.cancelled_caught + if join_method == "join_or_cancel": + assert inner_task.cancelled() + else: + assert not inner_task.cancelled() + assert await inner_task.join() == 42 + + async def test____create_task_group____task_join_erase_cancel( + self, + backend: AsyncBackend, + nursery: Nursery, + ) -> None: + import trio + + async def coroutine(value: int) -> int: + with contextlib.suppress(trio.Cancelled): + await trio.sleep(0.5) + return value + + async with backend.create_task_group() as task_group: + inner_task = await task_group.start(coroutine, 42) + + with trio.CancelScope() as outer_scope: + call_later_with_nursery(nursery, 0.1, outer_scope.cancel) + + await inner_task.join_or_cancel() + + assert outer_scope.cancel_called + assert not outer_scope.cancelled_caught + assert not inner_task.cancelled() + assert await inner_task.join() == 42 + + @pytest.mark.parametrize("task_state", ["result", "exception", "cancelled"]) + async def test____create_task_group____task_wait( + self, + task_state: Literal["result", "exception", "cancelled"], + backend: AsyncBackend, + nursery: Nursery, + ) -> None: + import outcome + import trio + + tx, rx = trio.open_memory_channel[outcome.Outcome[None]](1) + + class FutureException(Exception): + pass + + def set_future_result(task: Task[None]) -> None: + with contextlib.closing(tx): + match task_state: + case "result": + tx.send_nowait(outcome.Value(None)) + case "exception": + tx.send_nowait(outcome.Error(FutureException("Error"))) + case "cancelled": + task.cancel() + case _: + pytest.fail("invalid argument") + + async def coroutine() -> None: + async with rx: + return (await rx.receive()).unwrap() + + with pytest.raises(ExceptionGroup) if task_state == "exception" else contextlib.nullcontext() as exc_info: + async with backend.create_task_group() as task_group: + task = await task_group.start(coroutine) + + call_later_with_nursery(nursery, 0.1, set_future_result, task) + + await task.wait() + assert task.done() + + # Must not yield if task is already done + with backend.timeout(0): + await task.wait() + + if exc_info is not None: + assert isinstance(exc_info.value.exceptions[0], FutureException) + + async def test____run_in_thread____cannot_be_cancelled_by_default( + self, + backend: AsyncBackend, + nursery: Nursery, + ) -> None: + import trio + + with trio.CancelScope() as scope: + call_later_with_nursery(nursery, 0.1, scope.cancel) + call_later_with_nursery(nursery, 0.2, scope.cancel) + call_later_with_nursery(nursery, 0.3, scope.cancel) + call_later_with_nursery(nursery, 0.4, scope.cancel) + await backend.run_in_thread(time.sleep, 0.5) + + assert not scope.cancelled_caught + + async def test____run_in_thread____abandon_on_cancel( + self, + backend: AsyncBackend, + ) -> None: + import trio + + with trio.move_on_after(0.1) as scope: + await backend.run_in_thread(time.sleep, 0.5, abandon_on_cancel=True) + + assert scope.cancelled_caught + + async def test____run_in_thread____sniffio_contextvar_reset(self, backend: AsyncBackend) -> None: + sniffio.current_async_library_cvar.set("trio") + + def callback() -> str | None: + return sniffio.current_async_library_cvar.get() + + cvar_inner = await backend.run_in_thread(callback) + cvar_outer = sniffio.current_async_library_cvar.get() + + assert cvar_inner is None + assert cvar_outer == "trio" + + async def test____create_threads_portal____run_coroutine_from_thread( + self, + backend: AsyncBackend, + ) -> None: + import trio + + trio_token = trio.lowlevel.current_trio_token() + threads_portal = backend.create_threads_portal() + + async def coroutine(value: int) -> int: + assert trio.lowlevel.current_trio_token() is trio_token + await trio.sleep(0.5) + return value + + def thread() -> int: + with pytest.raises(RuntimeError): + trio.lowlevel.current_trio_token() + return threads_portal.run_coroutine(coroutine, 42) + + with pytest.raises(RuntimeError): + await backend.run_in_thread(thread) + + async with threads_portal: + with pytest.raises(RuntimeError): + threads_portal.run_coroutine(coroutine, 42) + + assert await backend.run_in_thread(thread) == 42 + + with pytest.raises(RuntimeError): + await backend.run_in_thread(thread) + + async def test____create_threads_portal____run_coroutine_from_thread____can_be_called_from_other_event_loop( + self, + backend: AsyncBackend, + ) -> None: + import trio + + trio_token = trio.lowlevel.current_trio_token() + threads_portal = backend.create_threads_portal() + + async def coroutine(value: int) -> int: + assert trio.lowlevel.current_trio_token() is trio_token + await trio.sleep(0.5) + return value + + def thread() -> int: + async def main() -> int: + assert trio.lowlevel.current_trio_token() is not trio_token + return threads_portal.run_coroutine(coroutine, 42) + + return trio.run(main) + + async with backend.create_threads_portal() as threads_portal: + assert await backend.run_in_thread(thread) == 42 + + async def test____create_threads_portal____run_coroutine_from_thread____coroutine_cancelled( + self, + backend: AsyncBackend, + ) -> None: + import outcome + import trio + + async def coroutine(value: int) -> int: + with trio.move_on_after(0.5): + result = await outcome.acapture(backend.sleep_forever) + result.unwrap() + raise AssertionError("Not cancelled") + + def thread() -> int: + return threads_portal.run_coroutine(coroutine, 42) + + async with backend.create_threads_portal() as threads_portal: + with pytest.raises(concurrent.futures.CancelledError): + await backend.run_in_thread(thread) + + @pytest.mark.parametrize("exception_cls", [Exception, BaseException]) + async def test____create_threads_portal____run_coroutine_from_thread____exception_raised( + self, + backend: AsyncBackend, + exception_cls: type[BaseException], + ) -> None: + expected_exception = exception_cls("Why not?") + + async def coroutine(value: int) -> int: + raise expected_exception + + def thread() -> int: + return threads_portal.run_coroutine(coroutine, 42) + + threads_portal = backend.create_threads_portal() + async with threads_portal: + with pytest.raises(BaseException) as exc_info: + await backend.run_in_thread(thread) + + assert exc_info.value is expected_exception + + async def test____create_threads_portal____run_coroutine_from_thread____explicit_concurrent_future_Cancelled( + self, + backend: AsyncBackend, + ) -> None: + async def coroutine(value: int) -> int: + raise concurrent.futures.CancelledError() + + def thread() -> int: + with pytest.raises(concurrent.futures.CancelledError): + return threads_portal.run_coroutine(coroutine, 42) + return 54 + + async with backend.create_threads_portal() as threads_portal: + assert await backend.run_in_thread(thread) == 54 + + async def test____create_threads_portal____run_coroutine_from_thread____sniffio_contextvar_reset( + self, + backend: AsyncBackend, + ) -> None: + sniffio.current_async_library_cvar.set("main") + + async def coroutine() -> str | None: + return sniffio.current_async_library_cvar.get() + + def thread() -> str | None: + return threads_portal.run_coroutine(coroutine) + + async with backend.create_threads_portal() as threads_portal: + cvar_inner = await backend.run_in_thread(thread) + cvar_outer = sniffio.current_async_library_cvar.get() + + assert cvar_inner is None + assert cvar_outer == "main" + + async def test____create_threads_portal____run_sync_from_thread_in_event_loop( + self, + backend: AsyncBackend, + ) -> None: + import trio + + trio_token = trio.lowlevel.current_trio_token() + threads_portal = backend.create_threads_portal() + + def not_threadsafe_func(value: int) -> int: + assert trio.lowlevel.current_trio_token() is trio_token + return value + + def thread() -> int: + with pytest.raises(RuntimeError): + trio.lowlevel.current_trio_token() + return threads_portal.run_sync(not_threadsafe_func, 42) + + with pytest.raises(RuntimeError): + await backend.run_in_thread(thread) + + async with threads_portal: + with pytest.raises(RuntimeError): + threads_portal.run_sync(not_threadsafe_func, 42) + + assert await backend.run_in_thread(thread) == 42 + + with pytest.raises(RuntimeError): + await backend.run_in_thread(thread) + + async def test____create_threads_portal____run_sync_from_thread_in_event_loop____can_be_called_from_other_event_loop( + self, + backend: AsyncBackend, + ) -> None: + import trio + + trio_token = trio.lowlevel.current_trio_token() + + def not_threadsafe_func(value: int) -> int: + assert trio.lowlevel.current_trio_token() is trio_token + return value + + def thread() -> int: + async def main() -> int: + assert trio.lowlevel.current_trio_token() is not trio_token + return threads_portal.run_sync(not_threadsafe_func, 42) + + return trio.run(main) + + async with backend.create_threads_portal() as threads_portal: + assert await backend.run_in_thread(thread) == 42 + + @pytest.mark.parametrize("exception_cls", [Exception, BaseException]) + async def test____create_threads_portal____run_sync_from_thread_in_event_loop____exception_raised( + self, + backend: AsyncBackend, + exception_cls: type[BaseException], + ) -> None: + expected_exception = exception_cls("Why not?") + + def not_threadsafe_func(value: int) -> int: + raise expected_exception + + def thread() -> int: + return threads_portal.run_sync(not_threadsafe_func, 42) + + threads_portal = backend.create_threads_portal() + + async with threads_portal: + with pytest.raises(BaseException) as exc_info: + await backend.run_in_thread(thread) + + assert exc_info.value is expected_exception + + async def test____create_threads_portal____run_sync_from_thread_in_event_loop____explicit_concurrent_future_Cancelled( + self, + backend: AsyncBackend, + ) -> None: + def not_threadsafe_func(value: int) -> int: + raise concurrent.futures.CancelledError() + + def thread() -> int: + with pytest.raises(concurrent.futures.CancelledError): + return threads_portal.run_sync(not_threadsafe_func, 42) + return 54 + + async with backend.create_threads_portal() as threads_portal: + assert await backend.run_in_thread(thread) == 54 + + async def test____create_threads_portal____run_sync_from_thread_in_event_loop____async_function_given( + self, + backend: AsyncBackend, + ) -> None: + async def coroutine() -> None: + raise AssertionError("Should not be called") + + def thread() -> None: + with pytest.raises(TypeError, match=r"^func is a coroutine function.$") as exc_info: + _ = threads_portal.run_sync(coroutine) + + assert exc_info.value.__notes__ == ["You should use run_coroutine() or run_coroutine_soon() instead."] + + async with backend.create_threads_portal() as threads_portal: + await backend.run_in_thread(thread) + + async def test____create_threads_portal____run_sync_from_thread_in_event_loop____sniffio_contextvar_reset( + self, + backend: AsyncBackend, + ) -> None: + sniffio.current_async_library_cvar.set("main") + + def callback() -> str | None: + return sniffio.current_async_library_cvar.get() + + def thread() -> str | None: + return threads_portal.run_sync(callback) + + async with backend.create_threads_portal() as threads_portal: + cvar_inner = await backend.run_in_thread(thread) + cvar_outer = sniffio.current_async_library_cvar.get() + + assert cvar_inner is None + assert cvar_outer == "main" + + async def test____create_threads_portal____run_sync_soon____future_cancelled_before_call( + self, + backend: AsyncBackend, + mocker: MockerFixture, + ) -> None: + import trio + + trio_token = trio.lowlevel.current_trio_token() + func_stub = mocker.stub() + + def thread() -> None: + trio_token.run_sync_soon(time.sleep, 1) # Drastically slow down event loop + + future = threads_portal.run_sync_soon(func_stub, 42) + + with pytest.raises(TimeoutError): + future.exception(timeout=0.2) + + future.cancel() + concurrent.futures.wait({future}, timeout=5) # Test if future.set_running_or_notify_cancel() have been called + assert future.cancelled() + + async with backend.create_threads_portal() as threads_portal: + await backend.run_in_thread(thread) + + func_stub.assert_not_called() + + async def test____create_threads_portal____run_coroutine_soon____future_cancelled( + self, + backend: AsyncBackend, + ) -> None: + def thread() -> None: + future = threads_portal.run_coroutine_soon(backend.sleep, 1) + + with pytest.raises(TimeoutError): + future.exception(timeout=0.2) + + future.cancel() + concurrent.futures.wait({future}, timeout=0.2) # Test if future.set_running_or_notify_cancel() have been called + assert future.cancelled() + + async with backend.create_threads_portal() as threads_portal: + await backend.run_in_thread(thread) + + @pytest.mark.parametrize("value", [42, ValueError("Not caught")], ids=repr) + async def test____create_threads_portal____run_coroutine_soon____future_cancelled____cancellation_ignored( + self, + value: int | Exception, + backend: AsyncBackend, + mocker: MockerFixture, + caplog: pytest.LogCaptureFixture, + ) -> None: + import logging + + import trio + + caplog.set_level(logging.ERROR, "trio") + + cancellation_ignored = mocker.stub() + + subcoroutine_task: list[trio.lowlevel.Task] = [] + + async def coroutine() -> int: + try: + await trio.sleep(1) + except trio.Cancelled: + pass + await trio.lowlevel.cancel_shielded_checkpoint() + cancellation_ignored() + subcoroutine_task.append(trio.lowlevel.current_task()) + if isinstance(value, Exception): + raise value + return value + + def thread() -> None: + future = threads_portal.run_coroutine_soon(coroutine) + + with pytest.raises(TimeoutError): + future.exception(timeout=0.2) + + future.cancel() + concurrent.futures.wait({future}, timeout=0.2) # Test if future.set_running_or_notify_cancel() have been called + assert future.cancelled() + + async with backend.create_threads_portal() as threads_portal: + await backend.run_in_thread(thread) + + cancellation_ignored.assert_called_once() + if isinstance(value, Exception): + assert len(caplog.records) == 1 + assert caplog.records[0].levelno == logging.ERROR + assert caplog.records[0].exc_info is not None + assert caplog.records[0].exc_info[1] is value + assert caplog.records[0].message == "\n".join( + [ + "Task exception was not retrieved because future object is cancelled", + f"task: {subcoroutine_task[0]!r}", + ] + ) + else: + assert len(caplog.records) == 0 + + async def test____create_threads_portal____entered_twice( + self, + backend: AsyncBackend, + ) -> None: + async with backend.create_threads_portal() as threads_portal: + with pytest.raises(RuntimeError, match=r"ThreadsPortal entered twice\."): + await threads_portal.__aenter__() diff --git a/tests/functional_test/test_async/test_futures.py b/tests/functional_test/test_async/test_futures.py index f8a3ae07..e567b4e2 100644 --- a/tests/functional_test/test_async/test_futures.py +++ b/tests/functional_test/test_async/test_futures.py @@ -6,7 +6,8 @@ from collections.abc import AsyncIterator from typing import Any -from easynetwork.lowlevel.api_async.backend._asyncio.backend import AsyncIOBackend +from easynetwork.lowlevel.api_async.backend.abc import AsyncBackend +from easynetwork.lowlevel.api_async.backend.utils import new_builtin_backend from easynetwork.lowlevel.futures import AsyncExecutor, unwrap_future import pytest @@ -126,12 +127,12 @@ async def test____shutdown____cancel_futures( class TestUnwrapFuture: @pytest.fixture @staticmethod - def backend() -> AsyncIOBackend: - return AsyncIOBackend() + def backend() -> AsyncBackend: + return new_builtin_backend("asyncio") async def test____unwrap_future____wait_until_done( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() future: concurrent.futures.Future[int] = concurrent.futures.Future() @@ -143,7 +144,7 @@ async def test____unwrap_future____wait_until_done( async def test____unwrap_future____cancel_future_if_task_is_cancelled____result( self, future_running: str | None, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() future: concurrent.futures.Future[int] = concurrent.futures.Future() @@ -176,7 +177,7 @@ async def test____unwrap_future____cancel_future_if_task_is_cancelled____result( async def test____unwrap_future____cancel_future_if_task_is_cancelled____exception( self, future_running: str | None, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() expected_error = Exception("error") @@ -208,7 +209,7 @@ async def test____unwrap_future____cancel_future_if_task_is_cancelled____excepti async def test____unwrap_future____future_is_cancelled( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() future: concurrent.futures.Future[int] = concurrent.futures.Future() @@ -222,7 +223,7 @@ async def test____unwrap_future____future_is_cancelled( async def test____unwrap_future____already_done( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: future: concurrent.futures.Future[int] = concurrent.futures.Future() future.set_result(42) @@ -231,7 +232,7 @@ async def test____unwrap_future____already_done( async def test____unwrap_future____already_cancelled( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: future: concurrent.futures.Future[int] = concurrent.futures.Future() future.cancel() @@ -241,7 +242,7 @@ async def test____unwrap_future____already_cancelled( async def test____unwrap_future____already_cancelled____task_cancelled_too( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() future: concurrent.futures.Future[int] = concurrent.futures.Future() @@ -255,7 +256,7 @@ async def test____unwrap_future____already_cancelled____task_cancelled_too( async def test____unwrap_future____task_cancellation_prevails_over_future_cancellation( self, - backend: AsyncIOBackend, + backend: AsyncBackend, ) -> None: event_loop = asyncio.get_running_loop() future: concurrent.futures.Future[int] = concurrent.futures.Future() diff --git a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py index 38d06ad7..1892a1fa 100644 --- a/tests/functional_test/test_communication/test_async/test_server/test_tcp.py +++ b/tests/functional_test/test_communication/test_async/test_server/test_tcp.py @@ -16,8 +16,8 @@ IncrementalDeserializeError, StreamProtocolParseError, ) -from easynetwork.lowlevel.api_async.backend._asyncio._asyncio_utils import create_connection from easynetwork.lowlevel.api_async.backend._asyncio.backend import AsyncIOBackend +from easynetwork.lowlevel.api_async.backend._asyncio.dns_resolver import AsyncIODNSResolver from easynetwork.lowlevel.api_async.backend._asyncio.stream.listener import ListenerSocketAdapter from easynetwork.lowlevel.socket import SocketAddress, enable_socket_linger from easynetwork.protocol import AnyStreamProtocolType @@ -440,6 +440,7 @@ def run_server_and_wait(run_server: None, server_address: Any) -> None: @pytest_asyncio.fixture @staticmethod async def client_factory_no_handshake( + asyncio_backend: AsyncIOBackend, server_address: tuple[str, int], use_ssl: bool, client_ssl_context: ssl.SSLContext, @@ -449,7 +450,7 @@ async def client_factory_no_handshake( async def factory() -> tuple[asyncio.StreamReader, asyncio.StreamWriter]: async with asyncio.timeout(30): - sock = await create_connection(*server_address) + sock = await AsyncIODNSResolver().create_stream_connection(asyncio_backend, *server_address) reader, writer = await asyncio.open_connection( sock=sock, ssl=client_ssl_context if use_ssl else None, @@ -1136,6 +1137,7 @@ async def test____serve_forever____request_handler_on_connection_is_async_gen___ @pytest.mark.parametrize("ssl_handshake_timeout", [pytest.param(1, id="timeout==1sec")], indirect=True) async def test____serve_forever____ssl_handshake_timeout_error( self, + asyncio_backend: AsyncIOBackend, server_address: tuple[str, int], caplog: pytest.LogCaptureFixture, logger_crash_maximum_nb_lines: dict[str, int], @@ -1145,7 +1147,10 @@ async def test____serve_forever____ssl_handshake_timeout_error( caplog.set_level(logging.ERROR, LOGGER.name) logger_crash_maximum_nb_lines[LOGGER.name] = 1 logger_crash_maximum_nb_lines["easynetwork.lowlevel.api_async.transports.tls"] = 1 - with await create_connection(*server_address) as socket, pytest.raises(OSError): + with ( + await AsyncIODNSResolver().create_stream_connection(asyncio_backend, *server_address) as socket, + pytest.raises(OSError), + ): # The SSL handshake expects the client to send the list of encryption algorithms. # But we won't, so the server will close the connection after 1 second # and raise a TimeoutError or ConnectionAbortedError. diff --git a/tests/functional_test/test_communication/test_end2end.py b/tests/functional_test/test_communication/test_end2end.py new file mode 100644 index 00000000..9947886e --- /dev/null +++ b/tests/functional_test/test_communication/test_end2end.py @@ -0,0 +1,175 @@ +from __future__ import annotations + +from collections.abc import AsyncGenerator, Iterator + +from easynetwork.clients.async_tcp import AsyncTCPNetworkClient +from easynetwork.clients.async_udp import AsyncUDPNetworkClient +from easynetwork.clients.tcp import TCPNetworkClient +from easynetwork.clients.udp import UDPNetworkClient +from easynetwork.lowlevel.api_async.backend.utils import BuiltinAsyncBackendLiteral +from easynetwork.protocol import AnyStreamProtocolType, DatagramProtocol +from easynetwork.servers.abc import AbstractNetworkServer +from easynetwork.servers.handlers import AsyncBaseClientInterface, AsyncDatagramRequestHandler, AsyncStreamRequestHandler +from easynetwork.servers.standalone_tcp import StandaloneTCPNetworkServer +from easynetwork.servers.standalone_udp import StandaloneUDPNetworkServer +from easynetwork.servers.threads_helper import NetworkServerThread + +import pytest + + +class EchoRequestHandler(AsyncStreamRequestHandler[str, str], AsyncDatagramRequestHandler[str, str]): + async def handle(self, client: AsyncBaseClientInterface[str]) -> AsyncGenerator[None, str]: + request = yield + await client.send_packet(request) + + +@pytest.mark.flaky(retries=3, delay=1) +class BaseTestNetworkServer: + @pytest.fixture( + params=[ + pytest.param("asyncio"), + pytest.param("trio", marks=pytest.mark.feature_trio(async_test_auto_mark=False)), + ], + ids=lambda p: f"server_backend=={p!r}", + ) + @staticmethod + def server_backend(request: pytest.FixtureRequest) -> BuiltinAsyncBackendLiteral: + return request.param + + @pytest.fixture( + params=[ + pytest.param("asyncio", marks=pytest.mark.asyncio), + pytest.param("trio", marks=pytest.mark.feature_trio(async_test_auto_mark=True)), + ], + ids=lambda p: f"async_client_backend=={p!r}", + ) + @staticmethod + def async_client_backend(request: pytest.FixtureRequest) -> BuiltinAsyncBackendLiteral: + return request.param + + @pytest.fixture(autouse=True) + @staticmethod + def start_server( + server: AbstractNetworkServer, + ) -> Iterator[NetworkServerThread]: + with server: + server_thread = NetworkServerThread(server, daemon=True) + server_thread.start() + + yield server_thread + + server_thread.join(timeout=1) + + +class TestNetworkTCP(BaseTestNetworkServer): + @pytest.fixture + @staticmethod + def server( + server_backend: BuiltinAsyncBackendLiteral, + stream_protocol: AnyStreamProtocolType[str, str], + ) -> StandaloneTCPNetworkServer[str, str]: + return StandaloneTCPNetworkServer("127.0.0.1", 0, stream_protocol, EchoRequestHandler(), backend=server_backend) + + @pytest.fixture + @staticmethod + def server_address(server: StandaloneTCPNetworkServer[str, str]) -> tuple[str, int]: + port = server.get_addresses()[0].port + return ("localhost", port) + + def test____blocking_client____echo( + self, + server_address: tuple[str, int], + stream_protocol: AnyStreamProtocolType[str, str], + ) -> None: + + with TCPNetworkClient(server_address, stream_protocol, connect_timeout=1) as client: + + # Sequential read/write + for i in range(3): + client.send_packet(f"Hello world {i}") + assert client.recv_packet(timeout=1) == f"Hello world {i}" + + # Several write + for i in range(3): + client.send_packet(f"Hello world {i}") + for i in range(3): + assert client.recv_packet(timeout=1) == f"Hello world {i}" + + async def test____asynchronous_client____echo( + self, + async_client_backend: BuiltinAsyncBackendLiteral, + server_address: tuple[str, int], + stream_protocol: AnyStreamProtocolType[str, str], + ) -> None: + + async with AsyncTCPNetworkClient(server_address, stream_protocol, backend=async_client_backend) as client: + + # Sequential read/write + for i in range(3): + await client.send_packet(f"Hello world {i}") + with client.backend().timeout(1): + assert (await client.recv_packet()) == f"Hello world {i}" + + # Several write + for i in range(3): + await client.send_packet(f"Hello world {i}") + for i in range(3): + with client.backend().timeout(1): + assert (await client.recv_packet()) == f"Hello world {i}" + + +class TestNetworkUDP(BaseTestNetworkServer): + @pytest.fixture + @staticmethod + def server( + server_backend: BuiltinAsyncBackendLiteral, + datagram_protocol: DatagramProtocol[str, str], + ) -> StandaloneUDPNetworkServer[str, str]: + return StandaloneUDPNetworkServer("127.0.0.1", 0, datagram_protocol, EchoRequestHandler(), backend=server_backend) + + @pytest.fixture + @staticmethod + def server_address(server: StandaloneUDPNetworkServer[str, str]) -> tuple[str, int]: + port = server.get_addresses()[0].port + return ("127.0.0.1", port) + + def test____blocking_client____echo( + self, + server_address: tuple[str, int], + datagram_protocol: DatagramProtocol[str, str], + ) -> None: + + with UDPNetworkClient(server_address, datagram_protocol) as client: + + # Sequential read/write + for i in range(3): + client.send_packet(f"Hello world {i}") + assert client.recv_packet(timeout=1) == f"Hello world {i}" + + # Several write + for i in range(3): + client.send_packet(f"Hello world {i}") + for i in range(3): + assert client.recv_packet(timeout=1) == f"Hello world {i}" + + async def test____asynchronous_client____echo( + self, + async_client_backend: BuiltinAsyncBackendLiteral, + server_address: tuple[str, int], + datagram_protocol: DatagramProtocol[str, str], + ) -> None: + + async with AsyncUDPNetworkClient(server_address, datagram_protocol, backend=async_client_backend) as client: + + # Sequential read/write + for i in range(3): + await client.send_packet(f"Hello world {i}") + with client.backend().timeout(1): + assert (await client.recv_packet()) == f"Hello world {i}" + + # Several write + for i in range(3): + await client.send_packet(f"Hello world {i}") + for i in range(3): + with client.backend().timeout(1): + assert (await client.recv_packet()) == f"Hello world {i}" diff --git a/tests/pytest_plugins/async_finalizer.py b/tests/pytest_plugins/async_finalizer.py index 3103232e..a62a551e 100644 --- a/tests/pytest_plugins/async_finalizer.py +++ b/tests/pytest_plugins/async_finalizer.py @@ -7,6 +7,8 @@ import pytest_asyncio +from ..fixtures.trio import trio_fixture + @dataclass(repr=False, eq=False, kw_only=True, frozen=True, slots=True, weakref_slot=True) class AsyncFinalizer: @@ -15,6 +17,24 @@ class AsyncFinalizer: @pytest_asyncio.fixture async def async_finalizer(request: Any) -> AsyncIterator[AsyncFinalizer]: + import asyncio + + asyncio.get_running_loop() + + async with contextlib.AsyncExitStack() as stack: + + def add_finalizer(f: Callable[[], Awaitable[Any]]) -> None: + stack.push_async_callback(f) + + yield AsyncFinalizer(add_finalizer=add_finalizer) + + +@trio_fixture +async def async_finalizer_trio(request: Any) -> AsyncIterator[AsyncFinalizer]: + import trio + + trio.lowlevel.current_trio_token() + async with contextlib.AsyncExitStack() as stack: def add_finalizer(f: Callable[[], Awaitable[Any]]) -> None: diff --git a/tests/pytest_plugins/extra_features.py b/tests/pytest_plugins/extra_features.py index 1538d69b..8d2b9594 100644 --- a/tests/pytest_plugins/extra_features.py +++ b/tests/pytest_plugins/extra_features.py @@ -1,6 +1,7 @@ from __future__ import annotations import functools +import inspect import pytest @@ -24,6 +25,22 @@ def _auto_add_feature_marker(item: pytest.Item) -> None: item.add_marker(pytest.mark.feature) +def __has_marker(item: pytest.Item, name: str) -> bool: + return item.get_closest_marker(name) is not None + + +def _ensure_trio_marker_consistency(item: pytest.Function) -> None: + if (feature_trio_marker := item.get_closest_marker("feature_trio")) is not None: + if item.config.pluginmanager.has_plugin("trio") and not __has_marker(item, "trio"): + auto_mark = feature_trio_marker.kwargs.get("async_test_auto_mark", False) + if auto_mark and inspect.iscoroutinefunction(item.obj): + item.add_marker(pytest.mark.trio) + elif __has_marker(item, "trio"): + item.add_marker(pytest.mark.feature_trio) + + def pytest_collection_modifyitems(items: list[pytest.Item]) -> None: for item in items: + if funcitem := item.getparent(pytest.Function): + _ensure_trio_marker_consistency(funcitem) _auto_add_feature_marker(item) diff --git a/tests/tools.py b/tests/tools.py index d2c0ffe4..e76afe43 100644 --- a/tests/tools.py +++ b/tests/tools.py @@ -5,17 +5,21 @@ import importlib import sys import time -from collections.abc import Generator, Iterator -from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypeVar, assert_never, final +from collections.abc import Callable, Generator, Iterator +from typing import TYPE_CHECKING, Any, Literal, TypeAlias, TypeVar, TypeVarTuple, assert_never, final import pytest if TYPE_CHECKING: + import trio + from _typeshed import WriteableBuffer _T_contra = TypeVar("_T_contra", contravariant=True) _V_co = TypeVar("_V_co", covariant=True) +_T_Args = TypeVarTuple("_T_Args") + def _make_skipif_platform(platform: str, reason: str) -> pytest.MarkDecorator: return pytest.mark.skipif(sys.platform.startswith(platform), reason=reason) @@ -170,3 +174,42 @@ def temporary_task_factory( stack.callback(event_loop.set_task_factory, event_loop.get_task_factory()) event_loop.set_task_factory(task_factory) yield + + +def call_later_with_nursery( + nursery: trio.Nursery, + seconds: float, + func: Callable[[*_T_Args], Any], + /, + *args: *_T_Args, +) -> trio.CancelScope: + from trio import CancelScope, sleep + + scope = CancelScope() + + async def in_nursery_task() -> None: + with scope: + await sleep(seconds) + func(*args) + + nursery.start_soon(in_nursery_task) + return scope + + +def call_soon_with_nursery( + nursery: trio.Nursery, + func: Callable[[*_T_Args], Any], + /, + *args: *_T_Args, +) -> trio.CancelScope: + from trio import CancelScope + + scope = CancelScope() + + async def in_nursery_task() -> None: + with scope: + if not scope.cancel_called: + func(*args) + + nursery.start_soon(in_nursery_task) + return scope diff --git a/tests/unit_test/conftest.py b/tests/unit_test/conftest.py index 6637f403..af1efb5e 100644 --- a/tests/unit_test/conftest.py +++ b/tests/unit_test/conftest.py @@ -37,12 +37,18 @@ def remove_SO_REUSEPORT_support(monkeypatch: pytest.MonkeyPatch) -> None: @pytest.fixture def mock_socket_factory(mocker: MockerFixture) -> Callable[[], MagicMock]: - def factory() -> MagicMock: + def factory(family: int = -1, type: int = -1, proto: int = -1, fileno: int = 123) -> MagicMock: + if family == -1: + family = AF_INET + if type == -1: + type = SOCK_STREAM + if proto == -1: + proto = 0 mock_socket = mocker.NonCallableMagicMock(spec=Socket) - mock_socket.family = AF_INET - mock_socket.type = -1 - mock_socket.proto = 0 - mock_socket.fileno.return_value = 123 + mock_socket.family = family + mock_socket.type = type + mock_socket.proto = proto + mock_socket.fileno.return_value = fileno def close_side_effect() -> None: mock_socket.fileno.return_value = -1 @@ -64,12 +70,9 @@ def original_socket_cls() -> type[Socket]: @pytest.fixture -def mock_tcp_socket_factory(mock_socket_factory: Callable[[], MagicMock]) -> Callable[[], MagicMock]: - def factory() -> MagicMock: - mock_socket = mock_socket_factory() - mock_socket.type = SOCK_STREAM - mock_socket.proto = IPPROTO_TCP - return mock_socket +def mock_tcp_socket_factory(mock_socket_factory: Callable[[int, int, int], MagicMock]) -> Callable[[], MagicMock]: + def factory(family: int = -1) -> MagicMock: + return mock_socket_factory(family, SOCK_STREAM, IPPROTO_TCP) return factory @@ -80,12 +83,9 @@ def mock_tcp_socket(mock_tcp_socket_factory: Callable[[], MagicMock]) -> MagicMo @pytest.fixture -def mock_udp_socket_factory(mock_socket_factory: Callable[[], MagicMock]) -> Callable[[], MagicMock]: - def factory() -> MagicMock: - mock_socket = mock_socket_factory() - mock_socket.type = SOCK_DGRAM - mock_socket.proto = IPPROTO_UDP - return mock_socket +def mock_udp_socket_factory(mock_socket_factory: Callable[[int, int, int], MagicMock]) -> Callable[[], MagicMock]: + def factory(family: int = -1) -> MagicMock: + return mock_socket_factory(family, SOCK_DGRAM, IPPROTO_UDP) return factory @@ -97,9 +97,11 @@ def mock_udp_socket(mock_udp_socket_factory: Callable[[], MagicMock]) -> MagicMo @pytest.fixture def mock_ssl_socket_factory(mocker: MockerFixture) -> Callable[[], MagicMock]: - def factory() -> MagicMock: + def factory(family: int = -1) -> MagicMock: + if family == -1: + family = AF_INET mock_socket = mocker.NonCallableMagicMock(spec=SSLSocket) - mock_socket.family = AF_INET + mock_socket.family = family mock_socket.type = SOCK_STREAM mock_socket.proto = IPPROTO_TCP mock_socket.fileno.return_value = 123 diff --git a/tests/unit_test/test_async/conftest.py b/tests/unit_test/test_async/conftest.py index 17f2e90c..4df0b979 100644 --- a/tests/unit_test/test_async/conftest.py +++ b/tests/unit_test/test_async/conftest.py @@ -27,18 +27,6 @@ class FakeCancellation(BaseException): pass -@pytest.fixture(scope="module", autouse=True) -def __mute_socket_getaddrinfo(module_mocker: MockerFixture) -> None: - from socket import EAI_NONAME, gaierror, getaddrinfo - - module_mocker.patch( - "socket.getaddrinfo", - autospec=True, - wraps=getaddrinfo, - side_effect=gaierror(EAI_NONAME, "Name or service not known"), - ) - - @pytest.fixture(autouse=True) def __increase_event_loop_execution_time_before_warning(event_loop: asyncio.AbstractEventLoop) -> None: event_loop.slow_callback_duration = 5.0 diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_backend.py b/tests/unit_test/test_async/test_asyncio_backend/test_backend.py index bb5fecf5..af121cc2 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_backend.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_backend.py @@ -3,11 +3,12 @@ import asyncio import contextvars from collections.abc import Callable, Coroutine, Sequence -from socket import AF_INET, AF_INET6, AF_UNSPEC, AI_ADDRCONFIG, AI_PASSIVE, IPPROTO_TCP, IPPROTO_UDP, SOCK_DGRAM, SOCK_STREAM +from socket import AF_INET, AF_INET6, AF_UNSPEC, IPPROTO_TCP, IPPROTO_UDP, SOCK_DGRAM, SOCK_STREAM from typing import TYPE_CHECKING, Any, Final from easynetwork.lowlevel.api_async.backend._asyncio.backend import AsyncIOBackend from easynetwork.lowlevel.api_async.backend._asyncio.datagram.listener import DatagramListenerProtocol +from easynetwork.lowlevel.api_async.backend._asyncio.dns_resolver import AsyncIODNSResolver from easynetwork.lowlevel.api_async.backend._asyncio.stream.listener import AbstractAcceptedSocketFactory, AcceptedSocketFactory from easynetwork.lowlevel.api_async.backend._asyncio.stream.socket import StreamReaderBufferedProtocol @@ -192,6 +193,65 @@ async def test____get_current_task____compute_task_info( assert task_info.name == current_task.get_name() assert task_info.coro is current_task.get_coro() + async def test____getaddrinfo____use_loop_getaddrinfo( + self, + backend: AsyncIOBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + event_loop = asyncio.get_running_loop() + mock_loop_getaddrinfo = mocker.patch.object( + event_loop, + "getaddrinfo", + new_callable=mocker.AsyncMock, + return_value=mocker.sentinel.addrinfo_list, + ) + + # Act + addrinfo_list = await backend.getaddrinfo( + host=mocker.sentinel.host, + port=mocker.sentinel.port, + family=mocker.sentinel.family, + type=mocker.sentinel.type, + proto=mocker.sentinel.proto, + flags=mocker.sentinel.flags, + ) + + # Assert + assert addrinfo_list is mocker.sentinel.addrinfo_list + mock_loop_getaddrinfo.assert_awaited_once_with( + mocker.sentinel.host, + mocker.sentinel.port, + family=mocker.sentinel.family, + type=mocker.sentinel.type, + proto=mocker.sentinel.proto, + flags=mocker.sentinel.flags, + ) + + async def test____getnameinfo____use_loop_getnameinfo( + self, + backend: AsyncIOBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + event_loop = asyncio.get_running_loop() + mock_loop_getnameinfo = mocker.patch.object( + event_loop, + "getnameinfo", + new_callable=mocker.AsyncMock, + return_value=mocker.sentinel.resolved_addr, + ) + + # Act + resolved_addr = await backend.getnameinfo( + sockaddr=mocker.sentinel.sockaddr, + flags=mocker.sentinel.flags, + ) + + # Assert + assert resolved_addr is mocker.sentinel.resolved_addr + mock_loop_getnameinfo.assert_awaited_once_with(mocker.sentinel.sockaddr, mocker.sentinel.flags) + @pytest.mark.parametrize("happy_eyeballs_delay", [None, 42], ids=lambda p: f"happy_eyeballs_delay=={p}") async def test____create_tcp_connection____use_loop_create_connection( self, @@ -216,8 +276,9 @@ async def test____create_tcp_connection____use_loop_create_connection( new_callable=mocker.AsyncMock, return_value=(mock_asyncio_transport, mock_protocol), ) - mock_own_create_connection: AsyncMock = mocker.patch( - f"{_ASYNCIO_BACKEND_MODULE}._asyncio_utils.create_connection", + mock_own_create_connection: AsyncMock = mocker.patch.object( + AsyncIODNSResolver, + "create_stream_connection", new_callable=mocker.AsyncMock, return_value=mock_tcp_socket, ) @@ -235,6 +296,7 @@ async def test____create_tcp_connection____use_loop_create_connection( # Assert mock_own_create_connection.assert_awaited_once_with( + backend, *remote_address, happy_eyeballs_delay=expected_happy_eyeballs_delay, local_address=local_address, @@ -286,7 +348,6 @@ async def test____create_tcp_listeners____open_listener_sockets( mocker: MockerFixture, ) -> None: # Arrange - event_loop = asyncio.get_running_loop() remote_host, remote_port = "remote_address", 5000 addrinfo_list = [ ( @@ -297,9 +358,9 @@ async def test____create_tcp_listeners____open_listener_sockets( (remote_host, remote_port), ) ] - mock_getaddrinfo = mocker.patch.object( - event_loop, - "getaddrinfo", + mock_resolve_listener_addresses = mocker.patch.object( + AsyncIODNSResolver, + "resolve_listener_addresses", new_callable=mocker.AsyncMock, return_value=addrinfo_list, ) @@ -322,16 +383,14 @@ async def test____create_tcp_listeners____open_listener_sockets( ) # Assert - mock_getaddrinfo.assert_awaited_once_with( - remote_host, + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + [remote_host], remote_port, - family=AF_UNSPEC, - type=SOCK_STREAM, - proto=0, - flags=AI_PASSIVE | AI_ADDRCONFIG, + SOCK_STREAM, ) mock_open_listeners.assert_called_once_with( - sorted(set(addrinfo_list)), + addrinfo_list, backlog=123456789, reuse_address=mocker.ANY, # Determined according to OS reuse_port=mocker.sentinel.reuse_port, @@ -339,25 +398,19 @@ async def test____create_tcp_listeners____open_listener_sockets( mock_ListenerSocketAdapter.assert_called_once_with(backend, mock_tcp_socket, expected_factory) assert listener_sockets == [mocker.sentinel.listener_socket] - @pytest.mark.parametrize("remote_host", [None, ""], ids=repr) - async def test____create_tcp_listeners____bind_to_any_interfaces( + @pytest.mark.parametrize("remote_host", [None, "", ["::", "0.0.0.0"]], ids=repr) + async def test____create_tcp_listeners____bind_to_all_interfaces( self, - remote_host: str | None, + remote_host: str | list[str] | None, backend: AsyncIOBackend, - mock_tcp_socket: MagicMock, + mock_tcp_socket_factory: Callable[[int], MagicMock], mocker: MockerFixture, ) -> None: # Arrange - event_loop = asyncio.get_running_loop() + mock_tcp_socket_ipv4 = mock_tcp_socket_factory(AF_INET) + mock_tcp_socket_ipv6 = mock_tcp_socket_factory(AF_INET6) remote_port = 5000 addrinfo_list = [ - ( - AF_INET, - SOCK_STREAM, - IPPROTO_TCP, - "", - ("0.0.0.0", remote_port), - ), ( AF_INET6, SOCK_STREAM, @@ -365,62 +418,6 @@ async def test____create_tcp_listeners____bind_to_any_interfaces( "", ("::", remote_port), ), - ] - mock_getaddrinfo = mocker.patch.object( - event_loop, - "getaddrinfo", - new_callable=mocker.AsyncMock, - return_value=addrinfo_list, - ) - mock_open_listeners = mocker.patch( - "easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result", - return_value=[mock_tcp_socket, mock_tcp_socket], - ) - mock_ListenerSocketAdapter: MagicMock = mocker.patch( - f"{_ASYNCIO_BACKEND_MODULE}.stream.listener.ListenerSocketAdapter", - return_value=mocker.sentinel.listener_socket, - ) - expected_factory: AbstractAcceptedSocketFactory[Any] = AcceptedSocketFactory() - - # Act - listener_sockets: Sequence[Any] = await backend.create_tcp_listeners( - remote_host, - remote_port, - backlog=123456789, - reuse_port=mocker.sentinel.reuse_port, - ) - - # Assert - mock_getaddrinfo.assert_awaited_once_with( - None, - remote_port, - family=AF_UNSPEC, - type=SOCK_STREAM, - proto=0, - flags=AI_PASSIVE | AI_ADDRCONFIG, - ) - mock_open_listeners.assert_called_once_with( - sorted(set(addrinfo_list)), - backlog=123456789, - reuse_address=mocker.ANY, # Determined according to OS - reuse_port=mocker.sentinel.reuse_port, - ) - assert mock_ListenerSocketAdapter.call_args_list == [ - mocker.call(backend, mock_tcp_socket, expected_factory) for _ in range(2) - ] - assert listener_sockets == [mocker.sentinel.listener_socket, mocker.sentinel.listener_socket] - - async def test____create_tcp_listeners____bind_to_several_hosts( - self, - backend: AsyncIOBackend, - mock_tcp_socket: MagicMock, - mocker: MockerFixture, - ) -> None: - # Arrange - event_loop = asyncio.get_running_loop() - remote_hosts = ["0.0.0.0", "::"] - remote_port = 5000 - addrinfo_list = [ ( AF_INET, SOCK_STREAM, @@ -428,106 +425,56 @@ async def test____create_tcp_listeners____bind_to_several_hosts( "", ("0.0.0.0", remote_port), ), - ( - AF_INET6, - SOCK_STREAM, - IPPROTO_TCP, - "", - ("::", remote_port), - ), ] - mock_getaddrinfo = mocker.patch.object( - event_loop, - "getaddrinfo", + mock_resolve_listener_addresses = mocker.patch.object( + AsyncIODNSResolver, + "resolve_listener_addresses", new_callable=mocker.AsyncMock, - side_effect=[[info] for info in addrinfo_list], + return_value=addrinfo_list, ) mock_open_listeners = mocker.patch( "easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result", - return_value=[mock_tcp_socket, mock_tcp_socket], + return_value=[mock_tcp_socket_ipv6, mock_tcp_socket_ipv4], ) mock_ListenerSocketAdapter: MagicMock = mocker.patch( f"{_ASYNCIO_BACKEND_MODULE}.stream.listener.ListenerSocketAdapter", - return_value=mocker.sentinel.listener_socket, + side_effect=[mocker.sentinel.listener_socket_ipv6, mocker.sentinel.listener_socket_ipv4], ) expected_factory: AbstractAcceptedSocketFactory[Any] = AcceptedSocketFactory() # Act listener_sockets: Sequence[Any] = await backend.create_tcp_listeners( - remote_hosts, + remote_host, remote_port, backlog=123456789, reuse_port=mocker.sentinel.reuse_port, ) # Assert - assert mock_getaddrinfo.await_args_list == [ - mocker.call( - host, + if isinstance(remote_host, list): + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + remote_host, remote_port, - family=AF_UNSPEC, - type=SOCK_STREAM, - proto=0, - flags=AI_PASSIVE | AI_ADDRCONFIG, + SOCK_STREAM, + ) + else: + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + [None], + remote_port, + SOCK_STREAM, ) - for host in remote_hosts - ] mock_open_listeners.assert_called_once_with( - sorted(set(addrinfo_list)), + addrinfo_list, backlog=123456789, reuse_address=mocker.ANY, # Determined according to OS reuse_port=mocker.sentinel.reuse_port, ) assert mock_ListenerSocketAdapter.call_args_list == [ - mocker.call(backend, mock_tcp_socket, expected_factory) for _ in range(2) + mocker.call(backend, sock, expected_factory) for sock in [mock_tcp_socket_ipv6, mock_tcp_socket_ipv4] ] - assert listener_sockets == [mocker.sentinel.listener_socket, mocker.sentinel.listener_socket] - - async def test____create_tcp_listeners____error_getaddrinfo_returns_empty_list( - self, - backend: AsyncIOBackend, - mocker: MockerFixture, - ) -> None: - # Arrange - event_loop = asyncio.get_running_loop() - remote_host = "remote_address" - remote_port = 5000 - - mock_getaddrinfo = mocker.patch.object( - event_loop, - "getaddrinfo", - new_callable=mocker.AsyncMock, - return_value=[], - ) - mock_open_listeners = mocker.patch( - "easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result", - side_effect=AssertionError, - ) - mock_ListenerSocketAdapter: MagicMock = mocker.patch( - f"{_ASYNCIO_BACKEND_MODULE}.stream.listener.ListenerSocketAdapter", - side_effect=AssertionError, - ) - - # Act - with pytest.raises(OSError, match=r"getaddrinfo\('remote_address'\) returned empty list"): - await backend.create_tcp_listeners( - remote_host, - remote_port, - backlog=123456789, - reuse_port=mocker.sentinel.reuse_port, - ) - - # Assert - mock_getaddrinfo.assert_awaited_once_with( - remote_host, - remote_port, - family=AF_UNSPEC, - type=SOCK_STREAM, - proto=0, - flags=AI_PASSIVE | AI_ADDRCONFIG, - ) - mock_open_listeners.assert_not_called() - mock_ListenerSocketAdapter.assert_not_called() + assert listener_sockets == [mocker.sentinel.listener_socket_ipv6, mocker.sentinel.listener_socket_ipv4] async def test____create_tcp_listeners____invalid_backlog( self, @@ -535,14 +482,13 @@ async def test____create_tcp_listeners____invalid_backlog( mocker: MockerFixture, ) -> None: # Arrange - event_loop = asyncio.get_running_loop() remote_host = "remote_address" remote_port = 5000 - mock_getaddrinfo = mocker.patch.object( - event_loop, - "getaddrinfo", + mock_resolve_listener_addresses = mocker.patch.object( + AsyncIODNSResolver, + "resolve_listener_addresses", new_callable=mocker.AsyncMock, - return_value=[], + side_effect=AssertionError, ) mock_open_listeners = mocker.patch( "easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result", @@ -563,7 +509,7 @@ async def test____create_tcp_listeners____invalid_backlog( ) # Assert - mock_getaddrinfo.assert_not_called() + mock_resolve_listener_addresses.assert_not_called() mock_open_listeners.assert_not_called() mock_ListenerSocketAdapter.assert_not_called() @@ -589,8 +535,9 @@ async def test____create_udp_endpoint____use_loop_create_datagram_endpoint( new_callable=mocker.AsyncMock, return_value=mock_endpoint, ) - mock_own_create_connection: AsyncMock = mocker.patch( - f"{_ASYNCIO_BACKEND_MODULE}._asyncio_utils.create_datagram_connection", + mock_own_create_connection: AsyncMock = mocker.patch.object( + AsyncIODNSResolver, + "create_datagram_connection", new_callable=mocker.AsyncMock, return_value=mock_udp_socket, ) @@ -603,6 +550,7 @@ async def test____create_udp_endpoint____use_loop_create_datagram_endpoint( # Assert mock_own_create_connection.assert_awaited_once_with( + backend, *remote_address, local_address=local_address, family=AF_UNSPEC if socket_family is None else socket_family, @@ -661,9 +609,9 @@ async def test____create_udp_listeners____open_listener_sockets( ] mock_transport = mocker.NonCallableMagicMock(spec=asyncio.DatagramTransport) mock_protocol = mocker.NonCallableMagicMock(spec=DatagramListenerProtocol) - mock_getaddrinfo = mocker.patch.object( - event_loop, - "getaddrinfo", + mock_resolve_listener_addresses = mocker.patch.object( + AsyncIODNSResolver, + "resolve_listener_addresses", new_callable=mocker.AsyncMock, return_value=addrinfo_list, ) @@ -690,16 +638,14 @@ async def test____create_udp_listeners____open_listener_sockets( ) # Assert - mock_getaddrinfo.assert_awaited_once_with( - remote_host, + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + [remote_host], remote_port, - family=AF_UNSPEC, - type=SOCK_DGRAM, - proto=0, - flags=AI_PASSIVE | AI_ADDRCONFIG, + SOCK_DGRAM, ) mock_open_listeners.assert_called_once_with( - sorted(set(addrinfo_list)), + addrinfo_list, backlog=None, reuse_address=False, reuse_port=mocker.sentinel.reuse_port, @@ -711,25 +657,20 @@ async def test____create_udp_listeners____open_listener_sockets( mock_DatagramListenerSocketAdapter.assert_called_once_with(backend, mock_transport, mock_protocol) assert listener_sockets == [mocker.sentinel.listener_socket] - @pytest.mark.parametrize("remote_host", [None, ""], ids=repr) - async def test____create_udp_listeners____bind_to_local_interfaces( + @pytest.mark.parametrize("remote_host", [None, "", ["::", "0.0.0.0"]], ids=repr) + async def test____create_udp_listeners____bind_to_all_interfaces( self, - remote_host: str | None, + remote_host: str | list[str] | None, backend: AsyncIOBackend, - mock_udp_socket: MagicMock, + mock_udp_socket_factory: Callable[[int], MagicMock], mocker: MockerFixture, ) -> None: # Arrange + mock_udp_socket_ipv4 = mock_udp_socket_factory(AF_INET) + mock_udp_socket_ipv6 = mock_udp_socket_factory(AF_INET6) event_loop = asyncio.get_running_loop() remote_port = 5000 addrinfo_list = [ - ( - AF_INET, - SOCK_DGRAM, - IPPROTO_UDP, - "", - ("127.0.0.1", remote_port), - ), ( AF_INET6, SOCK_DGRAM, @@ -737,75 +678,6 @@ async def test____create_udp_listeners____bind_to_local_interfaces( "", ("::1", remote_port), ), - ] - mock_transport = mocker.NonCallableMagicMock(spec=asyncio.DatagramTransport) - mock_protocol = mocker.NonCallableMagicMock(spec=DatagramListenerProtocol) - mock_getaddrinfo = mocker.patch.object( - event_loop, - "getaddrinfo", - new_callable=mocker.AsyncMock, - return_value=addrinfo_list, - ) - mock_open_listeners = mocker.patch( - "easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result", - return_value=[mock_udp_socket, mock_udp_socket], - ) - mock_create_datagram_endpoint: AsyncMock = mocker.patch.object( - event_loop, - "create_datagram_endpoint", - new_callable=mocker.AsyncMock, - return_value=(mock_transport, mock_protocol), - ) - mock_DatagramListenerSocketAdapter: MagicMock = mocker.patch( - f"{_ASYNCIO_BACKEND_MODULE}.datagram.listener.DatagramListenerSocketAdapter", - return_value=mocker.sentinel.listener_socket, - ) - - # Act - listener_sockets: Sequence[Any] = await backend.create_udp_listeners( - remote_host, - remote_port, - reuse_port=mocker.sentinel.reuse_port, - ) - - # Assert - mock_getaddrinfo.assert_awaited_once_with( - None, - remote_port, - family=AF_UNSPEC, - type=SOCK_DGRAM, - proto=0, - flags=AI_PASSIVE | AI_ADDRCONFIG, - ) - mock_open_listeners.assert_called_once_with( - sorted(set(addrinfo_list)), - backlog=None, - reuse_address=False, - reuse_port=mocker.sentinel.reuse_port, - ) - assert mock_create_datagram_endpoint.await_args_list == [ - mocker.call( - partial_eq(DatagramListenerProtocol, loop=event_loop), - sock=mock_udp_socket, - ) - for _ in range(2) - ] - assert mock_DatagramListenerSocketAdapter.call_args_list == [ - mocker.call(backend, mock_transport, mock_protocol) for _ in range(2) - ] - assert listener_sockets == [mocker.sentinel.listener_socket, mocker.sentinel.listener_socket] - - async def test____create_udp_listeners____bind_to_several_hosts( - self, - backend: AsyncIOBackend, - mock_udp_socket: MagicMock, - mocker: MockerFixture, - ) -> None: - # Arrange - event_loop = asyncio.get_running_loop() - remote_hosts = ["127.0.0.1", "::1"] - remote_port = 5000 - addrinfo_list = [ ( AF_INET, SOCK_DGRAM, @@ -813,58 +685,56 @@ async def test____create_udp_listeners____bind_to_several_hosts( "", ("127.0.0.1", remote_port), ), - ( - AF_INET6, - SOCK_DGRAM, - IPPROTO_UDP, - "", - ("::1", remote_port), - ), ] - mock_transport = mocker.NonCallableMagicMock(spec=asyncio.DatagramTransport) - mock_protocol = mocker.NonCallableMagicMock(spec=DatagramListenerProtocol) - mock_getaddrinfo = mocker.patch.object( - event_loop, - "getaddrinfo", + mock_transport_ipv6 = mocker.NonCallableMagicMock(spec=asyncio.DatagramTransport) + mock_transport_ipv4 = mocker.NonCallableMagicMock(spec=asyncio.DatagramTransport) + mock_protocol_ipv6 = mocker.NonCallableMagicMock(spec=DatagramListenerProtocol) + mock_protocol_ipv4 = mocker.NonCallableMagicMock(spec=DatagramListenerProtocol) + mock_resolve_listener_addresses = mocker.patch.object( + AsyncIODNSResolver, + "resolve_listener_addresses", new_callable=mocker.AsyncMock, return_value=addrinfo_list, ) mock_open_listeners = mocker.patch( "easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result", - return_value=[mock_udp_socket, mock_udp_socket], + return_value=[mock_udp_socket_ipv6, mock_udp_socket_ipv4], ) mock_create_datagram_endpoint: AsyncMock = mocker.patch.object( event_loop, "create_datagram_endpoint", new_callable=mocker.AsyncMock, - return_value=(mock_transport, mock_protocol), + side_effect=[(mock_transport_ipv6, mock_protocol_ipv6), (mock_transport_ipv4, mock_protocol_ipv4)], ) mock_DatagramListenerSocketAdapter: MagicMock = mocker.patch( f"{_ASYNCIO_BACKEND_MODULE}.datagram.listener.DatagramListenerSocketAdapter", - return_value=mocker.sentinel.listener_socket, + side_effect=[mocker.sentinel.listener_socket_ipv6, mocker.sentinel.listener_socket_ipv4], ) # Act listener_sockets: Sequence[Any] = await backend.create_udp_listeners( - remote_hosts, + remote_host, remote_port, reuse_port=mocker.sentinel.reuse_port, ) # Assert - assert mock_getaddrinfo.await_args_list == [ - mocker.call( - host, + if isinstance(remote_host, list): + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + remote_host, remote_port, - family=AF_UNSPEC, - type=SOCK_DGRAM, - proto=0, - flags=AI_PASSIVE | AI_ADDRCONFIG, + SOCK_DGRAM, + ) + else: + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + [None], + remote_port, + SOCK_DGRAM, ) - for host in remote_hosts - ] mock_open_listeners.assert_called_once_with( - sorted(set(addrinfo_list)), + addrinfo_list, backlog=None, reuse_address=False, reuse_port=mocker.sentinel.reuse_port, @@ -872,65 +742,18 @@ async def test____create_udp_listeners____bind_to_several_hosts( assert mock_create_datagram_endpoint.await_args_list == [ mocker.call( partial_eq(DatagramListenerProtocol, loop=event_loop), - sock=mock_udp_socket, + sock=sock, ) - for _ in range(2) + for sock in [mock_udp_socket_ipv6, mock_udp_socket_ipv4] ] assert mock_DatagramListenerSocketAdapter.call_args_list == [ - mocker.call(backend, mock_transport, mock_protocol) for _ in range(2) + mocker.call(backend, mock_transport, mock_protocol) + for mock_transport, mock_protocol in [ + (mock_transport_ipv6, mock_protocol_ipv6), + (mock_transport_ipv4, mock_protocol_ipv4), + ] ] - assert listener_sockets == [mocker.sentinel.listener_socket, mocker.sentinel.listener_socket] - - async def test____create_udp_listeners____error_getaddrinfo_returns_empty_list( - self, - backend: AsyncIOBackend, - mocker: MockerFixture, - ) -> None: - # Arrange - event_loop = asyncio.get_running_loop() - remote_host = "remote_address" - remote_port = 5000 - mock_getaddrinfo = mocker.patch.object( - event_loop, - "getaddrinfo", - new_callable=mocker.AsyncMock, - return_value=[], - ) - mock_open_listeners = mocker.patch( - "easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result", - side_effect=AssertionError, - ) - mock_create_datagram_endpoint: AsyncMock = mocker.patch.object( - event_loop, - "create_datagram_endpoint", - new_callable=mocker.AsyncMock, - side_effect=AssertionError, - ) - mock_DatagramListenerSocketAdapter: MagicMock = mocker.patch( - f"{_ASYNCIO_BACKEND_MODULE}.datagram.listener.DatagramListenerSocketAdapter", - side_effect=AssertionError, - ) - - # Act - with pytest.raises(OSError, match=r"getaddrinfo\('remote_address'\) returned empty list"): - await backend.create_udp_listeners( - remote_host, - remote_port, - reuse_port=mocker.sentinel.reuse_port, - ) - - # Assert - mock_getaddrinfo.assert_awaited_once_with( - remote_host, - remote_port, - family=AF_UNSPEC, - type=SOCK_DGRAM, - proto=0, - flags=AI_PASSIVE | AI_ADDRCONFIG, - ) - mock_open_listeners.assert_not_called() - mock_create_datagram_endpoint.assert_not_called() - mock_DatagramListenerSocketAdapter.assert_not_called() + assert listener_sockets == [mocker.sentinel.listener_socket_ipv6, mocker.sentinel.listener_socket_ipv4] async def test____create_lock____use_asyncio_Lock_class( self, diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_dns_resolver.py b/tests/unit_test/test_async/test_asyncio_backend/test_dns_resolver.py new file mode 100644 index 00000000..27d3dd4b --- /dev/null +++ b/tests/unit_test/test_async/test_asyncio_backend/test_dns_resolver.py @@ -0,0 +1,30 @@ +from __future__ import annotations + +import asyncio +from typing import TYPE_CHECKING + +from easynetwork.lowlevel.api_async.backend._asyncio.dns_resolver import AsyncIODNSResolver + +import pytest + +if TYPE_CHECKING: + from unittest.mock import AsyncMock, MagicMock + + from pytest_mock import MockerFixture + + +@pytest.fixture +def mock_sock_connect(event_loop: asyncio.AbstractEventLoop, mocker: MockerFixture) -> AsyncMock: + return mocker.patch.object(event_loop, "sock_connect", new_callable=mocker.AsyncMock, return_value=None) + + +@pytest.mark.asyncio +async def test____AsyncIODNSResolver____connect_socket(mock_tcp_socket: MagicMock, mock_sock_connect: AsyncMock) -> None: + # Arrange + dns_resolver = AsyncIODNSResolver() + + # Act + await dns_resolver.connect_socket(mock_tcp_socket, ("127.0.0.1", 12345)) + + # Assert + mock_sock_connect.assert_awaited_once_with(mock_tcp_socket, ("127.0.0.1", 12345)) diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py index 959af41a..836462d3 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_stream.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_stream.py @@ -174,11 +174,11 @@ def _make_accept_side_effect( mocker: MockerFixture, sleep_time: float = 0, ) -> Callable[[], Coroutine[Any, Any, MagicMock]]: - accept_cb = mocker.MagicMock(side_effect=side_effect) + accept_cb = mocker.AsyncMock(side_effect=side_effect) async def accept_side_effect() -> MagicMock: await asyncio.sleep(sleep_time) - return accept_cb() + return await accept_cb() return accept_side_effect diff --git a/tests/unit_test/test_async/test_asyncio_backend/test_utils.py b/tests/unit_test/test_async/test_asyncio_backend/test_utils.py index 7ad388b8..42e8fb02 100644 --- a/tests/unit_test/test_async/test_asyncio_backend/test_utils.py +++ b/tests/unit_test/test_async/test_asyncio_backend/test_utils.py @@ -2,756 +2,22 @@ import asyncio import asyncio.trsock -import errno -from collections.abc import Callable, Sequence -from socket import ( - AF_INET, - AF_INET6, - AF_UNSPEC, - AI_NUMERICHOST, - AI_NUMERICSERV, - AI_PASSIVE, - EAI_BADFLAGS, - EAI_NONAME, - IPPROTO_TCP, - IPPROTO_UDP, - SOCK_DGRAM, - SOCK_STREAM, - SocketType, - gaierror, -) -from typing import TYPE_CHECKING, Any, Literal, Protocol as TypingProtocol, assert_never +from collections.abc import Callable +from socket import SocketType +from typing import TYPE_CHECKING -from easynetwork.lowlevel._utils import error_from_errno -from easynetwork.lowlevel.api_async.backend._asyncio._asyncio_utils import ( - create_connection, - create_datagram_connection, - ensure_resolved, - wait_until_readable, - wait_until_writable, -) -from easynetwork.lowlevel.api_async.backend._asyncio.tasks import TaskUtils +from easynetwork.lowlevel.api_async.backend._asyncio._asyncio_utils import wait_until_readable, wait_until_writable import pytest from ....tools import is_proactor_event_loop -from ..._utils import datagram_addrinfo_list, stream_addrinfo_list if TYPE_CHECKING: - from unittest.mock import AsyncMock, MagicMock + from unittest.mock import MagicMock from pytest_mock import MockerFixture -@pytest.fixture -def mock_getaddrinfo(event_loop: asyncio.AbstractEventLoop, mocker: MockerFixture) -> AsyncMock: - return mocker.patch.object(event_loop, "getaddrinfo", new_callable=mocker.AsyncMock) - - -@pytest.fixture -def mock_stdlib_socket_getaddrinfo(mocker: MockerFixture) -> AsyncMock: - return mocker.patch("socket.getaddrinfo") - - -@pytest.fixture -def mock_sock_connect(event_loop: asyncio.AbstractEventLoop, mocker: MockerFixture) -> AsyncMock: - return mocker.patch.object(event_loop, "sock_connect", new_callable=mocker.AsyncMock, return_value=None) - - -@pytest.fixture -def mock_socket_ipv4(mock_socket_factory: Callable[[], MagicMock]) -> MagicMock: - return mock_socket_factory() - - -@pytest.fixture -def mock_socket_ipv6(mock_socket_factory: Callable[[], MagicMock]) -> MagicMock: - return mock_socket_factory() - - -@pytest.fixture(autouse=True) -def mock_socket_cls(mock_socket_ipv4: MagicMock, mock_socket_ipv6: MagicMock, mocker: MockerFixture) -> MagicMock: - def side_effect(family: int, type: int, proto: int) -> MagicMock: - if family == AF_INET6: - used_socket = mock_socket_ipv6 - elif family == AF_INET: - used_socket = mock_socket_ipv4 - else: - raise error_from_errno(errno.EAFNOSUPPORT) - - used_socket.family = family - used_socket.type = type - used_socket.proto = proto - return used_socket - - return mocker.patch("socket.socket", side_effect=side_effect) - - -@pytest.mark.asyncio -async def test____ensure_resolved____try_numeric_first( - mock_getaddrinfo: AsyncMock, - mock_stdlib_socket_getaddrinfo: MagicMock, -) -> None: - # Arrange - expected_result = stream_addrinfo_list(8080, families=[AF_INET]) - mock_stdlib_socket_getaddrinfo.return_value = expected_result - - # Act - info = await ensure_resolved("127.0.0.1", 8080, 123456789, SOCK_STREAM, proto=IPPROTO_TCP, flags=AI_PASSIVE) - - # Assert - assert info == expected_result - mock_stdlib_socket_getaddrinfo.assert_called_once_with( - "127.0.0.1", - 8080, - family=123456789, - type=SOCK_STREAM, - proto=IPPROTO_TCP, - flags=AI_PASSIVE | AI_NUMERICHOST | AI_NUMERICSERV, - ) - mock_getaddrinfo.assert_not_awaited() - - -@pytest.mark.asyncio -async def test____ensure_resolved____try_numeric_first____success_but_return_empty_list( - mock_getaddrinfo: AsyncMock, - mock_stdlib_socket_getaddrinfo: MagicMock, -) -> None: - # Arrange - mock_stdlib_socket_getaddrinfo.return_value = [] - - # Act - with pytest.raises(OSError, match=r"^getaddrinfo\('127.0.0.1'\) returned empty list$"): - await ensure_resolved("127.0.0.1", 8080, 123456789, SOCK_STREAM, proto=IPPROTO_TCP, flags=AI_PASSIVE) - - # Assert - mock_stdlib_socket_getaddrinfo.assert_called_once_with( - "127.0.0.1", - 8080, - family=123456789, - type=SOCK_STREAM, - proto=IPPROTO_TCP, - flags=AI_PASSIVE | AI_NUMERICHOST | AI_NUMERICSERV, - ) - mock_getaddrinfo.assert_not_awaited() - - -@pytest.mark.asyncio -async def test____ensure_resolved____fallback_to_async_getaddrinfo( - mock_getaddrinfo: AsyncMock, - mock_stdlib_socket_getaddrinfo: MagicMock, -) -> None: - # Arrange - expected_result = stream_addrinfo_list(8080, families=[AF_INET]) - mock_stdlib_socket_getaddrinfo.side_effect = gaierror(EAI_NONAME, "Name or service not known") - mock_getaddrinfo.return_value = expected_result - - # Act - info = await ensure_resolved("127.0.0.1", 8080, 123456789, SOCK_STREAM, proto=IPPROTO_TCP, flags=AI_PASSIVE) - - # Assert - assert info == expected_result - mock_getaddrinfo.assert_awaited_once_with( - "127.0.0.1", - 8080, - family=123456789, - type=SOCK_STREAM, - proto=IPPROTO_TCP, - flags=AI_PASSIVE, - ) - - -@pytest.mark.asyncio -async def test____ensure_resolved____fallback_to_async_getaddrinfo____success_but_return_empty_list( - mock_getaddrinfo: AsyncMock, - mock_stdlib_socket_getaddrinfo: MagicMock, -) -> None: - # Arrange - mock_stdlib_socket_getaddrinfo.side_effect = gaierror(EAI_NONAME, "Name or service not known") - mock_getaddrinfo.return_value = [] - - # Act - with pytest.raises(OSError, match=r"^getaddrinfo\('127.0.0.1'\) returned empty list$"): - await ensure_resolved("127.0.0.1", 8080, 123456789, SOCK_STREAM, proto=IPPROTO_TCP, flags=AI_PASSIVE) - - # Assert - mock_getaddrinfo.assert_awaited_once_with( - "127.0.0.1", - 8080, - family=123456789, - type=SOCK_STREAM, - proto=IPPROTO_TCP, - flags=AI_PASSIVE, - ) - - -@pytest.mark.asyncio -async def test____ensure_resolved____propagate_unrelated_gaierror( - mock_getaddrinfo: AsyncMock, - mock_stdlib_socket_getaddrinfo: MagicMock, -) -> None: - # Arrange - mock_stdlib_socket_getaddrinfo.side_effect = gaierror(EAI_BADFLAGS, "Invalid flags") - - # Act - with pytest.raises(gaierror): - await ensure_resolved("127.0.0.1", 8080, 123456789, SOCK_STREAM, proto=IPPROTO_TCP, flags=AI_PASSIVE) - - # Assert - mock_stdlib_socket_getaddrinfo.assert_called_once_with( - "127.0.0.1", - 8080, - family=123456789, - type=SOCK_STREAM, - proto=IPPROTO_TCP, - flags=AI_PASSIVE | AI_NUMERICHOST | AI_NUMERICSERV, - ) - mock_getaddrinfo.assert_not_awaited() - - -class _CreateConnectionCallable(TypingProtocol): - async def __call__( - self, - host: str, - port: int, - *, - local_address: tuple[str, int] | None = None, - ) -> SocketType: ... - - -class _AddrInfoListFactory(TypingProtocol): - def __call__( - self, - port: int, - families: Sequence[int] = ..., - ) -> Sequence[tuple[int, int, int, str, tuple[Any, ...]]]: ... - - -@pytest.fixture(params=[SOCK_STREAM, SOCK_DGRAM], ids=lambda sock_type: f"sock_type=={sock_type!r}") -def connection_socktype(request: pytest.FixtureRequest) -> int: - return request.param - - -@pytest.fixture -def create_connection_of_socktype(connection_socktype: int) -> _CreateConnectionCallable: - if connection_socktype == SOCK_STREAM: - return create_connection - if connection_socktype == SOCK_DGRAM: - return create_datagram_connection - pytest.fail("Invalid fixture argument") - - -@pytest.fixture -def addrinfo_list_factory(connection_socktype: int) -> _AddrInfoListFactory: - if connection_socktype == SOCK_STREAM: - return stream_addrinfo_list - if connection_socktype == SOCK_DGRAM: - return datagram_addrinfo_list - pytest.fail("Invalid fixture argument") - - -@pytest.mark.asyncio -@pytest.mark.parametrize("with_local_address", [False, True], ids=lambda boolean: f"with_local_address=={boolean}") -async def test____create_connection____default( - create_connection_of_socktype: _CreateConnectionCallable, - addrinfo_list_factory: _AddrInfoListFactory, - connection_socktype: int, - with_local_address: bool, - mock_socket_cls: MagicMock, - mock_socket_ipv4: MagicMock, - mock_socket_ipv6: MagicMock, - mock_getaddrinfo: AsyncMock, - mock_sock_connect: AsyncMock, - mocker: MockerFixture, -) -> None: - # Arrange - expected_proto = IPPROTO_TCP if connection_socktype == SOCK_STREAM else IPPROTO_UDP - remote_host, remote_port = "localhost", 12345 - local_address: tuple[str, int] | None = ("localhost", 11111) if with_local_address else None - - if local_address is None: - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] - else: - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port), addrinfo_list_factory(local_address[1])] - - # Act - socket = await create_connection_of_socktype(remote_host, remote_port, local_address=local_address) - - # Assert - if local_address is None: - assert mock_getaddrinfo.await_args_list == [ - mocker.call(remote_host, remote_port, family=AF_UNSPEC, type=connection_socktype, proto=0, flags=0), - ] - else: - assert mock_getaddrinfo.await_args_list == [ - mocker.call(remote_host, remote_port, family=AF_UNSPEC, type=connection_socktype, proto=0, flags=0), - mocker.call(*local_address, family=AF_UNSPEC, type=connection_socktype, proto=0, flags=0), - ] - - mock_socket_cls.assert_called_once_with(AF_INET6, connection_socktype, expected_proto) - assert socket is mock_socket_ipv6 - - mock_socket_ipv6.setblocking.assert_called_once_with(False) - if local_address is None: - mock_socket_ipv6.bind.assert_not_called() - else: - mock_socket_ipv6.bind.assert_called_once_with(("::1", 11111, 0, 0)) - mock_sock_connect.assert_awaited_once_with(mock_socket_ipv6, ("::1", 12345, 0, 0)) - mock_socket_ipv6.close.assert_not_called() - - mock_socket_ipv4.setblocking.assert_not_called() - mock_socket_ipv4.bind.assert_not_called() - - -@pytest.mark.asyncio -@pytest.mark.parametrize("fail_on", ["socket", "bind", "connect"], ids=lambda fail_on: f"fail_on=={fail_on}") -async def test____create_connection____first_failed( - fail_on: Literal["socket", "bind", "connect"], - create_connection_of_socktype: _CreateConnectionCallable, - addrinfo_list_factory: _AddrInfoListFactory, - connection_socktype: int, - mock_socket_cls: MagicMock, - mock_socket_ipv4: MagicMock, - mock_socket_ipv6: MagicMock, - mock_getaddrinfo: AsyncMock, - mock_sock_connect: AsyncMock, - mocker: MockerFixture, -) -> None: - # Arrange - remote_host, remote_port = "localhost", 12345 - local_address: tuple[str, int] | None = ("localhost", 11111) if fail_on == "bind" else None - - if local_address is None: - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] - else: - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port), addrinfo_list_factory(local_address[1])] - - match fail_on: - case "socket": - mock_socket_cls.side_effect = [error_from_errno(errno.EAFNOSUPPORT), mock_socket_ipv4] - case "bind": - mock_socket_ipv6.bind.side_effect = error_from_errno(errno.EADDRINUSE) - case "connect": - mock_sock_connect.side_effect = [error_from_errno(errno.ECONNREFUSED), None] - case _: - assert_never(fail_on) - - # Act - socket = await create_connection_of_socktype(remote_host, remote_port, local_address=local_address) - - # Assert - if connection_socktype == SOCK_STREAM: - assert mock_socket_cls.call_args_list == [ - mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), - ] - else: - assert mock_socket_cls.call_args_list == [ - mocker.call(AF_INET6, SOCK_DGRAM, IPPROTO_UDP), - mocker.call(AF_INET, SOCK_DGRAM, IPPROTO_UDP), - ] - assert socket is mock_socket_ipv4 - - if fail_on != "socket": - mock_socket_ipv6.setblocking.assert_called_once_with(False) - if local_address is None: - mock_socket_ipv6.bind.assert_not_called() - else: - mock_socket_ipv6.bind.assert_called_once_with(("::1", 11111, 0, 0)) - match fail_on: - case "bind": - assert mocker.call(mock_socket_ipv6, ("::1", 12345, 0, 0)) not in mock_sock_connect.await_args_list - case "connect": - mock_sock_connect.assert_any_await(mock_socket_ipv6, ("::1", 12345, 0, 0)) - case _: - assert_never(fail_on) - mock_socket_ipv6.close.assert_called_once_with() - - mock_socket_ipv4.setblocking.assert_called_once_with(False) - if local_address is None: - mock_socket_ipv4.bind.assert_not_called() - else: - mock_socket_ipv4.bind.assert_called_once_with(("127.0.0.1", 11111)) - mock_sock_connect.assert_awaited_with(mock_socket_ipv4, ("127.0.0.1", 12345)) - mock_socket_ipv4.close.assert_not_called() - - -@pytest.mark.asyncio -@pytest.mark.parametrize("fail_on", ["socket", "bind", "connect"], ids=lambda fail_on: f"fail_on=={fail_on}") -async def test____create_connection____all_failed( - fail_on: Literal["socket", "bind", "connect"], - create_connection_of_socktype: _CreateConnectionCallable, - addrinfo_list_factory: _AddrInfoListFactory, - connection_socktype: int, - mock_socket_cls: MagicMock, - mock_socket_ipv4: MagicMock, - mock_socket_ipv6: MagicMock, - mock_getaddrinfo: AsyncMock, - mock_sock_connect: AsyncMock, - mocker: MockerFixture, -) -> None: - # Arrange - remote_host, remote_port = "localhost", 12345 - local_address: tuple[str, int] | None = ("localhost", 11111) if fail_on == "bind" else None - - if local_address is None: - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] - else: - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port), addrinfo_list_factory(local_address[1])] - - match fail_on: - case "socket": - mock_socket_cls.side_effect = error_from_errno(errno.EAFNOSUPPORT) - case "bind": - mock_socket_ipv4.bind.side_effect = error_from_errno(errno.EADDRINUSE) - mock_socket_ipv6.bind.side_effect = error_from_errno(errno.EADDRINUSE) - case "connect": - mock_sock_connect.side_effect = error_from_errno(errno.ECONNREFUSED) - case _: - assert_never(fail_on) - - # Act - with pytest.raises(ExceptionGroup) as exc_info: - await create_connection_of_socktype(remote_host, remote_port, local_address=local_address) - - # Assert - os_errors, exc = exc_info.value.split(OSError) - assert exc is None - assert os_errors is not None - assert len(os_errors.exceptions) == 2 - assert all(isinstance(exc, OSError) for exc in os_errors.exceptions) - del os_errors - - if connection_socktype == SOCK_STREAM: - assert mock_socket_cls.call_args_list == [ - mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), - ] - else: - assert mock_socket_cls.call_args_list == [ - mocker.call(AF_INET6, SOCK_DGRAM, IPPROTO_UDP), - mocker.call(AF_INET, SOCK_DGRAM, IPPROTO_UDP), - ] - - if fail_on != "socket": - mock_socket_ipv4.setblocking.assert_called_once_with(False) - mock_socket_ipv6.setblocking.assert_called_once_with(False) - if local_address is None: - mock_socket_ipv4.bind.assert_not_called() - mock_socket_ipv6.bind.assert_not_called() - else: - mock_socket_ipv4.bind.assert_called_once_with(("127.0.0.1", 11111)) - mock_socket_ipv6.bind.assert_called_once_with(("::1", 11111, 0, 0)) - match fail_on: - case "bind" | "socket": - assert mocker.call(mock_socket_ipv4, ("127.0.0.1", 12345)) not in mock_sock_connect.await_args_list - assert mocker.call(mock_socket_ipv6, ("::1", 12345, 0, 0)) not in mock_sock_connect.await_args_list - case "connect": - mock_sock_connect.assert_any_await(mock_socket_ipv4, ("127.0.0.1", 12345)) - mock_sock_connect.assert_any_await(mock_socket_ipv6, ("::1", 12345, 0, 0)) - case _: - assert_never(fail_on) - mock_socket_ipv4.close.assert_called_once_with() - mock_socket_ipv6.close.assert_called_once_with() - - -@pytest.mark.asyncio -@pytest.mark.parametrize("fail_on", ["socket", "connect"], ids=lambda fail_on: f"fail_on=={fail_on}") -async def test____create_connection____unrelated_exception( - fail_on: Literal["socket", "connect"], - connection_socktype: int, - create_connection_of_socktype: _CreateConnectionCallable, - addrinfo_list_factory: _AddrInfoListFactory, - mock_socket_cls: MagicMock, - mock_socket_ipv6: MagicMock, - mock_getaddrinfo: AsyncMock, - mock_sock_connect: AsyncMock, -) -> None: - # Arrange - remote_host, remote_port = "localhost", 12345 - - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] - expected_failure_exception = BaseException() - - match fail_on: - case "socket": - mock_socket_cls.side_effect = expected_failure_exception - case "connect": - mock_sock_connect.side_effect = expected_failure_exception - case _: - assert_never(fail_on) - - # Act - with pytest.raises(BaseException) as exc_info: - await create_connection_of_socktype(remote_host, remote_port) - - # Assert - if connection_socktype == SOCK_STREAM: - assert isinstance(exc_info.value, BaseExceptionGroup) - assert len(exc_info.value.exceptions) == 1 - assert exc_info.value.exceptions[0] is expected_failure_exception - else: - assert exc_info.value is expected_failure_exception - if fail_on != "socket": - mock_socket_ipv6.close.assert_called_once_with() - - -@pytest.mark.asyncio -@pytest.mark.parametrize("fail_on", ["remote_address", "local_address"], ids=lambda fail_on: f"fail_on=={fail_on}") -async def test____create_connection____getaddrinfo_returned_empty_list( - fail_on: Literal["remote_address", "local_address"], - create_connection_of_socktype: _CreateConnectionCallable, - addrinfo_list_factory: _AddrInfoListFactory, - mock_socket_cls: MagicMock, - mock_socket_ipv4: MagicMock, - mock_socket_ipv6: MagicMock, - mock_getaddrinfo: AsyncMock, - mock_sock_connect: AsyncMock, -) -> None: - # Arrange - remote_host, remote_port = "localhost", 12345 - local_address: tuple[str, int] = ("localhost", 11111) - - match fail_on: - case "remote_address": - mock_getaddrinfo.side_effect = [[]] - case "local_address": - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port), []] - case _: - assert_never(fail_on) - - # Act - with pytest.raises(OSError, match=r"^getaddrinfo\('localhost'\) returned empty list$"): - await create_connection_of_socktype(remote_host, remote_port, local_address=local_address) - - # Assert - mock_socket_cls.assert_not_called() - mock_socket_ipv4.bind.assert_not_called() - mock_socket_ipv6.bind.assert_not_called() - mock_sock_connect.assert_not_called() - - -@pytest.mark.asyncio -async def test____create_connection____getaddrinfo_return_mismatch( - create_connection_of_socktype: _CreateConnectionCallable, - addrinfo_list_factory: _AddrInfoListFactory, - mock_socket_ipv4: MagicMock, - mock_socket_ipv6: MagicMock, - mock_getaddrinfo: AsyncMock, - mock_sock_connect: AsyncMock, -) -> None: - # Arrange - remote_host, remote_port = "localhost", 12345 - local_address: tuple[str, int] = ("localhost", 11111) - - mock_getaddrinfo.side_effect = [ - addrinfo_list_factory(remote_port, families=[AF_INET6]), - addrinfo_list_factory(local_address[1], families=[AF_INET]), - ] - - # Act - with pytest.raises(ExceptionGroup) as exc_info: - await create_connection_of_socktype(remote_host, remote_port, local_address=local_address) - - # Assert - os_errors, exc = exc_info.value.split(OSError) - assert exc is None - assert os_errors is not None - assert len(os_errors.exceptions) == 1 - assert str(os_errors.exceptions[0]) == f"no matching local address with family={AF_INET6!r} found" - del os_errors - - mock_socket_ipv4.bind.assert_not_called() - mock_socket_ipv6.bind.assert_not_called() - mock_sock_connect.assert_not_called() - - -@pytest.mark.asyncio -@pytest.mark.parametrize("connection_socktype", [SOCK_STREAM], indirect=True, ids=repr) -@pytest.mark.flaky(retries=3) -async def test____create_connection____happy_eyeballs_delay____connect_cancellation( - addrinfo_list_factory: _AddrInfoListFactory, - mock_socket_cls: MagicMock, - mock_socket_ipv4: MagicMock, - mock_socket_ipv6: MagicMock, - mock_getaddrinfo: AsyncMock, - mock_sock_connect: AsyncMock, - mocker: MockerFixture, -) -> None: - # Arrange - event_loop = asyncio.get_running_loop() - timestamps: list[float] = [] - remote_host, remote_port = "localhost", 12345 - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port, families=[AF_INET6, AF_INET])] - - async def connect_side_effect(sock: SocketType, address: tuple[Any, ...]) -> None: - timestamps.append(event_loop.time()) - if sock.family == AF_INET6: - await asyncio.sleep(1) - else: - await asyncio.sleep(0.01) - - mock_sock_connect.side_effect = connect_side_effect - - # Act - socket = await create_connection(remote_host, remote_port, happy_eyeballs_delay=0.5) - - # Assert - assert socket is mock_socket_ipv4 - assert mock_socket_cls.call_args_list == [ - mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), - ] - - mock_socket_ipv6.close.assert_called_once_with() - mock_socket_ipv4.close.assert_not_called() - - ipv6_start_time, ipv4_start_time = timestamps - assert ipv4_start_time - ipv6_start_time == pytest.approx(0.5, rel=1e-1) - - -@pytest.mark.asyncio -@pytest.mark.parametrize("connection_socktype", [SOCK_STREAM], indirect=True, ids=repr) -async def test____create_connection____happy_eyeballs_delay____connect_too_late( - addrinfo_list_factory: _AddrInfoListFactory, - mock_socket_cls: MagicMock, - mock_socket_ipv4: MagicMock, - mock_socket_ipv6: MagicMock, - mock_getaddrinfo: AsyncMock, - mock_sock_connect: AsyncMock, - mocker: MockerFixture, -) -> None: - # Arrange - remote_host, remote_port = "localhost", 12345 - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port, families=[AF_INET6, AF_INET])] - - async def connect_side_effect(sock: SocketType, address: tuple[Any, ...]) -> None: - try: - await asyncio.sleep(1) - except asyncio.CancelledError: - TaskUtils.current_asyncio_task().uncancel() - await asyncio.sleep(0) - - mock_sock_connect.side_effect = connect_side_effect - - # Act - socket = await create_connection(remote_host, remote_port, happy_eyeballs_delay=0.25) - - # Assert - assert socket is mock_socket_ipv6 - assert mock_socket_cls.call_args_list == [ - mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), - ] - - mock_socket_ipv4.close.assert_called_once_with() - mock_socket_ipv6.close.assert_not_called() - - -@pytest.mark.asyncio -@pytest.mark.parametrize("connection_socktype", [SOCK_STREAM], indirect=True, ids=repr) -async def test____create_connection____happy_eyeballs_delay____winner_closed_because_of_exception_in_another_task( - addrinfo_list_factory: _AddrInfoListFactory, - mock_socket_cls: MagicMock, - mock_socket_ipv4: MagicMock, - mock_socket_ipv6: MagicMock, - mock_getaddrinfo: AsyncMock, - mock_sock_connect: AsyncMock, - mocker: MockerFixture, -) -> None: - # Arrange - expected_failure_exception = BaseException("error") - remote_host, remote_port = "localhost", 12345 - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port, families=[AF_INET6, AF_INET])] - - async def connect_side_effect(sock: SocketType, address: tuple[Any, ...]) -> None: - try: - await asyncio.sleep(0.5) - except asyncio.CancelledError: - raise expected_failure_exception from None - - mock_sock_connect.side_effect = connect_side_effect - - # Act - with pytest.raises(BaseExceptionGroup) as exc_info: - await create_connection(remote_host, remote_port, happy_eyeballs_delay=0.25) - - # Assert - assert mock_socket_cls.call_args_list == [ - mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), - ] - assert list(exc_info.value.exceptions) == [expected_failure_exception] - - mock_socket_ipv4.close.assert_called_once_with() - mock_socket_ipv6.close.assert_called_once_with() - - -@pytest.mark.asyncio -@pytest.mark.parametrize("connection_socktype", [SOCK_STREAM], indirect=True, ids=repr) -async def test____create_connection____happy_eyeballs_delay____addrinfo_reordering____prioritize_ipv6_over_ipv4( - addrinfo_list_factory: _AddrInfoListFactory, - mock_socket_cls: MagicMock, - mock_socket_ipv6: MagicMock, - mock_getaddrinfo: AsyncMock, - mock_sock_connect: AsyncMock, - mocker: MockerFixture, -) -> None: - # Arrange - remote_host, remote_port = "localhost", 12345 - mock_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port, families=[AF_INET, AF_INET6])] - - async def connect_side_effect(sock: SocketType, address: tuple[Any, ...]) -> None: - await asyncio.sleep(0.5) - - mock_sock_connect.side_effect = connect_side_effect - - # Act - socket = await create_connection(remote_host, remote_port, happy_eyeballs_delay=0.25) - - # Assert - assert socket is mock_socket_ipv6 - assert mock_socket_cls.call_args_list == [ - mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), - ] - - -@pytest.mark.asyncio -@pytest.mark.parametrize("connection_socktype", [SOCK_STREAM], indirect=True, ids=repr) -async def test____create_connection____happy_eyeballs_delay____addrinfo_reordering____interleave_families( - addrinfo_list_factory: _AddrInfoListFactory, - mock_socket_cls: MagicMock, - mock_socket_ipv6: MagicMock, - mock_getaddrinfo: AsyncMock, - mock_sock_connect: AsyncMock, - mocker: MockerFixture, -) -> None: - # Arrange - remote_host, remote_port = "localhost", 12345 - mock_getaddrinfo.side_effect = [ - addrinfo_list_factory(remote_port, families=[AF_INET6, AF_INET6, AF_INET6, AF_INET, AF_INET, AF_INET]), - ] - - async def connect_side_effect(sock: SocketType, address: tuple[Any, ...]) -> None: - await asyncio.sleep(1) - - mock_sock_connect.side_effect = connect_side_effect - - # Act - socket = await create_connection(remote_host, remote_port, happy_eyeballs_delay=0.1) - - # Assert - assert socket is mock_socket_ipv6 - assert mock_socket_cls.call_args_list == [ - mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), - mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), - ] - - @pytest.mark.asyncio @pytest.mark.parametrize( ["waiter", "event_loop_add_event_func_name", "event_loop_remove_event_func_name"], diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py index 7bd77dda..3207b29b 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_backend.py @@ -2,16 +2,17 @@ import asyncio from collections.abc import Awaitable -from typing import TYPE_CHECKING, Any, final +from typing import TYPE_CHECKING, Any, final, no_type_check from easynetwork.lowlevel.api_async.backend.abc import TaskInfo import pytest +from ...._utils import partial_eq from ._fake_backends import BaseFakeBackend if TYPE_CHECKING: - from unittest.mock import MagicMock + from unittest.mock import AsyncMock, MagicMock from pytest_mock import MockerFixture @@ -23,6 +24,7 @@ def __init__(self, mocker: MockerFixture) -> None: self.mock_coro_yield: MagicMock = mocker.MagicMock(side_effect=lambda: asyncio.sleep(0)) self.mock_current_time: MagicMock = mocker.MagicMock(return_value=123456789) self.mock_sleep = mocker.AsyncMock(return_value=None) + self.mock_run_in_thread: AsyncMock = mocker.AsyncMock(side_effect=lambda func, /, *args, **options: func(*args)) async def coro_yield(self) -> None: await self.mock_coro_yield() @@ -42,6 +44,10 @@ async def sleep(self, delay: float) -> None: async def ignore_cancellation(self, coroutine: Awaitable[Any]) -> Any: return await coroutine + @no_type_check + async def run_in_thread(self, *args: Any, **kwargs: Any) -> Any: + return await self.mock_run_in_thread(*args, **kwargs) + class TestTaskInfo: def test____equality____between_two_task_info_objects____equal(self, mocker: MockerFixture) -> None: @@ -111,3 +117,61 @@ async def test____sleep_until____deadline_lower_than_current_time( # Assert backend.mock_current_time.assert_called_once_with() backend.mock_sleep.assert_awaited_once_with(0) + + async def test____getaddrinfo____run_stdlib_function_in_thread( + self, + backend: MockBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_getaddrinfo = mocker.patch("socket.getaddrinfo", return_value=mocker.sentinel.addrinfo_list) + + # Act + addrinfo_list = await backend.getaddrinfo( + host=mocker.sentinel.host, + port=mocker.sentinel.port, + family=mocker.sentinel.family, + type=mocker.sentinel.type, + proto=mocker.sentinel.proto, + flags=mocker.sentinel.flags, + ) + + # Assert + assert addrinfo_list is mocker.sentinel.addrinfo_list + backend.mock_run_in_thread.assert_awaited_once_with( + partial_eq( + mock_getaddrinfo, + mocker.sentinel.host, + mocker.sentinel.port, + family=mocker.sentinel.family, + type=mocker.sentinel.type, + proto=mocker.sentinel.proto, + flags=mocker.sentinel.flags, + ), + abandon_on_cancel=True, + ) + + async def test____getnameinfo____run_stdlib_function_in_thread( + self, + backend: MockBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_getnameinfo = mocker.patch("socket.getnameinfo", return_value=mocker.sentinel.resolved_addr) + + # Act + resolved_addr = await backend.getnameinfo( + sockaddr=mocker.sentinel.sockaddr, + flags=mocker.sentinel.flags, + ) + + # Assert + assert resolved_addr is mocker.sentinel.resolved_addr + backend.mock_run_in_thread.assert_awaited_once_with( + partial_eq( + mock_getnameinfo, + mocker.sentinel.sockaddr, + mocker.sentinel.flags, + ), + abandon_on_cancel=True, + ) diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_common_tools/__init__.py b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_common_tools/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_common_tools/test_dns_resolver.py b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_common_tools/test_dns_resolver.py new file mode 100644 index 00000000..cf767510 --- /dev/null +++ b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_common_tools/test_dns_resolver.py @@ -0,0 +1,963 @@ +from __future__ import annotations + +import asyncio +import asyncio.trsock +import errno +from collections.abc import Callable, Sequence +from socket import ( + AF_INET, + AF_INET6, + AF_UNSPEC, + AI_ADDRCONFIG, + AI_NUMERICHOST, + AI_NUMERICSERV, + AI_PASSIVE, + EAI_BADFLAGS, + EAI_NONAME, + IPPROTO_TCP, + IPPROTO_UDP, + SOCK_DGRAM, + SOCK_STREAM, + SocketType, + gaierror, +) +from typing import TYPE_CHECKING, Any, Literal, Protocol as TypingProtocol, assert_never + +from easynetwork.lowlevel._utils import error_from_errno +from easynetwork.lowlevel.api_async.backend._asyncio.tasks import TaskUtils +from easynetwork.lowlevel.api_async.backend._common.dns_resolver import BaseAsyncDNSResolver +from easynetwork.lowlevel.api_async.backend.abc import AsyncBackend + +import pytest + +from ....._utils import datagram_addrinfo_list, stream_addrinfo_list + +if TYPE_CHECKING: + from unittest.mock import AsyncMock, MagicMock + + from pytest_mock import MockerFixture + + +class MockedDNSResolver(BaseAsyncDNSResolver): + def __init__(self, event_loop: asyncio.AbstractEventLoop, mocker: MockerFixture) -> None: + self.mock_async_getaddrinfo: AsyncMock = mocker.patch.object(event_loop, "getaddrinfo", new_callable=mocker.AsyncMock) + self.mock_sock_connect: AsyncMock = mocker.AsyncMock(return_value=None) + + async def connect_socket(self, socket: SocketType, address: tuple[str, int] | tuple[str, int, int, int]) -> None: + return await self.mock_sock_connect(socket, address) + + +@pytest.fixture +def dns_resolver(event_loop: asyncio.AbstractEventLoop, mocker: MockerFixture) -> MockedDNSResolver: + return MockedDNSResolver(event_loop, mocker) + + +@pytest.fixture(autouse=True) +def mock_stdlib_socket_getaddrinfo(mocker: MockerFixture) -> MagicMock: + from socket import EAI_NONAME, gaierror + + return mocker.patch("socket.getaddrinfo", autospec=True, side_effect=gaierror(EAI_NONAME, "Name or service not known")) + + +@pytest.fixture +def mock_socket_ipv4(mock_socket_factory: Callable[[], MagicMock]) -> MagicMock: + return mock_socket_factory() + + +@pytest.fixture +def mock_socket_ipv6(mock_socket_factory: Callable[[], MagicMock]) -> MagicMock: + return mock_socket_factory() + + +@pytest.fixture(autouse=True) +def mock_socket_cls(mock_socket_ipv4: MagicMock, mock_socket_ipv6: MagicMock, mocker: MockerFixture) -> MagicMock: + def side_effect(family: int, type: int, proto: int) -> MagicMock: + if family == AF_INET6: + used_socket = mock_socket_ipv6 + elif family == AF_INET: + used_socket = mock_socket_ipv4 + else: + raise error_from_errno(errno.EAFNOSUPPORT) + + used_socket.family = family + used_socket.type = type + used_socket.proto = proto + return used_socket + + return mocker.patch("socket.socket", side_effect=side_effect) + + +@pytest.mark.asyncio +async def test____ensure_resolved____try_numeric_first( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + mock_stdlib_socket_getaddrinfo: MagicMock, +) -> None: + # Arrange + expected_result = stream_addrinfo_list(8080, families=[AF_INET]) + mock_stdlib_socket_getaddrinfo.side_effect = None + mock_stdlib_socket_getaddrinfo.return_value = expected_result + + # Act + info = await dns_resolver.ensure_resolved( + asyncio_backend, + "127.0.0.1", + 8080, + 123456789, + SOCK_STREAM, + proto=IPPROTO_TCP, + flags=AI_PASSIVE, + ) + + # Assert + assert info == expected_result + mock_stdlib_socket_getaddrinfo.assert_called_once_with( + "127.0.0.1", + 8080, + family=123456789, + type=SOCK_STREAM, + proto=IPPROTO_TCP, + flags=AI_PASSIVE | AI_NUMERICHOST | AI_NUMERICSERV, + ) + dns_resolver.mock_async_getaddrinfo.assert_not_awaited() + + +@pytest.mark.asyncio +async def test____ensure_resolved____try_numeric_first____success_but_return_empty_list( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + mock_stdlib_socket_getaddrinfo: MagicMock, +) -> None: + # Arrange + mock_stdlib_socket_getaddrinfo.side_effect = None + mock_stdlib_socket_getaddrinfo.return_value = [] + + # Act + with pytest.raises(OSError, match=r"^getaddrinfo\('127.0.0.1'\) returned empty list$"): + await dns_resolver.ensure_resolved( + asyncio_backend, + "127.0.0.1", + 8080, + 123456789, + SOCK_STREAM, + proto=IPPROTO_TCP, + flags=AI_PASSIVE, + ) + + # Assert + mock_stdlib_socket_getaddrinfo.assert_called_once_with( + "127.0.0.1", + 8080, + family=123456789, + type=SOCK_STREAM, + proto=IPPROTO_TCP, + flags=AI_PASSIVE | AI_NUMERICHOST | AI_NUMERICSERV, + ) + dns_resolver.mock_async_getaddrinfo.assert_not_awaited() + + +@pytest.mark.asyncio +async def test____ensure_resolved____fallback_to_async_getaddrinfo( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + mock_stdlib_socket_getaddrinfo: MagicMock, +) -> None: + # Arrange + expected_result = stream_addrinfo_list(8080, families=[AF_INET]) + mock_stdlib_socket_getaddrinfo.side_effect = gaierror(EAI_NONAME, "Name or service not known") + dns_resolver.mock_async_getaddrinfo.return_value = expected_result + + # Act + info = await dns_resolver.ensure_resolved( + asyncio_backend, + "127.0.0.1", + 8080, + 123456789, + SOCK_STREAM, + proto=IPPROTO_TCP, + flags=AI_PASSIVE, + ) + + # Assert + assert info == expected_result + dns_resolver.mock_async_getaddrinfo.assert_awaited_once_with( + "127.0.0.1", + 8080, + family=123456789, + type=SOCK_STREAM, + proto=IPPROTO_TCP, + flags=AI_PASSIVE, + ) + + +@pytest.mark.asyncio +async def test____ensure_resolved____fallback_to_async_getaddrinfo____success_but_return_empty_list( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + mock_stdlib_socket_getaddrinfo: MagicMock, +) -> None: + # Arrange + mock_stdlib_socket_getaddrinfo.side_effect = gaierror(EAI_NONAME, "Name or service not known") + dns_resolver.mock_async_getaddrinfo.return_value = [] + + # Act + with pytest.raises(OSError, match=r"^getaddrinfo\('127.0.0.1'\) returned empty list$"): + await dns_resolver.ensure_resolved( + asyncio_backend, + "127.0.0.1", + 8080, + 123456789, + SOCK_STREAM, + proto=IPPROTO_TCP, + flags=AI_PASSIVE, + ) + + # Assert + dns_resolver.mock_async_getaddrinfo.assert_awaited_once_with( + "127.0.0.1", + 8080, + family=123456789, + type=SOCK_STREAM, + proto=IPPROTO_TCP, + flags=AI_PASSIVE, + ) + + +@pytest.mark.asyncio +async def test____ensure_resolved____propagate_unrelated_gaierror( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + mock_stdlib_socket_getaddrinfo: MagicMock, +) -> None: + # Arrange + mock_stdlib_socket_getaddrinfo.side_effect = gaierror(EAI_BADFLAGS, "Invalid flags") + + # Act + with pytest.raises(gaierror): + await dns_resolver.ensure_resolved( + asyncio_backend, + "127.0.0.1", + 8080, + 123456789, + SOCK_STREAM, + proto=IPPROTO_TCP, + flags=AI_PASSIVE, + ) + + # Assert + mock_stdlib_socket_getaddrinfo.assert_called_once_with( + "127.0.0.1", + 8080, + family=123456789, + type=SOCK_STREAM, + proto=IPPROTO_TCP, + flags=AI_PASSIVE | AI_NUMERICHOST | AI_NUMERICSERV, + ) + dns_resolver.mock_async_getaddrinfo.assert_not_awaited() + + +@pytest.mark.asyncio +async def test____resolve_listener_addresses____bind_to_any_interfaces( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, +) -> None: + # Arrange + local_port = 5000 + addrinfo_list = [ + ( + AF_INET, + SOCK_STREAM, + IPPROTO_TCP, + "", + ("0.0.0.0", local_port), + ), + ( + AF_INET6, + SOCK_STREAM, + IPPROTO_TCP, + "", + ("::", local_port), + ), + ] + dns_resolver.mock_async_getaddrinfo.return_value = addrinfo_list + + # Act + listeners_addrinfo = await dns_resolver.resolve_listener_addresses( + asyncio_backend, + hosts=[None], + port=local_port, + socktype=SOCK_STREAM, + ) + + # Assert + dns_resolver.mock_async_getaddrinfo.assert_awaited_once_with( + None, + local_port, + family=AF_UNSPEC, + type=SOCK_STREAM, + proto=0, + flags=AI_PASSIVE | AI_ADDRCONFIG, + ) + assert listeners_addrinfo == sorted(addrinfo_list) + + +@pytest.mark.asyncio +async def test____resolve_listener_addresses____bind_to_several_hosts( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + mocker: MockerFixture, +) -> None: + # Arrange + local_hosts = ["0.0.0.0", "::"] + local_port = 5000 + addrinfo_list = [ + ( + AF_INET, + SOCK_STREAM, + IPPROTO_TCP, + "", + ("0.0.0.0", local_port), + ), + ( + AF_INET6, + SOCK_STREAM, + IPPROTO_TCP, + "", + ("::", local_port), + ), + ] + dns_resolver.mock_async_getaddrinfo.side_effect = [[info] for info in addrinfo_list] + + # Act + listeners_addrinfo = await dns_resolver.resolve_listener_addresses( + asyncio_backend, + hosts=local_hosts, + port=local_port, + socktype=SOCK_STREAM, + ) + + # Assert + assert dns_resolver.mock_async_getaddrinfo.await_args_list == [ + mocker.call( + host, + local_port, + family=AF_UNSPEC, + type=SOCK_STREAM, + proto=0, + flags=AI_PASSIVE | AI_ADDRCONFIG, + ) + for host in local_hosts + ] + assert listeners_addrinfo == sorted(addrinfo_list) + + +@pytest.mark.asyncio +async def test____resolve_listener_addresses____error_getaddrinfo_returns_empty_list( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, +) -> None: + # Arrange + local_host = "localhost" + local_port = 5000 + + dns_resolver.mock_async_getaddrinfo.return_value = [] + + # Act + with pytest.raises(ExceptionGroup) as exc_info: + await dns_resolver.resolve_listener_addresses( + asyncio_backend, + hosts=[local_host], + port=local_port, + socktype=SOCK_STREAM, + ) + + # Assert + os_errors, exc = exc_info.value.split(OSError) + assert exc is None + assert os_errors is not None + assert len(os_errors.exceptions) == 1 + assert str(os_errors.exceptions[0]) == "getaddrinfo('localhost') returned empty list" + del os_errors + + dns_resolver.mock_async_getaddrinfo.assert_awaited_once_with( + local_host, + local_port, + family=AF_UNSPEC, + type=SOCK_STREAM, + proto=0, + flags=AI_PASSIVE | AI_ADDRCONFIG, + ) + + +class _CreateConnectionCallable(TypingProtocol): + async def __call__( + self, + backend: AsyncBackend, + host: str, + port: int, + *, + local_address: tuple[str, int] | None = None, + ) -> SocketType: ... + + +class _AddrInfoListFactory(TypingProtocol): + def __call__( + self, + port: int, + families: Sequence[int] = ..., + ) -> Sequence[tuple[int, int, int, str, tuple[Any, ...]]]: ... + + +@pytest.fixture(params=[SOCK_STREAM, SOCK_DGRAM], ids=lambda sock_type: f"sock_type=={sock_type!r}") +def connection_socktype(request: pytest.FixtureRequest) -> int: + return request.param + + +@pytest.fixture +def addrinfo_list_factory(connection_socktype: int) -> _AddrInfoListFactory: + if connection_socktype == SOCK_STREAM: + return stream_addrinfo_list + if connection_socktype == SOCK_DGRAM: + return datagram_addrinfo_list + pytest.fail("Invalid fixture argument") + + +def create_connection_of_socktype(dns_resolver: MockedDNSResolver, connection_socktype: int) -> _CreateConnectionCallable: + if connection_socktype == SOCK_STREAM: + return dns_resolver.create_stream_connection + if connection_socktype == SOCK_DGRAM: + return dns_resolver.create_datagram_connection + pytest.fail("Invalid fixture argument") + + +@pytest.mark.asyncio +@pytest.mark.parametrize("with_local_address", [False, True], ids=lambda boolean: f"with_local_address=={boolean}") +async def test____create_connection____default( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + addrinfo_list_factory: _AddrInfoListFactory, + connection_socktype: int, + with_local_address: bool, + mock_socket_cls: MagicMock, + mock_socket_ipv4: MagicMock, + mock_socket_ipv6: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + expected_proto = IPPROTO_TCP if connection_socktype == SOCK_STREAM else IPPROTO_UDP + remote_host, remote_port = "localhost", 12345 + local_address: tuple[str, int] | None = ("localhost", 11111) if with_local_address else None + + if local_address is None: + dns_resolver.mock_async_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] + else: + dns_resolver.mock_async_getaddrinfo.side_effect = [ + addrinfo_list_factory(remote_port), + addrinfo_list_factory(local_address[1]), + ] + + # Act + socket = await create_connection_of_socktype(dns_resolver, connection_socktype)( + asyncio_backend, + remote_host, + remote_port, + local_address=local_address, + ) + + # Assert + if local_address is None: + assert dns_resolver.mock_async_getaddrinfo.await_args_list == [ + mocker.call(remote_host, remote_port, family=AF_UNSPEC, type=connection_socktype, proto=0, flags=0), + ] + else: + assert dns_resolver.mock_async_getaddrinfo.await_args_list == [ + mocker.call(remote_host, remote_port, family=AF_UNSPEC, type=connection_socktype, proto=0, flags=0), + mocker.call(*local_address, family=AF_UNSPEC, type=connection_socktype, proto=0, flags=0), + ] + + mock_socket_cls.assert_called_once_with(AF_INET6, connection_socktype, expected_proto) + assert socket is mock_socket_ipv6 + + mock_socket_ipv6.setblocking.assert_called_once_with(False) + if local_address is None: + mock_socket_ipv6.bind.assert_not_called() + else: + mock_socket_ipv6.bind.assert_called_once_with(("::1", 11111, 0, 0)) + dns_resolver.mock_sock_connect.assert_awaited_once_with(mock_socket_ipv6, ("::1", 12345, 0, 0)) + mock_socket_ipv6.close.assert_not_called() + + mock_socket_ipv4.setblocking.assert_not_called() + mock_socket_ipv4.bind.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("fail_on", ["socket", "bind", "connect"], ids=lambda fail_on: f"fail_on=={fail_on}") +async def test____create_connection____first_failed( + fail_on: Literal["socket", "bind", "connect"], + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + addrinfo_list_factory: _AddrInfoListFactory, + connection_socktype: int, + mock_socket_cls: MagicMock, + mock_socket_ipv4: MagicMock, + mock_socket_ipv6: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + remote_host, remote_port = "localhost", 12345 + local_address: tuple[str, int] | None = ("localhost", 11111) if fail_on == "bind" else None + + if local_address is None: + dns_resolver.mock_async_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] + else: + dns_resolver.mock_async_getaddrinfo.side_effect = [ + addrinfo_list_factory(remote_port), + addrinfo_list_factory(local_address[1]), + ] + + match fail_on: + case "socket": + mock_socket_cls.side_effect = [error_from_errno(errno.EAFNOSUPPORT), mock_socket_ipv4] + case "bind": + mock_socket_ipv6.bind.side_effect = error_from_errno(errno.EADDRINUSE) + case "connect": + dns_resolver.mock_sock_connect.side_effect = [error_from_errno(errno.ECONNREFUSED), None] + case _: + assert_never(fail_on) + + # Act + socket = await create_connection_of_socktype(dns_resolver, connection_socktype)( + asyncio_backend, + remote_host, + remote_port, + local_address=local_address, + ) + + # Assert + if connection_socktype == SOCK_STREAM: + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), + ] + else: + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET6, SOCK_DGRAM, IPPROTO_UDP), + mocker.call(AF_INET, SOCK_DGRAM, IPPROTO_UDP), + ] + assert socket is mock_socket_ipv4 + + if fail_on != "socket": + mock_socket_ipv6.setblocking.assert_called_once_with(False) + if local_address is None: + mock_socket_ipv6.bind.assert_not_called() + else: + mock_socket_ipv6.bind.assert_called_once_with(("::1", 11111, 0, 0)) + match fail_on: + case "bind": + assert mocker.call(mock_socket_ipv6, ("::1", 12345, 0, 0)) not in dns_resolver.mock_sock_connect.await_args_list + case "connect": + dns_resolver.mock_sock_connect.assert_any_await(mock_socket_ipv6, ("::1", 12345, 0, 0)) + case _: + assert_never(fail_on) + mock_socket_ipv6.close.assert_called_once_with() + + mock_socket_ipv4.setblocking.assert_called_once_with(False) + if local_address is None: + mock_socket_ipv4.bind.assert_not_called() + else: + mock_socket_ipv4.bind.assert_called_once_with(("127.0.0.1", 11111)) + dns_resolver.mock_sock_connect.assert_awaited_with(mock_socket_ipv4, ("127.0.0.1", 12345)) + mock_socket_ipv4.close.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("fail_on", ["socket", "bind", "connect"], ids=lambda fail_on: f"fail_on=={fail_on}") +async def test____create_connection____all_failed( + fail_on: Literal["socket", "bind", "connect"], + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + addrinfo_list_factory: _AddrInfoListFactory, + connection_socktype: int, + mock_socket_cls: MagicMock, + mock_socket_ipv4: MagicMock, + mock_socket_ipv6: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + remote_host, remote_port = "localhost", 12345 + local_address: tuple[str, int] | None = ("localhost", 11111) if fail_on == "bind" else None + + if local_address is None: + dns_resolver.mock_async_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] + else: + dns_resolver.mock_async_getaddrinfo.side_effect = [ + addrinfo_list_factory(remote_port), + addrinfo_list_factory(local_address[1]), + ] + + match fail_on: + case "socket": + mock_socket_cls.side_effect = error_from_errno(errno.EAFNOSUPPORT) + case "bind": + mock_socket_ipv4.bind.side_effect = error_from_errno(errno.EADDRINUSE) + mock_socket_ipv6.bind.side_effect = error_from_errno(errno.EADDRINUSE) + case "connect": + dns_resolver.mock_sock_connect.side_effect = error_from_errno(errno.ECONNREFUSED) + case _: + assert_never(fail_on) + + # Act + with pytest.raises(ExceptionGroup) as exc_info: + await create_connection_of_socktype(dns_resolver, connection_socktype)( + asyncio_backend, + remote_host, + remote_port, + local_address=local_address, + ) + + # Assert + os_errors, exc = exc_info.value.split(OSError) + assert exc is None + assert os_errors is not None + assert len(os_errors.exceptions) == 2 + assert all(isinstance(exc, OSError) for exc in os_errors.exceptions) + del os_errors + + if connection_socktype == SOCK_STREAM: + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), + ] + else: + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET6, SOCK_DGRAM, IPPROTO_UDP), + mocker.call(AF_INET, SOCK_DGRAM, IPPROTO_UDP), + ] + + if fail_on != "socket": + mock_socket_ipv4.setblocking.assert_called_once_with(False) + mock_socket_ipv6.setblocking.assert_called_once_with(False) + if local_address is None: + mock_socket_ipv4.bind.assert_not_called() + mock_socket_ipv6.bind.assert_not_called() + else: + mock_socket_ipv4.bind.assert_called_once_with(("127.0.0.1", 11111)) + mock_socket_ipv6.bind.assert_called_once_with(("::1", 11111, 0, 0)) + match fail_on: + case "bind" | "socket": + assert mocker.call(mock_socket_ipv4, ("127.0.0.1", 12345)) not in dns_resolver.mock_sock_connect.await_args_list + assert mocker.call(mock_socket_ipv6, ("::1", 12345, 0, 0)) not in dns_resolver.mock_sock_connect.await_args_list + case "connect": + dns_resolver.mock_sock_connect.assert_any_await(mock_socket_ipv4, ("127.0.0.1", 12345)) + dns_resolver.mock_sock_connect.assert_any_await(mock_socket_ipv6, ("::1", 12345, 0, 0)) + case _: + assert_never(fail_on) + mock_socket_ipv4.close.assert_called_once_with() + mock_socket_ipv6.close.assert_called_once_with() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("fail_on", ["socket", "connect"], ids=lambda fail_on: f"fail_on=={fail_on}") +async def test____create_connection____unrelated_exception( + fail_on: Literal["socket", "connect"], + connection_socktype: int, + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + addrinfo_list_factory: _AddrInfoListFactory, + mock_socket_cls: MagicMock, + mock_socket_ipv6: MagicMock, +) -> None: + # Arrange + remote_host, remote_port = "localhost", 12345 + + dns_resolver.mock_async_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port)] + expected_failure_exception = BaseException() + + match fail_on: + case "socket": + mock_socket_cls.side_effect = expected_failure_exception + case "connect": + dns_resolver.mock_sock_connect.side_effect = expected_failure_exception + case _: + assert_never(fail_on) + + # Act + with pytest.raises(BaseException) as exc_info: + await create_connection_of_socktype(dns_resolver, connection_socktype)(asyncio_backend, remote_host, remote_port) + + # Assert + if connection_socktype == SOCK_STREAM: + assert isinstance(exc_info.value, BaseExceptionGroup) + assert len(exc_info.value.exceptions) == 1 + assert exc_info.value.exceptions[0] is expected_failure_exception + else: + assert exc_info.value is expected_failure_exception + if fail_on != "socket": + mock_socket_ipv6.close.assert_called_once_with() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("fail_on", ["remote_address", "local_address"], ids=lambda fail_on: f"fail_on=={fail_on}") +async def test____create_connection____getaddrinfo_returned_empty_list( + fail_on: Literal["remote_address", "local_address"], + connection_socktype: int, + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + addrinfo_list_factory: _AddrInfoListFactory, + mock_socket_cls: MagicMock, + mock_socket_ipv4: MagicMock, + mock_socket_ipv6: MagicMock, +) -> None: + # Arrange + remote_host, remote_port = "localhost", 12345 + local_address: tuple[str, int] = ("localhost", 11111) + + match fail_on: + case "remote_address": + dns_resolver.mock_async_getaddrinfo.side_effect = [[]] + case "local_address": + dns_resolver.mock_async_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port), []] + case _: + assert_never(fail_on) + + # Act + with pytest.raises(OSError, match=r"getaddrinfo\('localhost'\) returned empty list"): + await create_connection_of_socktype(dns_resolver, connection_socktype)( + asyncio_backend, + remote_host, + remote_port, + local_address=local_address, + ) + + # Assert + mock_socket_cls.assert_not_called() + mock_socket_ipv4.bind.assert_not_called() + mock_socket_ipv6.bind.assert_not_called() + dns_resolver.mock_sock_connect.assert_not_called() + + +@pytest.mark.asyncio +async def test____create_connection____getaddrinfo_return_mismatch( + connection_socktype: int, + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + addrinfo_list_factory: _AddrInfoListFactory, + mock_socket_ipv4: MagicMock, + mock_socket_ipv6: MagicMock, +) -> None: + # Arrange + remote_host, remote_port = "localhost", 12345 + local_address: tuple[str, int] = ("localhost", 11111) + + dns_resolver.mock_async_getaddrinfo.side_effect = [ + addrinfo_list_factory(remote_port, families=[AF_INET6]), + addrinfo_list_factory(local_address[1], families=[AF_INET]), + ] + + # Act + with pytest.raises(ExceptionGroup) as exc_info: + await create_connection_of_socktype(dns_resolver, connection_socktype)( + asyncio_backend, + remote_host, + remote_port, + local_address=local_address, + ) + + # Assert + os_errors, exc = exc_info.value.split(OSError) + assert exc is None + assert os_errors is not None + assert len(os_errors.exceptions) == 1 + assert str(os_errors.exceptions[0]) == f"no matching local address with family={AF_INET6!r} found" + del os_errors + + mock_socket_ipv4.bind.assert_not_called() + mock_socket_ipv6.bind.assert_not_called() + dns_resolver.mock_sock_connect.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("connection_socktype", [SOCK_STREAM], indirect=True, ids=repr) +@pytest.mark.flaky(retries=3) +async def test____create_stream_connection____happy_eyeballs_delay____connect_cancellation( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + addrinfo_list_factory: _AddrInfoListFactory, + mock_socket_cls: MagicMock, + mock_socket_ipv4: MagicMock, + mock_socket_ipv6: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + event_loop = asyncio.get_running_loop() + timestamps: list[float] = [] + remote_host, remote_port = "localhost", 12345 + dns_resolver.mock_async_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port, families=[AF_INET6, AF_INET])] + + async def connect_side_effect(sock: SocketType, address: tuple[Any, ...]) -> None: + timestamps.append(event_loop.time()) + if sock.family == AF_INET6: + await asyncio.sleep(1) + else: + await asyncio.sleep(0.01) + + dns_resolver.mock_sock_connect.side_effect = connect_side_effect + + # Act + socket = await dns_resolver.create_stream_connection(asyncio_backend, remote_host, remote_port, happy_eyeballs_delay=0.5) + + # Assert + assert socket is mock_socket_ipv4 + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), + ] + + mock_socket_ipv6.close.assert_called_once_with() + mock_socket_ipv4.close.assert_not_called() + + ipv6_start_time, ipv4_start_time = timestamps + assert ipv4_start_time - ipv6_start_time == pytest.approx(0.5, rel=1e-1) + + +@pytest.mark.asyncio +@pytest.mark.parametrize("connection_socktype", [SOCK_STREAM], indirect=True, ids=repr) +async def test____create_stream_connection____happy_eyeballs_delay____connect_too_late( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + addrinfo_list_factory: _AddrInfoListFactory, + mock_socket_cls: MagicMock, + mock_socket_ipv4: MagicMock, + mock_socket_ipv6: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + remote_host, remote_port = "localhost", 12345 + dns_resolver.mock_async_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port, families=[AF_INET6, AF_INET])] + + async def connect_side_effect(sock: SocketType, address: tuple[Any, ...]) -> None: + try: + await asyncio.sleep(1) + except asyncio.CancelledError: + TaskUtils.current_asyncio_task().uncancel() + await asyncio.sleep(0) + + dns_resolver.mock_sock_connect.side_effect = connect_side_effect + + # Act + socket = await dns_resolver.create_stream_connection(asyncio_backend, remote_host, remote_port, happy_eyeballs_delay=0.25) + + # Assert + assert socket is mock_socket_ipv6 + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), + ] + + mock_socket_ipv4.close.assert_called_once_with() + mock_socket_ipv6.close.assert_not_called() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("connection_socktype", [SOCK_STREAM], indirect=True, ids=repr) +async def test____create_stream_connection____happy_eyeballs_delay____winner_closed_because_of_exception_in_another_task( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + addrinfo_list_factory: _AddrInfoListFactory, + mock_socket_cls: MagicMock, + mock_socket_ipv4: MagicMock, + mock_socket_ipv6: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + expected_failure_exception = BaseException("error") + remote_host, remote_port = "localhost", 12345 + dns_resolver.mock_async_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port, families=[AF_INET6, AF_INET])] + + async def connect_side_effect(sock: SocketType, address: tuple[Any, ...]) -> None: + try: + await asyncio.sleep(0.5) + except asyncio.CancelledError: + raise expected_failure_exception from None + + dns_resolver.mock_sock_connect.side_effect = connect_side_effect + + # Act + with pytest.raises(BaseExceptionGroup) as exc_info: + await dns_resolver.create_stream_connection(asyncio_backend, remote_host, remote_port, happy_eyeballs_delay=0.25) + + # Assert + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), + ] + assert list(exc_info.value.exceptions) == [expected_failure_exception] + + mock_socket_ipv4.close.assert_called_once_with() + mock_socket_ipv6.close.assert_called_once_with() + + +@pytest.mark.asyncio +@pytest.mark.parametrize("connection_socktype", [SOCK_STREAM], indirect=True, ids=repr) +async def test____create_stream_connection____happy_eyeballs_delay____addrinfo_reordering____prioritize_ipv6_over_ipv4( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + addrinfo_list_factory: _AddrInfoListFactory, + mock_socket_cls: MagicMock, + mock_socket_ipv6: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + remote_host, remote_port = "localhost", 12345 + dns_resolver.mock_async_getaddrinfo.side_effect = [addrinfo_list_factory(remote_port, families=[AF_INET, AF_INET6])] + + async def connect_side_effect(sock: SocketType, address: tuple[Any, ...]) -> None: + await asyncio.sleep(0.5) + + dns_resolver.mock_sock_connect.side_effect = connect_side_effect + + # Act + socket = await dns_resolver.create_stream_connection(asyncio_backend, remote_host, remote_port, happy_eyeballs_delay=0.25) + + # Assert + assert socket is mock_socket_ipv6 + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), + ] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("connection_socktype", [SOCK_STREAM], indirect=True, ids=repr) +async def test____create_stream_connection____happy_eyeballs_delay____addrinfo_reordering____interleave_families( + asyncio_backend: AsyncBackend, + dns_resolver: MockedDNSResolver, + addrinfo_list_factory: _AddrInfoListFactory, + mock_socket_cls: MagicMock, + mock_socket_ipv6: MagicMock, + mocker: MockerFixture, +) -> None: + # Arrange + remote_host, remote_port = "localhost", 12345 + dns_resolver.mock_async_getaddrinfo.side_effect = [ + addrinfo_list_factory(remote_port, families=[AF_INET6, AF_INET6, AF_INET6, AF_INET, AF_INET, AF_INET]), + ] + + async def connect_side_effect(sock: SocketType, address: tuple[Any, ...]) -> None: + await asyncio.sleep(1) + + dns_resolver.mock_sock_connect.side_effect = connect_side_effect + + # Act + socket = await dns_resolver.create_stream_connection(asyncio_backend, remote_host, remote_port, happy_eyeballs_delay=0.1) + + # Assert + assert socket is mock_socket_ipv6 + assert mock_socket_cls.call_args_list == [ + mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET6, SOCK_STREAM, IPPROTO_TCP), + mocker.call(AF_INET, SOCK_STREAM, IPPROTO_TCP), + ] diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_utils.py b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_utils.py index 1af026de..295e2ae4 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_utils.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_backend/test_utils.py @@ -3,10 +3,14 @@ from typing import TYPE_CHECKING from easynetwork.lowlevel.api_async.backend._asyncio.backend import AsyncIOBackend -from easynetwork.lowlevel.api_async.backend.utils import ensure_backend +from easynetwork.lowlevel.api_async.backend._trio.backend import TrioBackend +from easynetwork.lowlevel.api_async.backend.abc import AsyncBackend +from easynetwork.lowlevel.api_async.backend.utils import BuiltinAsyncBackendLiteral, ensure_backend, new_builtin_backend import pytest +from ...._utils import mock_import_module_not_found + if TYPE_CHECKING: from unittest.mock import MagicMock @@ -18,14 +22,25 @@ def mock_current_async_library(mocker: MockerFixture) -> MagicMock: return mocker.patch("sniffio.current_async_library", autospec=True) -def test____ensure_backend____valid_string_literal(mock_current_async_library: MagicMock) -> None: +@pytest.mark.parametrize( + ["name", "expected_type"], + [ + pytest.param("asyncio", AsyncIOBackend, id="asyncio"), + pytest.param("trio", TrioBackend, id="trio", marks=pytest.mark.feature_trio), + ], +) +def test____ensure_backend____valid_string_literal( + name: BuiltinAsyncBackendLiteral, + expected_type: type[AsyncBackend], + mock_current_async_library: MagicMock, +) -> None: # Arrange # Act - backend = ensure_backend("asyncio") + backend = ensure_backend(name) # Assert - assert isinstance(backend, AsyncIOBackend) + assert isinstance(backend, expected_type) mock_current_async_library.assert_not_called() @@ -54,28 +69,85 @@ def test____ensure_backend____invalid_string_literal(mock_current_async_library: # Arrange # Act & Assert - with pytest.raises(NotImplementedError, match=r"^trio$"): - _ = ensure_backend("trio") # type: ignore[arg-type] + with pytest.raises(NotImplementedError, match=r"^curio$"): + _ = ensure_backend("curio") # type: ignore[arg-type] mock_current_async_library.assert_not_called() -def test____ensure_backend____None____current_async_library_is_supported(mock_current_async_library: MagicMock) -> None: +@pytest.mark.parametrize( + ["name", "expected_type"], + [ + pytest.param("asyncio", AsyncIOBackend, id="asyncio"), + pytest.param("trio", TrioBackend, id="trio", marks=pytest.mark.feature_trio), + ], +) +def test____ensure_backend____None____current_async_library_is_supported( + name: BuiltinAsyncBackendLiteral, + expected_type: type[AsyncBackend], + mock_current_async_library: MagicMock, +) -> None: # Arrange - mock_current_async_library.return_value = "asyncio" + mock_current_async_library.return_value = name # Act backend = ensure_backend(None) # Assert - assert isinstance(backend, AsyncIOBackend) + assert isinstance(backend, expected_type) mock_current_async_library.assert_called_once() def test____ensure_backend____None____current_async_library_is_not_supported(mock_current_async_library: MagicMock) -> None: # Arrange - mock_current_async_library.return_value = "trio" + mock_current_async_library.return_value = "curio" # Act & Assert - with pytest.raises(NotImplementedError, match=r"^trio$"): + with pytest.raises(NotImplementedError, match=r"^curio$"): _ = ensure_backend(None) mock_current_async_library.assert_called_once() + + +@pytest.mark.parametrize( + ["name", "expected_type"], + [ + pytest.param("asyncio", AsyncIOBackend, id="asyncio"), + pytest.param("trio", TrioBackend, id="trio", marks=pytest.mark.feature_trio), + ], +) +def test____new_builtin_backend____valid_string_literal( + name: BuiltinAsyncBackendLiteral, + expected_type: type[AsyncBackend], +) -> None: + # Arrange + + # Act + backend = new_builtin_backend(name) + + # Assert + assert isinstance(backend, expected_type) + + +def test____new_builtin_backend____invalid_string_literal() -> None: + # Arrange + + # Act & Assert + with pytest.raises(NotImplementedError, match=r"^curio$"): + _ = ensure_backend("curio") # type: ignore[arg-type] + + +def test____new_builtin_backend_____trio_dependency_missing(mocker: MockerFixture) -> None: + # Arrange + mock_import = mock_import_module_not_found({"trio"}, mocker) + + # Act + with pytest.raises(ModuleNotFoundError) as exc_info: + try: + _ = new_builtin_backend("trio") + finally: + mocker.stop(mock_import) + + # Assert + mock_import.assert_any_call("trio", mocker.ANY, mocker.ANY, None, 0) + assert exc_info.value.args[0] == "trio dependencies are missing. Consider adding 'trio' extra" + assert exc_info.value.__notes__ == ['example: pip install "easynetwork[trio]"'] + assert isinstance(exc_info.value.__cause__, ModuleNotFoundError) diff --git a/tests/unit_test/test_async/test_trio_backend/__init__.py b/tests/unit_test/test_async/test_trio_backend/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/unit_test/test_async/test_trio_backend/conftest.py b/tests/unit_test/test_async/test_trio_backend/conftest.py new file mode 100644 index 00000000..79a9c9da --- /dev/null +++ b/tests/unit_test/test_async/test_trio_backend/conftest.py @@ -0,0 +1,146 @@ +from __future__ import annotations + +from collections.abc import Callable +from socket import AF_INET, IPPROTO_TCP, IPPROTO_UDP, SOCK_DGRAM, SOCK_STREAM +from typing import TYPE_CHECKING + +from easynetwork.lowlevel.api_async.backend._trio.backend import TrioBackend + +import pytest + +if TYPE_CHECKING: + from unittest.mock import MagicMock + + from pytest_mock import MockerFixture + + +@pytest.fixture(scope="package") +def trio_backend() -> TrioBackend: + return TrioBackend() + + +@pytest.fixture +def mock_trio_socket_from_stdlib(mocker: MockerFixture) -> MagicMock: + return mocker.patch("trio.socket.from_stdlib_socket", autospec=True) + + +@pytest.fixture +def mock_trio_socket_factory(mocker: MockerFixture) -> Callable[[], MagicMock]: + from types import MethodType + + import trio + + def factory(family: int = -1, type: int = -1, proto: int = -1, fileno: int = 123) -> MagicMock: + if family == -1: + family = AF_INET + if type == -1: + type = SOCK_STREAM + if proto == -1: + proto = 0 + mock_socket = mocker.NonCallableMagicMock(spec=trio.socket.SocketType) + mock_socket.family = family + mock_socket.type = type + mock_socket.proto = proto + mock_socket.fileno.return_value = fileno + + def close_side_effect() -> None: + mock_socket.fileno.return_value = -1 + + def detached_side_effect() -> int: + to_return, mock_socket.fileno.return_value = mock_socket.fileno.return_value, -1 + return to_return + + mock_socket.close.side_effect = close_side_effect + mock_socket.detach.side_effect = detached_side_effect + for async_method in ("recv", "recv_into", "recvfrom", "recvfrom_into", "send", "sendto", "sendmsg"): + if hasattr(mock_socket, async_method): + setattr( + mock_socket, + async_method, + mocker.AsyncMock(spec=MethodType(getattr(trio.socket.SocketType, async_method), mock_socket)), + ) + return mock_socket + + return factory + + +@pytest.fixture +def mock_trio_tcp_socket_factory(mock_trio_socket_factory: Callable[[int, int, int], MagicMock]) -> Callable[[], MagicMock]: + def factory(family: int = -1) -> MagicMock: + return mock_trio_socket_factory(family, SOCK_STREAM, IPPROTO_TCP) + + return factory + + +@pytest.fixture +def mock_trio_tcp_socket(mock_trio_tcp_socket_factory: Callable[[], MagicMock]) -> MagicMock: + return mock_trio_tcp_socket_factory() + + +@pytest.fixture +def mock_trio_udp_socket_factory(mock_trio_socket_factory: Callable[[int, int, int], MagicMock]) -> Callable[[], MagicMock]: + def factory(family: int = -1) -> MagicMock: + return mock_trio_socket_factory(family, SOCK_DGRAM, IPPROTO_UDP) + + return factory + + +@pytest.fixture +def mock_trio_udp_socket(mock_trio_udp_socket_factory: Callable[[], MagicMock]) -> MagicMock: + return mock_trio_udp_socket_factory() + + +@pytest.fixture +def mock_trio_socket_stream_factory( + mocker: MockerFixture, + mock_trio_tcp_socket_factory: Callable[[], MagicMock], +) -> Callable[[], MagicMock]: + import trio + + def factory(socket: trio.socket.SocketType | None = None) -> MagicMock: + mock_stream = mocker.NonCallableMagicMock(spec=trio.SocketStream) + if socket is None: + mock_stream.socket = mock_trio_tcp_socket_factory() + else: + mock_stream.socket = socket + + async def aclose_side_effect() -> None: + mock_stream.socket.close() + + mock_stream.aclose.side_effect = aclose_side_effect + return mock_stream + + return factory + + +@pytest.fixture +def mock_trio_socket_stream(mock_trio_socket_stream_factory: Callable[[], MagicMock]) -> MagicMock: + return mock_trio_socket_stream_factory() + + +@pytest.fixture +def mock_trio_socket_listener_factory( + mocker: MockerFixture, + mock_trio_tcp_socket_factory: Callable[[], MagicMock], +) -> Callable[[], MagicMock]: + import trio + + def factory(socket: trio.socket.SocketType | None = None) -> MagicMock: + mock_stream = mocker.NonCallableMagicMock(spec=trio.SocketListener) + if socket is None: + mock_stream.socket = mock_trio_tcp_socket_factory() + else: + mock_stream.socket = socket + + async def aclose_side_effect() -> None: + mock_stream.socket.close() + + mock_stream.aclose.side_effect = aclose_side_effect + return mock_stream + + return factory + + +@pytest.fixture +def mock_trio_socket_listener(mock_trio_socket_listener_factory: Callable[[], MagicMock]) -> MagicMock: + return mock_trio_socket_listener_factory() diff --git a/tests/unit_test/test_async/test_trio_backend/test_backend.py b/tests/unit_test/test_async/test_trio_backend/test_backend.py new file mode 100644 index 00000000..0c5c6af6 --- /dev/null +++ b/tests/unit_test/test_async/test_trio_backend/test_backend.py @@ -0,0 +1,701 @@ +from __future__ import annotations + +from collections.abc import Callable, Sequence +from socket import AF_INET, AF_INET6, AF_UNSPEC, IPPROTO_TCP, IPPROTO_UDP, SOCK_DGRAM, SOCK_STREAM +from typing import TYPE_CHECKING, Any, Final + +from easynetwork.lowlevel.api_async.backend._trio.backend import TrioBackend + +import pytest + +if TYPE_CHECKING: + from unittest.mock import AsyncMock, MagicMock + + from pytest_mock import MockerFixture + + +_TRIO_BACKEND_MODULE: Final[str] = "easynetwork.lowlevel.api_async.backend._trio" + + +@pytest.mark.feature_trio(async_test_auto_mark=True) +class TestTrioBackend: + @pytest.fixture + @staticmethod + def backend() -> TrioBackend: + return TrioBackend() + + @pytest.fixture(params=[("local_address", 12345), None], ids=lambda addr: f"local_address=={addr}") + @staticmethod + def local_address(request: pytest.FixtureRequest) -> tuple[str, int] | None: + return request.param + + @pytest.fixture(params=[("remote_address", 5000)], ids=lambda addr: f"remote_address=={addr}") + @staticmethod + def remote_address(request: pytest.FixtureRequest) -> tuple[str, int] | None: + return request.param + + async def test____get_cancelled_exc_class____returns_trio_Cancelled( + self, + backend: TrioBackend, + ) -> None: + # Arrange + import trio + + # Act & Assert + assert backend.get_cancelled_exc_class() is trio.Cancelled + + async def test____current_time____use_event_loop_time( + self, + backend: TrioBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + import trio + + mock_current_time: MagicMock = mocker.patch("trio.current_time", side_effect=trio.current_time) + + # Act + current_time = backend.current_time() + + # Assert + mock_current_time.assert_called_once_with() + assert current_time > 0 + + async def test____sleep____use_trio_sleep( + self, + backend: TrioBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_sleep: AsyncMock = mocker.patch("trio.sleep", autospec=True) + + # Act + await backend.sleep(123456789) + + # Assert + mock_sleep.assert_awaited_once_with(123456789) + + async def test____get_current_task____compute_task_info( + self, + backend: TrioBackend, + ) -> None: + # Arrange + import trio + + current_task = trio.lowlevel.current_task() + + # Act + task_info = backend.get_current_task() + + # Assert + assert task_info.id == id(current_task) + assert task_info.name == current_task.name + assert task_info.coro is current_task.coro + + async def test____getaddrinfo____use_loop_getaddrinfo( + self, + backend: TrioBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_trio_getaddrinfo = mocker.patch( + "trio.socket.getaddrinfo", + new_callable=mocker.AsyncMock, + return_value=mocker.sentinel.addrinfo_list, + ) + + # Act + addrinfo_list = await backend.getaddrinfo( + host=mocker.sentinel.host, + port=mocker.sentinel.port, + family=mocker.sentinel.family, + type=mocker.sentinel.type, + proto=mocker.sentinel.proto, + flags=mocker.sentinel.flags, + ) + + # Assert + assert addrinfo_list is mocker.sentinel.addrinfo_list + mock_trio_getaddrinfo.assert_awaited_once_with( + mocker.sentinel.host, + mocker.sentinel.port, + family=mocker.sentinel.family, + type=mocker.sentinel.type, + proto=mocker.sentinel.proto, + flags=mocker.sentinel.flags, + ) + + async def test____getnameinfo____use_loop_getnameinfo( + self, + backend: TrioBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_trio_getnameinfo = mocker.patch( + "trio.socket.getnameinfo", + new_callable=mocker.AsyncMock, + return_value=mocker.sentinel.resolved_addr, + ) + + # Act + resolved_addr = await backend.getnameinfo( + sockaddr=mocker.sentinel.sockaddr, + flags=mocker.sentinel.flags, + ) + + # Assert + assert resolved_addr is mocker.sentinel.resolved_addr + mock_trio_getnameinfo.assert_awaited_once_with(mocker.sentinel.sockaddr, mocker.sentinel.flags) + + @pytest.mark.parametrize("happy_eyeballs_delay", [None, 42], ids=lambda p: f"happy_eyeballs_delay=={p}") + async def test____create_tcp_connection____create_trio_stream( + self, + happy_eyeballs_delay: float | None, + local_address: tuple[str, int] | None, + remote_address: tuple[str, int], + backend: TrioBackend, + mock_tcp_socket: MagicMock, + mock_trio_tcp_socket: MagicMock, + mock_trio_socket_stream: MagicMock, + mock_trio_socket_from_stdlib: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.dns_resolver import TrioDNSResolver + + mock_trio_socket_stream.socket = mock_trio_tcp_socket + mock_TrioStreamSocketAdapter: MagicMock = mocker.patch( + f"{_TRIO_BACKEND_MODULE}.stream.socket.TrioStreamSocketAdapter", + return_value=mocker.sentinel.socket, + ) + mock_trio_SocketStream = mocker.patch( + "trio.SocketStream", + return_value=mock_trio_socket_stream, + ) + mock_trio_socket_from_stdlib.return_value = mock_trio_tcp_socket + mock_own_create_connection: AsyncMock = mocker.patch.object( + TrioDNSResolver, + "create_stream_connection", + new_callable=mocker.AsyncMock, + return_value=mock_tcp_socket, + ) + + expected_happy_eyeballs_delay: float = 0.25 + if happy_eyeballs_delay is not None: + expected_happy_eyeballs_delay = happy_eyeballs_delay + + # Act + socket = await backend.create_tcp_connection( + *remote_address, + happy_eyeballs_delay=happy_eyeballs_delay, + local_address=local_address, + ) + + # Assert + mock_own_create_connection.assert_awaited_once_with( + backend, + *remote_address, + happy_eyeballs_delay=expected_happy_eyeballs_delay, + local_address=local_address, + ) + mock_trio_socket_from_stdlib.assert_called_once_with(mock_tcp_socket) + mock_trio_SocketStream.assert_called_once_with(mock_trio_tcp_socket) + mock_TrioStreamSocketAdapter.assert_called_once_with(backend, mock_trio_socket_stream) + assert socket is mocker.sentinel.socket + + async def test____wrap_stream_socket____create_trio_stream( + self, + backend: TrioBackend, + mock_tcp_socket: MagicMock, + mock_trio_tcp_socket: MagicMock, + mock_trio_socket_stream: MagicMock, + mock_trio_socket_from_stdlib: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_trio_socket_stream.socket = mock_trio_tcp_socket + mock_TrioStreamSocketAdapter: MagicMock = mocker.patch( + f"{_TRIO_BACKEND_MODULE}.stream.socket.TrioStreamSocketAdapter", + return_value=mocker.sentinel.socket, + ) + mock_trio_SocketStream = mocker.patch( + "trio.SocketStream", + return_value=mock_trio_socket_stream, + ) + mock_trio_socket_from_stdlib.return_value = mock_trio_tcp_socket + + # Act + socket = await backend.wrap_stream_socket(mock_tcp_socket) + + # Assert + mock_trio_socket_from_stdlib.assert_called_once_with(mock_tcp_socket) + mock_trio_SocketStream.assert_called_once_with(mock_trio_tcp_socket) + mock_TrioStreamSocketAdapter.assert_called_once_with(backend, mock_trio_socket_stream) + assert socket is mocker.sentinel.socket + + async def test____create_tcp_listeners____open_listener_sockets( + self, + backend: TrioBackend, + mock_tcp_socket: MagicMock, + mock_trio_tcp_socket: MagicMock, + mock_trio_socket_listener: MagicMock, + mock_trio_socket_from_stdlib: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.dns_resolver import TrioDNSResolver + + remote_host, remote_port = "remote_address", 5000 + addrinfo_list = [ + ( + mocker.sentinel.family, + mocker.sentinel.type, + mocker.sentinel.proto, + mocker.sentinel.canonical_name, + (remote_host, remote_port), + ) + ] + mock_resolve_listener_addresses = mocker.patch.object( + TrioDNSResolver, + "resolve_listener_addresses", + new_callable=mocker.AsyncMock, + return_value=addrinfo_list, + ) + mock_open_listeners = mocker.patch( + "easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result", + return_value=[mock_tcp_socket], + ) + mock_trio_socket_from_stdlib.side_effect = [mock_trio_tcp_socket] + mock_trio_SocketListener: MagicMock = mocker.patch("trio.SocketListener", side_effect=[mock_trio_socket_listener]) + mock_ListenerSocketAdapter: MagicMock = mocker.patch( + f"{_TRIO_BACKEND_MODULE}.stream.listener.TrioListenerSocketAdapter", + return_value=mocker.sentinel.listener_socket, + ) + + # Act + listener_sockets: Sequence[Any] = await backend.create_tcp_listeners( + remote_host, + remote_port, + backlog=123456789, + reuse_port=mocker.sentinel.reuse_port, + ) + + # Assert + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + [remote_host], + remote_port, + SOCK_STREAM, + ) + mock_open_listeners.assert_called_once_with( + addrinfo_list, + backlog=123456789, + reuse_address=mocker.ANY, # Determined according to OS + reuse_port=mocker.sentinel.reuse_port, + ) + mock_trio_socket_from_stdlib.assert_called_once_with(mock_tcp_socket) + mock_trio_SocketListener.assert_called_once_with(mock_trio_tcp_socket) + mock_ListenerSocketAdapter.assert_called_once_with(backend, mock_trio_socket_listener) + assert listener_sockets == [mocker.sentinel.listener_socket] + + @pytest.mark.parametrize("remote_host", [None, "", ["::", "0.0.0.0"]], ids=repr) + async def test____create_tcp_listeners____bind_to_all_interfaces( + self, + remote_host: str | list[str] | None, + backend: TrioBackend, + mock_trio_socket_from_stdlib: MagicMock, + mock_tcp_socket_factory: Callable[[int], MagicMock], + mock_trio_tcp_socket_factory: Callable[[int], MagicMock], + mock_trio_socket_listener_factory: Callable[[MagicMock], MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.dns_resolver import TrioDNSResolver + + mock_tcp_socket_ipv4 = mock_tcp_socket_factory(AF_INET) + mock_tcp_socket_ipv6 = mock_tcp_socket_factory(AF_INET6) + mock_trio_tcp_socket_ipv4 = mock_trio_tcp_socket_factory(AF_INET) + mock_trio_tcp_socket_ipv6 = mock_trio_tcp_socket_factory(AF_INET6) + mock_trio_socket_listener_ipv4 = mock_trio_socket_listener_factory(mock_trio_tcp_socket_ipv4) + mock_trio_socket_listener_ipv6 = mock_trio_socket_listener_factory(mock_trio_tcp_socket_ipv6) + remote_port = 5000 + addrinfo_list = [ + ( + AF_INET6, + SOCK_STREAM, + IPPROTO_TCP, + "", + ("::", remote_port), + ), + ( + AF_INET, + SOCK_STREAM, + IPPROTO_TCP, + "", + ("0.0.0.0", remote_port), + ), + ] + mock_resolve_listener_addresses = mocker.patch.object( + TrioDNSResolver, + "resolve_listener_addresses", + new_callable=mocker.AsyncMock, + return_value=addrinfo_list, + ) + mock_open_listeners = mocker.patch( + "easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result", + return_value=[mock_tcp_socket_ipv6, mock_tcp_socket_ipv4], + ) + mock_trio_socket_from_stdlib.side_effect = [mock_trio_tcp_socket_ipv6, mock_trio_tcp_socket_ipv4] + mock_trio_SocketListener: MagicMock = mocker.patch( + "trio.SocketListener", + side_effect=[mock_trio_socket_listener_ipv6, mock_trio_socket_listener_ipv4], + ) + mock_ListenerSocketAdapter: MagicMock = mocker.patch( + f"{_TRIO_BACKEND_MODULE}.stream.listener.TrioListenerSocketAdapter", + side_effect=[mocker.sentinel.listener_socket_ipv6, mocker.sentinel.listener_socket_ipv4], + ) + + # Act + listener_sockets: Sequence[Any] = await backend.create_tcp_listeners( + remote_host, + remote_port, + backlog=123456789, + reuse_port=mocker.sentinel.reuse_port, + ) + + # Assert + if isinstance(remote_host, list): + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + remote_host, + remote_port, + SOCK_STREAM, + ) + else: + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + [None], + remote_port, + SOCK_STREAM, + ) + mock_open_listeners.assert_called_once_with( + addrinfo_list, + backlog=123456789, + reuse_address=mocker.ANY, # Determined according to OS + reuse_port=mocker.sentinel.reuse_port, + ) + assert mock_trio_socket_from_stdlib.call_args_list == [ + mocker.call(sock) for sock in [mock_tcp_socket_ipv6, mock_tcp_socket_ipv4] + ] + assert mock_trio_SocketListener.call_args_list == [ + mocker.call(trio_sock) for trio_sock in [mock_trio_tcp_socket_ipv6, mock_trio_tcp_socket_ipv4] + ] + assert mock_ListenerSocketAdapter.call_args_list == [ + mocker.call(backend, listener) for listener in [mock_trio_socket_listener_ipv6, mock_trio_socket_listener_ipv4] + ] + assert listener_sockets == [mocker.sentinel.listener_socket_ipv6, mocker.sentinel.listener_socket_ipv4] + + @pytest.mark.parametrize("socket_family", [None, AF_INET, AF_INET6], ids=lambda p: f"family=={p}") + async def test____create_udp_endpoint____create_datagram_socket( + self, + local_address: tuple[str, int] | None, + remote_address: tuple[str, int], + socket_family: int | None, + backend: TrioBackend, + mock_udp_socket: MagicMock, + mock_trio_udp_socket: MagicMock, + mock_trio_socket_from_stdlib: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.dns_resolver import TrioDNSResolver + + mock_TrioDatagramSocketAdapter: MagicMock = mocker.patch( + f"{_TRIO_BACKEND_MODULE}.datagram.socket.TrioDatagramSocketAdapter", + return_value=mocker.sentinel.socket, + ) + mock_trio_socket_from_stdlib.return_value = mock_trio_udp_socket + mock_own_create_connection: AsyncMock = mocker.patch.object( + TrioDNSResolver, + "create_datagram_connection", + new_callable=mocker.AsyncMock, + return_value=mock_udp_socket, + ) + + # Act + if socket_family is None: + socket = await backend.create_udp_endpoint(*remote_address, local_address=local_address) + else: + socket = await backend.create_udp_endpoint(*remote_address, local_address=local_address, family=socket_family) + + # Assert + mock_own_create_connection.assert_awaited_once_with( + backend, + *remote_address, + local_address=local_address, + family=AF_UNSPEC if socket_family is None else socket_family, + ) + mock_trio_socket_from_stdlib.assert_called_once_with(mock_udp_socket) + mock_TrioDatagramSocketAdapter.assert_called_once_with(backend, mock_trio_udp_socket) + + assert socket is mocker.sentinel.socket + + async def test____wrap_connected_datagram_socket____create_datagram_socket( + self, + backend: TrioBackend, + mock_udp_socket: MagicMock, + mock_trio_udp_socket: MagicMock, + mock_trio_socket_from_stdlib: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_TrioDatagramSocketAdapter: MagicMock = mocker.patch( + f"{_TRIO_BACKEND_MODULE}.datagram.socket.TrioDatagramSocketAdapter", + return_value=mocker.sentinel.socket, + ) + mock_trio_socket_from_stdlib.return_value = mock_trio_udp_socket + + # Act + socket = await backend.wrap_connected_datagram_socket(mock_udp_socket) + + # Assert + mock_trio_socket_from_stdlib.assert_called_once_with(mock_udp_socket) + mock_TrioDatagramSocketAdapter.assert_called_once_with(backend, mock_trio_udp_socket) + + assert socket is mocker.sentinel.socket + + async def test____create_udp_listeners____open_listener_sockets( + self, + backend: TrioBackend, + mock_udp_socket: MagicMock, + mock_trio_udp_socket: MagicMock, + mock_trio_socket_from_stdlib: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.dns_resolver import TrioDNSResolver + + remote_host, remote_port = "remote_address", 5000 + addrinfo_list = [ + ( + mocker.sentinel.family, + mocker.sentinel.type, + mocker.sentinel.proto, + mocker.sentinel.canonical_name, + (remote_host, remote_port), + ) + ] + mock_resolve_listener_addresses = mocker.patch.object( + TrioDNSResolver, + "resolve_listener_addresses", + new_callable=mocker.AsyncMock, + return_value=addrinfo_list, + ) + mock_open_listeners = mocker.patch( + "easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result", + return_value=[mock_udp_socket], + ) + mock_trio_socket_from_stdlib.side_effect = [mock_trio_udp_socket] + mock_DatagramListenerSocketAdapter: MagicMock = mocker.patch( + f"{_TRIO_BACKEND_MODULE}.datagram.listener.TrioDatagramListenerSocketAdapter", + return_value=mocker.sentinel.listener_socket, + ) + + # Act + listener_sockets: Sequence[Any] = await backend.create_udp_listeners( + remote_host, + remote_port, + reuse_port=mocker.sentinel.reuse_port, + ) + + # Assert + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + [remote_host], + remote_port, + SOCK_DGRAM, + ) + mock_open_listeners.assert_called_once_with( + addrinfo_list, + backlog=None, + reuse_address=False, + reuse_port=mocker.sentinel.reuse_port, + ) + mock_trio_socket_from_stdlib.assert_called_once_with(mock_udp_socket) + mock_DatagramListenerSocketAdapter.assert_called_once_with(backend, mock_trio_udp_socket) + assert listener_sockets == [mocker.sentinel.listener_socket] + + @pytest.mark.parametrize("remote_host", [None, "", ["::", "0.0.0.0"]], ids=repr) + async def test____create_udp_listeners____bind_to_all_interfaces( + self, + remote_host: str | list[str] | None, + backend: TrioBackend, + mock_udp_socket_factory: Callable[[int], MagicMock], + mock_trio_udp_socket_factory: Callable[[int], MagicMock], + mock_trio_socket_from_stdlib: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.dns_resolver import TrioDNSResolver + + mock_udp_socket_ipv4 = mock_udp_socket_factory(AF_INET) + mock_udp_socket_ipv6 = mock_udp_socket_factory(AF_INET6) + mock_trio_udp_socket_ipv4 = mock_trio_udp_socket_factory(AF_INET) + mock_trio_udp_socket_ipv6 = mock_trio_udp_socket_factory(AF_INET6) + remote_port = 5000 + addrinfo_list = [ + ( + AF_INET6, + SOCK_DGRAM, + IPPROTO_UDP, + "", + ("::1", remote_port), + ), + ( + AF_INET, + SOCK_DGRAM, + IPPROTO_UDP, + "", + ("127.0.0.1", remote_port), + ), + ] + mock_resolve_listener_addresses = mocker.patch.object( + TrioDNSResolver, + "resolve_listener_addresses", + new_callable=mocker.AsyncMock, + return_value=addrinfo_list, + ) + mock_open_listeners = mocker.patch( + "easynetwork.lowlevel._utils.open_listener_sockets_from_getaddrinfo_result", + return_value=[mock_udp_socket_ipv6, mock_udp_socket_ipv4], + ) + mock_trio_socket_from_stdlib.side_effect = [mock_trio_udp_socket_ipv6, mock_trio_udp_socket_ipv4] + mock_DatagramListenerSocketAdapter: MagicMock = mocker.patch( + f"{_TRIO_BACKEND_MODULE}.datagram.listener.TrioDatagramListenerSocketAdapter", + side_effect=[mocker.sentinel.listener_socket_ipv6, mocker.sentinel.listener_socket_ipv4], + ) + + # Act + listener_sockets: Sequence[Any] = await backend.create_udp_listeners( + remote_host, + remote_port, + reuse_port=mocker.sentinel.reuse_port, + ) + + # Assert + if isinstance(remote_host, list): + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + remote_host, + remote_port, + SOCK_DGRAM, + ) + else: + mock_resolve_listener_addresses.assert_awaited_once_with( + backend, + [None], + remote_port, + SOCK_DGRAM, + ) + mock_open_listeners.assert_called_once_with( + addrinfo_list, + backlog=None, + reuse_address=False, + reuse_port=mocker.sentinel.reuse_port, + ) + assert mock_trio_socket_from_stdlib.call_args_list == [ + mocker.call(sock) for sock in [mock_udp_socket_ipv6, mock_udp_socket_ipv4] + ] + assert mock_DatagramListenerSocketAdapter.call_args_list == [ + mocker.call(backend, trio_sock) for trio_sock in [mock_trio_udp_socket_ipv6, mock_trio_udp_socket_ipv4] + ] + assert listener_sockets == [mocker.sentinel.listener_socket_ipv6, mocker.sentinel.listener_socket_ipv4] + + async def test____create_lock____use_trio_Lock_class( + self, + backend: TrioBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_Lock = mocker.patch("trio.Lock", return_value=mocker.sentinel.lock) + + # Act + lock = backend.create_lock() + + # Assert + mock_Lock.assert_called_once_with() + assert lock is mocker.sentinel.lock + + async def test____create_event____use_trio_Event_class( + self, + backend: TrioBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_Event = mocker.patch("trio.Event", return_value=mocker.sentinel.event) + + # Act + event = backend.create_event() + + # Assert + mock_Event.assert_called_once_with() + assert event is mocker.sentinel.event + + @pytest.mark.parametrize( + "use_lock", + [ + pytest.param(False, id="None"), + pytest.param(True, id="trio.Lock"), + ], + ) + async def test____create_condition_var____use_trio_Condition_class( + self, + use_lock: bool, + backend: TrioBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + import trio + + mock_lock: MagicMock | None = None if not use_lock else mocker.NonCallableMagicMock(spec=trio.Lock) + mock_Condition = mocker.patch("trio.Condition", return_value=mocker.sentinel.condition_var) + + # Act + condition = backend.create_condition_var(mock_lock) + + # Assert + mock_Condition.assert_called_once_with(mock_lock) + assert condition is mocker.sentinel.condition_var + + @pytest.mark.parametrize("abandon_on_cancel", [False, True], ids=lambda p: f"abandon_on_cancel=={p}") + async def test____run_in_thread____use_loop_run_in_executor( + self, + abandon_on_cancel: bool, + backend: TrioBackend, + mocker: MockerFixture, + ) -> None: + # Arrange + func_stub = mocker.stub() + mock_run_in_executor = mocker.patch( + "trio.to_thread.run_sync", + new_callable=mocker.AsyncMock, + return_value=mocker.sentinel.return_value, + ) + + # Act + ret_val = await backend.run_in_thread( + func_stub, + mocker.sentinel.arg1, + mocker.sentinel.arg2, + abandon_on_cancel=abandon_on_cancel, + ) + + # Assert + mock_run_in_executor.assert_called_once_with( + func_stub, + mocker.sentinel.arg1, + mocker.sentinel.arg2, + abandon_on_cancel=abandon_on_cancel, + ) + func_stub.assert_not_called() + assert ret_val is mocker.sentinel.return_value diff --git a/tests/unit_test/test_async/test_trio_backend/test_datagram.py b/tests/unit_test/test_async/test_trio_backend/test_datagram.py new file mode 100644 index 00000000..822c6f8f --- /dev/null +++ b/tests/unit_test/test_async/test_trio_backend/test_datagram.py @@ -0,0 +1,336 @@ +from __future__ import annotations + +import contextlib +from collections.abc import AsyncIterator, Callable, Coroutine +from typing import TYPE_CHECKING, Any + +from easynetwork.lowlevel.api_async.backend._trio.backend import TrioBackend +from easynetwork.lowlevel.api_async.backend.abc import TaskGroup +from easynetwork.lowlevel.socket import SocketAttribute, SocketProxy + +import pytest + +from ....fixtures.trio import trio_fixture +from ...base import BaseTestSocket + +if TYPE_CHECKING: + from unittest.mock import AsyncMock, MagicMock + + from easynetwork.lowlevel.api_async.backend._trio.datagram.listener import TrioDatagramListenerSocketAdapter + from easynetwork.lowlevel.api_async.backend._trio.datagram.socket import TrioDatagramSocketAdapter + + from pytest_mock import MockerFixture + + +@pytest.mark.feature_trio(async_test_auto_mark=True) +class TestTrioDatagramSocketAdapter(BaseTestSocket): + @pytest.fixture + @classmethod + def mock_trio_udp_socket(cls, mock_trio_udp_socket: MagicMock) -> MagicMock: + cls.set_local_address_to_socket_mock(mock_trio_udp_socket, mock_trio_udp_socket.family, ("127.0.0.1", 11111)) + cls.set_remote_address_to_socket_mock(mock_trio_udp_socket, mock_trio_udp_socket.family, ("127.0.0.1", 12345)) + return mock_trio_udp_socket + + @trio_fixture + @staticmethod + async def transport( + trio_backend: TrioBackend, + mock_trio_udp_socket: MagicMock, + ) -> AsyncIterator[TrioDatagramSocketAdapter]: + from easynetwork.lowlevel.api_async.backend._trio.datagram.socket import TrioDatagramSocketAdapter + + transport = TrioDatagramSocketAdapter(trio_backend, mock_trio_udp_socket) + async with transport: + yield transport + + async def test____dunder_init____invalid_socket_type( + self, + trio_backend: TrioBackend, + mock_trio_tcp_socket: MagicMock, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.datagram.socket import TrioDatagramSocketAdapter + + # Act & Assert + with pytest.raises(ValueError, match=r"^A 'SOCK_DGRAM' socket is expected$"): + _ = TrioDatagramSocketAdapter(trio_backend, mock_trio_tcp_socket) + + async def test____dunder_del____ResourceWarning( + self, + trio_backend: TrioBackend, + mock_trio_udp_socket: MagicMock, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.datagram.socket import TrioDatagramSocketAdapter + + transport = TrioDatagramSocketAdapter(trio_backend, mock_trio_udp_socket) + + # Act & Assert + with pytest.warns(ResourceWarning, match=r"^unclosed transport .+$"): + del transport + + mock_trio_udp_socket.close.assert_called() + + async def test____aclose____close_transport_and_wait( + self, + transport: TrioDatagramSocketAdapter, + mock_trio_udp_socket: MagicMock, + ) -> None: + # Arrange + import trio.testing + + assert not transport.is_closing() + + # Act + with trio.testing.assert_checkpoints(): + await transport.aclose() + + # Assert + mock_trio_udp_socket.close.assert_called_once() + assert transport.is_closing() + + async def test____recv____read_from_reader( + self, + transport: TrioDatagramSocketAdapter, + mock_trio_udp_socket: MagicMock, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.datagram.socket import TrioDatagramSocketAdapter + + mock_trio_udp_socket.recv.return_value = b"data" + + # Act + data: bytes = await transport.recv() + + # Assert + mock_trio_udp_socket.recv.assert_awaited_once_with(TrioDatagramSocketAdapter.MAX_DATAGRAM_BUFSIZE) + assert data == b"data" + + async def test____send____write_on_socket( + self, + transport: TrioDatagramSocketAdapter, + mock_trio_udp_socket: MagicMock, + ) -> None: + # Arrange + mock_trio_udp_socket.send.side_effect = lambda data: memoryview(data).nbytes + + # Act + await transport.send(b"data to send") + + # Assert + mock_trio_udp_socket.send.assert_awaited_once_with(b"data to send") + + async def test____get_backend____returns_linked_instance( + self, + transport: TrioDatagramSocketAdapter, + trio_backend: TrioBackend, + ) -> None: + # Arrange + + # Act & Assert + assert transport.backend() is trio_backend + + async def test____extra_attributes____returns_socket_info( + self, + transport: TrioDatagramSocketAdapter, + mock_trio_udp_socket: MagicMock, + ) -> None: + # Arrange + + # Act & Assert + trsock = transport.extra(SocketAttribute.socket) + assert isinstance(trsock, SocketProxy) + assert transport.extra(SocketAttribute.family) == mock_trio_udp_socket.family + assert transport.extra(SocketAttribute.sockname) == ("127.0.0.1", 11111) + assert transport.extra(SocketAttribute.peername) == ("127.0.0.1", 12345) + + mock_trio_udp_socket.reset_mock() + trsock.fileno() + mock_trio_udp_socket.fileno.assert_called_once() + + +@pytest.mark.feature_trio(async_test_auto_mark=True) +class TestTrioDatagramListenerSocketAdapter(BaseTestSocket): + @pytest.fixture + @classmethod + def mock_trio_udp_listener_socket( + cls, + mock_trio_udp_socket_factory: Callable[[], MagicMock], + ) -> MagicMock: + mock_socket = mock_trio_udp_socket_factory() + cls.set_local_address_to_socket_mock(mock_socket, mock_socket.family, ("127.0.0.1", 11111)) + cls.configure_socket_mock_to_raise_ENOTCONN(mock_socket) + return mock_socket + + @pytest.fixture + @staticmethod + def handler(mocker: MockerFixture) -> AsyncMock: + handler = mocker.async_stub("handler") + handler.return_value = None + return handler + + @trio_fixture + @staticmethod + async def listener( + trio_backend: TrioBackend, + mock_trio_udp_listener_socket: MagicMock, + ) -> AsyncIterator[TrioDatagramListenerSocketAdapter]: + from easynetwork.lowlevel.api_async.backend._trio.datagram.listener import TrioDatagramListenerSocketAdapter + + listener = TrioDatagramListenerSocketAdapter(trio_backend, mock_trio_udp_listener_socket) + async with listener: + yield listener + + @staticmethod + def _make_recvfrom_into_side_effect( + side_effect: Any, + mocker: MockerFixture, + sleep_time: float = 0, + ) -> Callable[[bytearray | memoryview], Coroutine[Any, Any, tuple[int, Any]]]: + import trio + + next_datagram_cb = mocker.AsyncMock(side_effect=side_effect) + + def write_in_buffer(buffer: memoryview, to_write: bytes) -> int: + nbytes = len(to_write) + buffer[:nbytes] = to_write + return nbytes + + async def recvfrom_into_side_effect(buffer: bytearray | memoryview) -> tuple[int, Any]: + await trio.sleep(sleep_time) + datagram: bytes + address: tuple[Any, ...] + datagram, address = await next_datagram_cb() + with memoryview(buffer) as buffer: + return write_in_buffer(buffer, datagram), address + + return recvfrom_into_side_effect + + @staticmethod + async def _get_cancelled_exc() -> BaseException: + import outcome + import trio + + with trio.move_on_after(0): + result = await outcome.acapture(trio.sleep_forever) + + assert isinstance(result, outcome.Error) + return result.error.with_traceback(None) + + async def test____dunder_init____invalid_socket_type( + self, + trio_backend: TrioBackend, + mock_trio_tcp_socket: MagicMock, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.datagram.listener import TrioDatagramListenerSocketAdapter + + # Act & Assert + with pytest.raises(ValueError, match=r"^A 'SOCK_DGRAM' socket is expected$"): + _ = TrioDatagramListenerSocketAdapter(trio_backend, mock_trio_tcp_socket) + + async def test____dunder_del____ResourceWarning( + self, + trio_backend: TrioBackend, + mock_trio_udp_listener_socket: MagicMock, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.datagram.listener import TrioDatagramListenerSocketAdapter + + listener = TrioDatagramListenerSocketAdapter(trio_backend, mock_trio_udp_listener_socket) + + # Act & Assert + with pytest.warns(ResourceWarning, match=r"^unclosed listener .+$"): + del listener + + mock_trio_udp_listener_socket.close.assert_called() + + async def test____aclose____close_socket( + self, + listener: TrioDatagramListenerSocketAdapter, + mock_trio_udp_listener_socket: MagicMock, + ) -> None: + # Arrange + import trio.testing + + assert not listener.is_closing() + + # Act + with trio.testing.assert_checkpoints(): + await listener.aclose() + + # Assert + assert listener.is_closing() + mock_trio_udp_listener_socket.close.assert_called_once_with() + + @pytest.mark.parametrize("external_group", [True, False], ids=lambda p: f"external_group=={p}") + async def test____serve____default( + self, + trio_backend: TrioBackend, + listener: TrioDatagramListenerSocketAdapter, + external_group: bool, + handler: AsyncMock, + mock_trio_udp_listener_socket: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + + import trio + + mock_trio_udp_listener_socket.recvfrom_into.side_effect = self._make_recvfrom_into_side_effect( + [(b"received_datagram", ("127.0.0.1", 12345)), (await self._get_cancelled_exc())], + mocker, + sleep_time=0.1, + ) + + # Act + task_group: TaskGroup | None + async with trio_backend.create_task_group() if external_group else contextlib.nullcontext() as task_group: # type: ignore[attr-defined] + with pytest.raises(trio.Cancelled): + await listener.serve(handler, task_group) + + # Assert + handler.assert_awaited_once_with(b"received_datagram", ("127.0.0.1", 12345)) + + async def test____send_to____write_on_socket( + self, + listener: TrioDatagramListenerSocketAdapter, + mock_trio_udp_listener_socket: MagicMock, + ) -> None: + # Arrange + mock_trio_udp_listener_socket.sendto.side_effect = lambda data, *args: memoryview(data).nbytes + + # Act + await listener.send_to(b"data to send", ("127.0.0.1", 12345)) + + # Assert + mock_trio_udp_listener_socket.sendto.assert_awaited_once_with(b"data to send", ("127.0.0.1", 12345)) + + async def test____get_backend____returns_linked_instance( + self, + trio_backend: TrioBackend, + listener: TrioDatagramListenerSocketAdapter, + ) -> None: + # Arrange + + # Act & Assert + assert listener.backend() is trio_backend + + async def test____extra_attributes____returns_socket_info( + self, + listener: TrioDatagramListenerSocketAdapter, + mock_trio_udp_listener_socket: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + + # Act & Assert + trsock = listener.extra(SocketAttribute.socket) + assert isinstance(trsock, SocketProxy) + assert listener.extra(SocketAttribute.family) == mock_trio_udp_listener_socket.family + assert listener.extra(SocketAttribute.sockname) == ("127.0.0.1", 11111) + assert listener.extra(SocketAttribute.peername, mocker.sentinel.no_value) is mocker.sentinel.no_value + + mock_trio_udp_listener_socket.reset_mock() + trsock.fileno() + mock_trio_udp_listener_socket.fileno.assert_called_once() diff --git a/tests/unit_test/test_async/test_trio_backend/test_dns_resolver.py b/tests/unit_test/test_async/test_trio_backend/test_dns_resolver.py new file mode 100644 index 00000000..bece30ec --- /dev/null +++ b/tests/unit_test/test_async/test_trio_backend/test_dns_resolver.py @@ -0,0 +1,103 @@ +from __future__ import annotations + +import socket +from collections.abc import AsyncIterator +from typing import TYPE_CHECKING + +import pytest + +from ....fixtures.trio import trio_fixture + +if TYPE_CHECKING: + from trio import SocketListener, SocketStream + + from easynetwork.lowlevel.api_async.backend._trio.dns_resolver import TrioDNSResolver + + +@pytest.mark.feature_trio(async_test_auto_mark=True) +class TestTrioDNSResolver: + @trio_fixture + @staticmethod + async def listener() -> AsyncIterator[SocketListener]: + import trio + + async with (await trio.open_tcp_listeners(0, host="127.0.0.1"))[0] as listener: + yield listener + + @pytest.fixture + @staticmethod + def listener_address(listener: SocketListener) -> tuple[str, int]: + return listener.socket.getsockname() + + @pytest.fixture + @staticmethod + def client_sock(listener: SocketListener) -> socket.socket: + sock = socket.socket(family=listener.socket.family, type=listener.socket.type) + sock.setblocking(False) + return sock + + @pytest.fixture + @staticmethod + def dns_resolver() -> TrioDNSResolver: + from easynetwork.lowlevel.api_async.backend._trio.dns_resolver import TrioDNSResolver + + return TrioDNSResolver() + + async def test____connect_socket____works( + self, + dns_resolver: TrioDNSResolver, + listener: SocketListener, + listener_address: tuple[str, int], + client_sock: socket.socket, + ) -> None: + # Arrange + import trio + + # Act + server_stream: SocketStream | None = None + async with trio.open_nursery() as nursery: + nursery.cancel_scope.deadline = trio.current_time() + 1 + nursery.start_soon(dns_resolver.connect_socket, client_sock, listener_address) + + await trio.sleep(0.5) + server_stream = await listener.accept() + + # Assert + assert server_stream is not None + assert client_sock.fileno() > 0 + + async with server_stream, trio.SocketStream(trio.socket.from_stdlib_socket(client_sock)) as client_stream: + await client_stream.send_all(b"data") + assert (await server_stream.receive_some()) == b"data" + + async def test____connect_socket____close_on_cancel( + self, + dns_resolver: TrioDNSResolver, + listener_address: tuple[str, int], + client_sock: socket.socket, + ) -> None: + # Arrange + import trio + + # Act + with trio.move_on_after(0) as scope: + await dns_resolver.connect_socket(client_sock, listener_address) + + # Assert + assert scope.cancelled_caught + assert client_sock.fileno() < 0 + + async def test____connect_socket____close_on_error( + self, + dns_resolver: TrioDNSResolver, + client_sock: socket.socket, + ) -> None: + # Arrange + listener_address = ("unknown_address", 12345) + + # Act + with pytest.raises(OSError): + await dns_resolver.connect_socket(client_sock, listener_address) + + # Assert + assert client_sock.fileno() < 0 diff --git a/tests/unit_test/test_async/test_trio_backend/test_stream.py b/tests/unit_test/test_async/test_trio_backend/test_stream.py new file mode 100644 index 00000000..5ee7e1a2 --- /dev/null +++ b/tests/unit_test/test_async/test_trio_backend/test_stream.py @@ -0,0 +1,688 @@ +from __future__ import annotations + +import contextlib +import errno +import os +from collections.abc import AsyncIterator, Callable, Coroutine, Iterable, Iterator +from typing import TYPE_CHECKING, Any + +from easynetwork.lowlevel.api_async.backend._trio.backend import TrioBackend +from easynetwork.lowlevel.api_async.backend.abc import TaskGroup +from easynetwork.lowlevel.constants import CLOSED_SOCKET_ERRNOS +from easynetwork.lowlevel.socket import SocketAttribute, SocketProxy + +import pytest + +from ....fixtures.trio import trio_fixture +from ...base import BaseTestSocket, MixinTestSocketSendMSG + +if TYPE_CHECKING: + from unittest.mock import AsyncMock, MagicMock + + from trio import BrokenResourceError as _BrokenResourceError, ClosedResourceError as _ClosedResourceError + + from easynetwork.lowlevel.api_async.backend._trio.stream.listener import TrioListenerSocketAdapter + from easynetwork.lowlevel.api_async.backend._trio.stream.socket import TrioStreamSocketAdapter + + from _typeshed import ReadableBuffer + from pytest_mock import MockerFixture + + +class BaseTestTransportStreamSocket(BaseTestSocket): + @staticmethod + def _make_broken_resource_error(connection_error_errno: int) -> _BrokenResourceError: + import trio + + cause = OSError(connection_error_errno, os.strerror(connection_error_errno)) + + exc = trio.BrokenResourceError() + exc.__context__ = cause + exc.__cause__ = cause + return exc + + @staticmethod + def _make_closed_resource_error(closed_socket_errno: int = errno.EBADF) -> _ClosedResourceError: + import trio + + exc = trio.ClosedResourceError() + exc.__context__ = OSError(closed_socket_errno, os.strerror(closed_socket_errno)) + exc.__cause__ = None + return exc + + +@pytest.mark.feature_trio(async_test_auto_mark=True) +class TestTrioStreamSocketAdapter(BaseTestTransportStreamSocket, MixinTestSocketSendMSG): + @pytest.fixture + @classmethod + def mock_trio_tcp_socket(cls, mock_trio_tcp_socket: MagicMock, mocker: MockerFixture) -> MagicMock: + cls.set_local_address_to_socket_mock(mock_trio_tcp_socket, mock_trio_tcp_socket.family, ("127.0.0.1", 11111)) + cls.set_remote_address_to_socket_mock(mock_trio_tcp_socket, mock_trio_tcp_socket.family, ("127.0.0.1", 12345)) + + # Always create a new mock instance because sendmsg() is not available on all platforms + # therefore the mocker's autospec will consider sendmsg() unknown on these ones. + mock_trio_tcp_socket.sendmsg = mocker.AsyncMock( + spec=lambda *args: None, + side_effect=lambda buffers, *args: sum(map(len, map(memoryview, buffers))), + ) + return mock_trio_tcp_socket + + @pytest.fixture + @staticmethod + def mock_trio_socket_stream( + mock_trio_tcp_socket: MagicMock, + mock_trio_socket_stream_factory: Callable[[MagicMock], MagicMock], + ) -> MagicMock: + mock_trio_socket_stream = mock_trio_socket_stream_factory(mock_trio_tcp_socket) + assert mock_trio_socket_stream.socket is mock_trio_tcp_socket + + mock_trio_socket_stream.send_all.return_value = None + + return mock_trio_socket_stream + + @trio_fixture + @staticmethod + async def transport( + trio_backend: TrioBackend, + mock_trio_socket_stream: MagicMock, + ) -> AsyncIterator[TrioStreamSocketAdapter]: + from easynetwork.lowlevel.api_async.backend._trio.stream.socket import TrioStreamSocketAdapter + + transport = TrioStreamSocketAdapter(trio_backend, mock_trio_socket_stream) + async with transport: + yield transport + + async def test____dunder_del____ResourceWarning( + self, + trio_backend: TrioBackend, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.stream.socket import TrioStreamSocketAdapter + + transport = TrioStreamSocketAdapter(trio_backend, mock_trio_socket_stream) + + # Act & Assert + with pytest.warns(ResourceWarning, match=r"^unclosed transport .+$"): + del transport + + mock_trio_socket_stream.socket.close.assert_called() + + async def test____aclose____close_transport_and_wait( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + assert not transport.is_closing() + + # Act + await transport.aclose() + + # Assert + mock_trio_socket_stream.aclose.assert_awaited_once() + assert transport.is_closing() + + async def test____recv____read_from_reader( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + mock_trio_socket_stream.receive_some.return_value = b"data" + + # Act + data: bytes = await transport.recv(1024) + + # Assert + mock_trio_socket_stream.receive_some.assert_awaited_once_with(1024) + assert data == b"data" + + async def test____recv____null_bufsize( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + mock_trio_socket_stream.receive_some.return_value = b"" + + # Act + data: bytes = await transport.recv(0) + + # Assert + mock_trio_socket_stream.receive_some.assert_awaited_once_with(0) + assert data == b"" + + @pytest.mark.parametrize("closed_socket_errno", sorted(CLOSED_SOCKET_ERRNOS), ids=errno.errorcode.__getitem__) + async def test____recv____convert_trio_ClosedResourceError( + self, + closed_socket_errno: int, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + mock_trio_socket_stream.receive_some.side_effect = self._make_closed_resource_error(closed_socket_errno) + + # Act + with pytest.raises(OSError) as exc_info: + _ = await transport.recv(1024) + + # Assert + assert exc_info.value.errno == errno.EBADF + + @pytest.mark.parametrize( + "connection_error_errno", + [ + errno.ECONNABORTED, + errno.ECONNRESET, + errno.EPIPE, + ], + ids=errno.errorcode.__getitem__, + ) + async def test____recv____convert_trio_BrokenResourceError( + self, + connection_error_errno: int, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + mock_trio_socket_stream.receive_some.side_effect = self._make_broken_resource_error(connection_error_errno) + + # Act + with pytest.raises(OSError) as exc_info: + _ = await transport.recv(1024) + + # Assert + assert exc_info.value.errno == connection_error_errno + + async def test____recv_into____read_from_reader( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + mock_trio_socket_stream.socket.recv_into.return_value = 4 + buffer = bytearray(4) + + # Act + nbytes = await transport.recv_into(buffer) + + # Assert + mock_trio_socket_stream.socket.recv_into.assert_awaited_once_with(buffer) + assert nbytes == 4 + + async def test____recv_into____null_buffer( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + mock_trio_socket_stream.socket.recv_into.return_value = 0 + buffer = bytearray() + + # Act + nbytes = await transport.recv_into(buffer) + + # Assert + mock_trio_socket_stream.socket.recv_into.assert_awaited_once_with(buffer) + assert nbytes == 0 + + async def test____send_all____use_stream_send_all( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + + # Act + await transport.send_all(b"data to send") + + # Assert + mock_trio_socket_stream.send_all.assert_awaited_once_with(b"data to send") + + @pytest.mark.parametrize("closed_socket_errno", sorted(CLOSED_SOCKET_ERRNOS), ids=errno.errorcode.__getitem__) + async def test____send_all____convert_trio_ClosedResourceError( + self, + closed_socket_errno: int, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + mock_trio_socket_stream.send_all.side_effect = self._make_closed_resource_error(closed_socket_errno) + + # Act + with pytest.raises(OSError) as exc_info: + await transport.send_all(b"data to send") + + # Assert + assert exc_info.value.errno == errno.EBADF + + @pytest.mark.parametrize( + "connection_error_errno", + [ + errno.ECONNABORTED, + errno.ECONNRESET, + errno.EPIPE, + ], + ids=errno.errorcode.__getitem__, + ) + async def test____send_all____convert_trio_BrokenResourceError( + self, + connection_error_errno: int, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + mock_trio_socket_stream.send_all.side_effect = self._make_broken_resource_error(connection_error_errno) + + # Act + with pytest.raises(OSError) as exc_info: + await transport.send_all(b"data to send") + + # Assert + assert exc_info.value.errno == connection_error_errno + + async def test____send_all_from_iterable____use_socket_sendmsg_when_available( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + chunks: list[list[bytes]] = [] + + def sendmsg_side_effect(buffers: Iterable[ReadableBuffer]) -> int: + # Ensure we are not giving the islice directly. + assert not isinstance(buffers, Iterator) + + buffers = list(buffers) + chunks.append(list(map(bytes, buffers))) + return sum(map(len, map(memoryview, buffers))) + + mock_trio_socket_stream.socket.sendmsg.side_effect = sendmsg_side_effect + + # Act + await transport.send_all_from_iterable(iter([b"data", b"to", b"send"])) + + # Assert + mock_trio_socket_stream.send_all.assert_not_called() + mock_trio_socket_stream.socket.sendmsg.assert_called_once() + assert chunks == [[b"data", b"to", b"send"]] + + @pytest.mark.parametrize("SC_IOV_MAX", [2], ids=lambda p: f"SC_IOV_MAX=={p}", indirect=True) + async def test____send_all_from_iterable____nb_buffers_greather_than_SC_IOV_MAX( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + chunks: list[list[bytes]] = [] + + def sendmsg_side_effect(buffers: Iterable[ReadableBuffer]) -> int: + buffers = list(buffers) + chunks.append(list(map(bytes, buffers))) + return sum(map(len, map(memoryview, buffers))) + + mock_trio_socket_stream.socket.sendmsg.side_effect = sendmsg_side_effect + + # Act + await transport.send_all_from_iterable(iter([b"a", b"b", b"c", b"d", b"e"])) + + # Assert + assert mock_trio_socket_stream.socket.sendmsg.await_count == 3 + assert chunks == [ + [b"a", b"b"], + [b"c", b"d"], + [b"e"], + ] + + async def test____send_all_from_iterable____adjust_leftover_buffer( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + chunks: list[list[bytes]] = [] + + def sendmsg_side_effect(buffers: Iterable[ReadableBuffer]) -> int: + buffers = list(buffers) + chunks.append(list(map(bytes, buffers))) + return min(sum(map(len, map(memoryview, buffers))), 3) + + mock_trio_socket_stream.socket.sendmsg.side_effect = sendmsg_side_effect + + # Act + await transport.send_all_from_iterable(iter([b"abcd", b"efg", b"hijkl", b"mnop"])) + + # Assert + assert mock_trio_socket_stream.socket.sendmsg.await_count == 6 + assert chunks == [ + [b"abcd", b"efg", b"hijkl", b"mnop"], + [b"d", b"efg", b"hijkl", b"mnop"], + [b"g", b"hijkl", b"mnop"], + [b"jkl", b"mnop"], + [b"mnop"], + [b"p"], + ] + + async def test____send_all_from_iterable____fallback_to_send_all____sendmsg_unavailable( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + del mock_trio_socket_stream.socket.sendmsg + + # Act + await transport.send_all_from_iterable(iter([b"data", b"to", b"send"])) + + # Assert + assert mock_trio_socket_stream.send_all.await_args_list == [ + mocker.call(b"".join([b"data", b"to", b"send"])), + ] + + @pytest.mark.parametrize("SC_IOV_MAX", [-1, 0], ids=lambda p: f"SC_IOV_MAX=={p}", indirect=True) + async def test____send_all_from_iterable____fallback_to_send_all____sendmsg_available_but_no_defined_limit( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + + # Act + await transport.send_all_from_iterable(iter([b"data", b"to", b"send"])) + + # Assert + assert mock_trio_socket_stream.send_all.await_args_list == [ + mocker.call(b"".join([b"data", b"to", b"send"])), + ] + + async def test____send_all_from_iterable____fallback_to_send_all____empty_buffer_list( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + + # Act + await transport.send_all_from_iterable(iter([])) + + # Assert + mock_trio_socket_stream.send_all.assert_awaited_once_with(b"") + + async def test____send_eo____use_stream_eof( + self, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + + # Act + await transport.send_eof() + + # Assert + mock_trio_socket_stream.send_eof.assert_awaited_once_with() + + @pytest.mark.parametrize("closed_socket_errno", sorted(CLOSED_SOCKET_ERRNOS), ids=errno.errorcode.__getitem__) + async def test____send_eof____convert_trio_ClosedResourceError( + self, + closed_socket_errno: int, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + mock_trio_socket_stream.send_eof.side_effect = self._make_closed_resource_error(closed_socket_errno) + + # Act + with pytest.raises(OSError) as exc_info: + await transport.send_eof() + + # Assert + assert exc_info.value.errno == errno.EBADF + + @pytest.mark.parametrize( + "connection_error_errno", + [ + errno.ECONNABORTED, + errno.ECONNRESET, + errno.EPIPE, + ], + ids=errno.errorcode.__getitem__, + ) + async def test____send_eof____convert_trio_BrokenResourceError( + self, + connection_error_errno: int, + transport: TrioStreamSocketAdapter, + mock_trio_socket_stream: MagicMock, + ) -> None: + # Arrange + mock_trio_socket_stream.send_eof.side_effect = self._make_broken_resource_error(connection_error_errno) + + # Act + with pytest.raises(OSError) as exc_info: + await transport.send_eof() + + # Assert + assert exc_info.value.errno == connection_error_errno + + async def test____get_backend____returns_linked_instance( + self, + transport: TrioStreamSocketAdapter, + trio_backend: TrioBackend, + ) -> None: + # Arrange + + # Act & Assert + assert transport.backend() is trio_backend + + async def test____extra_attributes____returns_socket_info( + self, + transport: TrioStreamSocketAdapter, + mock_trio_tcp_socket: MagicMock, + ) -> None: + # Arrange + + # Act & Assert + trsock = transport.extra(SocketAttribute.socket) + assert isinstance(trsock, SocketProxy) + assert transport.extra(SocketAttribute.family) == mock_trio_tcp_socket.family + assert transport.extra(SocketAttribute.sockname) == ("127.0.0.1", 11111) + assert transport.extra(SocketAttribute.peername) == ("127.0.0.1", 12345) + + mock_trio_tcp_socket.reset_mock() + trsock.fileno() + mock_trio_tcp_socket.fileno.assert_called_once() + + +@pytest.mark.feature_trio(async_test_auto_mark=True) +class TestTrioListenerSocketAdapter(BaseTestTransportStreamSocket): + @pytest.fixture + @classmethod + def mock_trio_tcp_listener_socket( + cls, + mock_trio_tcp_socket_factory: Callable[[], MagicMock], + ) -> MagicMock: + mock_socket = mock_trio_tcp_socket_factory() + cls.set_local_address_to_socket_mock(mock_socket, mock_socket.family, ("127.0.0.1", 11111)) + cls.configure_socket_mock_to_raise_ENOTCONN(mock_socket) + return mock_socket + + @pytest.fixture + @staticmethod + def mock_trio_socket_listener( + mock_trio_tcp_listener_socket: MagicMock, + mock_trio_socket_listener_factory: Callable[[MagicMock], MagicMock], + ) -> MagicMock: + mock_trio_socket_listener = mock_trio_socket_listener_factory(mock_trio_tcp_listener_socket) + assert mock_trio_socket_listener.socket is mock_trio_tcp_listener_socket + return mock_trio_socket_listener + + @pytest.fixture + @staticmethod + def handler(mocker: MockerFixture) -> AsyncMock: + handler = mocker.async_stub("handler") + handler.return_value = None + return handler + + @trio_fixture + @staticmethod + async def listener( + trio_backend: TrioBackend, + mock_trio_socket_listener: MagicMock, + ) -> AsyncIterator[TrioListenerSocketAdapter]: + from easynetwork.lowlevel.api_async.backend._trio.stream.listener import TrioListenerSocketAdapter + + listener = TrioListenerSocketAdapter(trio_backend, mock_trio_socket_listener) + async with listener: + yield listener + + @staticmethod + def _make_accept_side_effect( + side_effect: Any, + mocker: MockerFixture, + sleep_time: float = 0, + ) -> Callable[[], Coroutine[Any, Any, MagicMock]]: + import trio + + accept_cb = mocker.AsyncMock(side_effect=side_effect) + + async def accept_side_effect() -> MagicMock: + await trio.sleep(sleep_time) + return await accept_cb() + + return accept_side_effect + + @staticmethod + async def _get_cancelled_exc() -> BaseException: + import outcome + import trio + + with trio.move_on_after(0): + result = await outcome.acapture(trio.sleep_forever) + + assert isinstance(result, outcome.Error) + return result.error.with_traceback(None) + + async def test____dunder_del____ResourceWarning( + self, + trio_backend: TrioBackend, + mock_trio_socket_listener: MagicMock, + ) -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio.stream.listener import TrioListenerSocketAdapter + + listener = TrioListenerSocketAdapter(trio_backend, mock_trio_socket_listener) + + # Act & Assert + with pytest.warns(ResourceWarning, match=r"^unclosed listener .+$"): + del listener + + mock_trio_socket_listener.socket.close.assert_called() + + async def test____aclose____close_socket( + self, + listener: TrioListenerSocketAdapter, + mock_trio_socket_listener: MagicMock, + ) -> None: + # Arrange + assert not listener.is_closing() + + # Act + await listener.aclose() + + # Assert + assert listener.is_closing() + mock_trio_socket_listener.aclose.assert_awaited_once_with() + + @pytest.mark.parametrize("external_group", [True, False], ids=lambda p: f"external_group=={p}") + async def test____serve____default( + self, + trio_backend: TrioBackend, + listener: TrioListenerSocketAdapter, + external_group: bool, + handler: AsyncMock, + mock_trio_socket_listener: MagicMock, + mock_trio_socket_stream_factory: Callable[[], MagicMock], + mocker: MockerFixture, + ) -> None: + # Arrange + import trio + + from easynetwork.lowlevel.api_async.backend._trio.stream.socket import TrioStreamSocketAdapter + + accepted_client_transport = mocker.NonCallableMagicMock(spec=TrioStreamSocketAdapter) + accepted_client_stream = mock_trio_socket_stream_factory() + mock_trio_socket_listener.accept.side_effect = self._make_accept_side_effect( + [accepted_client_stream, (await self._get_cancelled_exc())], + mocker, + sleep_time=0.1, + ) + mock_TrioStreamSocketAdapter: MagicMock = mocker.patch( + "easynetwork.lowlevel.api_async.backend._trio.stream.listener.TrioStreamSocketAdapter", + side_effect=[accepted_client_transport], + ) + + # Act + task_group: TaskGroup | None + async with trio_backend.create_task_group() if external_group else contextlib.nullcontext() as task_group: # type: ignore[attr-defined] + with pytest.raises(trio.Cancelled): + await listener.serve(handler, task_group) + + # Assert + mock_TrioStreamSocketAdapter.assert_called_once_with(trio_backend, accepted_client_stream) + handler.assert_awaited_once_with(accepted_client_transport) + + @pytest.mark.parametrize("closed_socket_errno", sorted(CLOSED_SOCKET_ERRNOS), ids=errno.errorcode.__getitem__) + async def test____serve____convert_trio_ClosedResourceError( + self, + closed_socket_errno: int, + trio_backend: TrioBackend, + listener: TrioListenerSocketAdapter, + handler: AsyncMock, + mock_trio_socket_listener: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + mock_trio_socket_listener.accept.side_effect = self._make_accept_side_effect( + self._make_closed_resource_error(closed_socket_errno), + mocker, + sleep_time=0.1, + ) + + # Act + async with trio_backend.create_task_group() as task_group: + with pytest.raises(OSError) as exc_info: + await listener.serve(handler, task_group) + + # Assert + assert exc_info.value.errno == errno.EBADF + handler.assert_not_awaited() + + async def test____get_backend____returns_linked_instance( + self, + trio_backend: TrioBackend, + listener: TrioListenerSocketAdapter, + ) -> None: + # Arrange + + # Act & Assert + assert listener.backend() is trio_backend + + async def test____extra_attributes____returns_socket_info( + self, + listener: TrioListenerSocketAdapter, + mock_trio_tcp_listener_socket: MagicMock, + mocker: MockerFixture, + ) -> None: + # Arrange + + # Act & Assert + trsock = listener.extra(SocketAttribute.socket) + assert isinstance(trsock, SocketProxy) + assert listener.extra(SocketAttribute.family) == mock_trio_tcp_listener_socket.family + assert listener.extra(SocketAttribute.sockname) == ("127.0.0.1", 11111) + assert listener.extra(SocketAttribute.peername, mocker.sentinel.no_value) is mocker.sentinel.no_value + + mock_trio_tcp_listener_socket.reset_mock() + trsock.fileno() + mock_trio_tcp_listener_socket.fileno.assert_called_once() diff --git a/tests/unit_test/test_async/test_trio_backend/test_tasks.py b/tests/unit_test/test_async/test_trio_backend/test_tasks.py new file mode 100644 index 00000000..edebfde6 --- /dev/null +++ b/tests/unit_test/test_async/test_trio_backend/test_tasks.py @@ -0,0 +1,471 @@ +from __future__ import annotations + +from collections.abc import Coroutine +from typing import TYPE_CHECKING, Any + +from easynetwork.lowlevel.api_async.backend.abc import TaskInfo + +import pytest + +from ....tools import call_later_with_nursery + +if TYPE_CHECKING: + from unittest.mock import MagicMock + + from trio import Nursery + + from easynetwork.lowlevel.api_async.backend._trio.tasks import Task, _OutcomeCell + + from pytest_mock import MockerFixture + + +@pytest.mark.feature_trio(async_test_auto_mark=True) +class TestTask: + @pytest.fixture + @staticmethod + def mock_trio_task(mocker: MockerFixture) -> MagicMock: + import trio + + mock = mocker.NonCallableMagicMock(spec=trio.lowlevel.Task) + mock.name = "mock_asyncio_task" + mock.coro = mocker.NonCallableMagicMock(spec=Coroutine) + return mock + + @pytest.fixture + @staticmethod + def mock_trio_scope(mocker: MockerFixture) -> MagicMock: + import trio + + return mocker.NonCallableMagicMock(spec=trio.CancelScope) + + @pytest.fixture + @staticmethod + def outcome_cell() -> _OutcomeCell[Any]: + from easynetwork.lowlevel.api_async.backend._trio.tasks import _OutcomeCell + + return _OutcomeCell() + + @pytest.fixture + @staticmethod + def task(mock_trio_task: MagicMock, mock_trio_scope: MagicMock, outcome_cell: _OutcomeCell[Any]) -> Task[Any]: + from easynetwork.lowlevel.api_async.backend._trio.tasks import Task + + return Task(task=mock_trio_task, scope=mock_trio_scope, outcome=outcome_cell) + + @staticmethod + async def _get_cancelled_exc() -> BaseException: + import outcome + import trio + + with trio.move_on_after(0): + result = await outcome.acapture(trio.sleep_forever) + + assert isinstance(result, outcome.Error) + return result.error.with_traceback(None) + + def test____info_property____trio_task_introspection( + self, + task: Task[Any], + mock_trio_task: MagicMock, + ) -> None: + # Arrange + + # Act + task_info = task.info + + # Assert + assert isinstance(task_info, TaskInfo) + assert task_info.name == "mock_asyncio_task" + assert task_info.id == id(mock_trio_task) + assert task_info.coro is mock_trio_task.coro + + def test____done____outcome_set_result( + self, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + mocker: MockerFixture, + ) -> None: + # Arrange + import outcome + + assert not task.done() + + # Act + outcome_cell.set(outcome.Value(mocker.sentinel.result)) + + # Assert + assert task.done() + + def test____done____outcome_set_error( + self, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + + assert not task.done() + + # Act + outcome_cell.set(outcome.Error(ValueError("error"))) + + # Assert + assert task.done() + + def test____cancel____task_not_done( + self, + task: Task[Any], + mock_trio_scope: MagicMock, + ) -> None: + # Arrange + assert not task.done() + + # Act + cancel_called = task.cancel() + + # Assert + assert cancel_called + mock_trio_scope.cancel.assert_called_once() + + def test____cancel____task_done( + self, + task: Task[Any], + mock_trio_scope: MagicMock, + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + + outcome_cell.set(outcome.Value(42)) + assert task.done() + + # Act + cancel_called = task.cancel() + + # Assert + assert not cancel_called + mock_trio_scope.cancel.assert_not_called() + + def test____cancelled____task_not_done( + self, + task: Task[Any], + ) -> None: + # Arrange + assert not task.done() + + # Act + is_cancelled = task.cancelled() + + # Assert + assert not is_cancelled + + def test____cancelled____task_done_with_result( + self, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + + outcome_cell.set(outcome.Value(42)) + assert task.done() + + # Act + is_cancelled = task.cancelled() + + # Assert + assert not is_cancelled + + def test____cancelled____task_done_with_error( + self, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + + outcome_cell.set(outcome.Error(ValueError("error"))) + assert task.done() + + # Act + is_cancelled = task.cancelled() + + # Assert + assert not is_cancelled + + async def test____cancelled____task_done_with_trio_Cancelled_error( + self, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + + outcome_cell.set(outcome.Error(await self._get_cancelled_exc())) + assert task.done() + + # Act + is_cancelled = task.cancelled() + + # Assert + assert is_cancelled + + async def test____wait____until_result( + self, + nursery: Nursery, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + import trio + + call_later_with_nursery(nursery, 0.5, outcome_cell.set, outcome.Value(42)) + + # Act + with trio.fail_after(2): + await task.wait() + + # Assert + assert task.done() + + async def test____wait____until_error( + self, + nursery: Nursery, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + import trio + + call_later_with_nursery(nursery, 0.5, outcome_cell.set, outcome.Error(ValueError("error"))) + + # Act + with trio.fail_after(2): + await task.wait() + + # Assert + assert task.done() + + async def test____wait____until_cancellation( + self, + nursery: Nursery, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + import trio + + call_later_with_nursery(nursery, 0.5, outcome_cell.set, outcome.Error(await self._get_cancelled_exc())) + + # Act + with trio.fail_after(2): + await task.wait() + + # Assert + assert task.done() + + async def test____wait____timeout( + self, + nursery: Nursery, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + import trio + + call_later_with_nursery(nursery, 1, outcome_cell.set, outcome.Value(42)) + + # Act + with trio.move_on_after(0.5) as scope: + await task.wait() + + # Assert + assert scope.cancelled_caught + assert not task.done() + + async def test____wait____already_done( + self, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + import trio.testing + + outcome_cell.set(outcome.Value(42)) + + # Act & Assert + with trio.testing.assert_no_checkpoints(): + await task.wait() + + @pytest.mark.parametrize( + "cancellable", + [ + pytest.param(False, id="join"), + pytest.param(True, id="join_or_cancel"), + ], + ) + async def test____join____until_result( + self, + cancellable: bool, + nursery: Nursery, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + import trio + + call_later_with_nursery(nursery, 0.5, outcome_cell.set, outcome.Value(42)) + + # Act + value: int = 0 + with trio.fail_after(2): + if cancellable: + value = await task.join() + else: + value = await task.join_or_cancel() + # trio.fail_after() would never see the cancellation without this. + await trio.lowlevel.checkpoint_if_cancelled() + + # Assert + assert value == 42 + + @pytest.mark.parametrize( + "cancellable", + [ + pytest.param(False, id="join"), + pytest.param(True, id="join_or_cancel"), + ], + ) + async def test____join____until_error( + self, + cancellable: bool, + nursery: Nursery, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + import trio + + exc = ValueError("error") + call_later_with_nursery(nursery, 0.5, outcome_cell.set, outcome.Error(exc)) + + # Act + with pytest.raises(ValueError) as exc_info: + with trio.fail_after(2): + if cancellable: + await task.join() + else: + await task.join_or_cancel() + # trio.fail_after() would never see the cancellation without this. + await trio.lowlevel.checkpoint_if_cancelled() + + # Assert + assert exc_info.value is exc + + @pytest.mark.parametrize( + "cancellable", + [ + pytest.param(False, id="join"), + pytest.param(True, id="join_or_cancel"), + ], + ) + async def test____join____until_cancellation( + self, + cancellable: bool, + nursery: Nursery, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + import trio + + exc = await self._get_cancelled_exc() + call_later_with_nursery(nursery, 0.5, outcome_cell.set, outcome.Error(exc)) + + # Act + with pytest.raises(trio.Cancelled) as exc_info: + with trio.fail_after(2): + if cancellable: + await task.join() + else: + await task.join_or_cancel() + # trio.fail_after() would never see the cancellation without this. + await trio.lowlevel.checkpoint_if_cancelled() + + # Assert + assert exc_info.value is exc + + @pytest.mark.parametrize( + "cancellable", + [ + pytest.param(False, id="join"), + pytest.param(True, id="join_or_cancel"), + ], + ) + async def test____join____timeout( + self, + cancellable: bool, + nursery: Nursery, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + import trio + + call_later_with_nursery(nursery, 1, outcome_cell.set, outcome.Value(42)) + + # Act + value: int | None = None + with trio.move_on_after(0.5) as scope: + if cancellable: + value = await task.join() + else: + value = await task.join_or_cancel() + + # Assert + if cancellable: + assert scope.cancelled_caught + assert not task.done() + assert value is None + else: + assert not scope.cancelled_caught + assert value == 42 + + @pytest.mark.parametrize( + "cancellable", + [ + pytest.param(False, id="join"), + pytest.param(True, id="join_or_cancel"), + ], + ) + async def test____join____already_done( + self, + cancellable: bool, + task: Task[Any], + outcome_cell: _OutcomeCell[Any], + ) -> None: + # Arrange + import outcome + import trio.testing + + outcome_cell.set(outcome.Value(42)) + + # Act & Assert + value: int | None = None + with trio.testing.assert_no_checkpoints(): + if cancellable: + value = await task.join() + else: + value = await task.join_or_cancel() + + # Assert + assert value == 42 diff --git a/tests/unit_test/test_async/test_trio_backend/test_threads.py b/tests/unit_test/test_async/test_trio_backend/test_threads.py new file mode 100644 index 00000000..b80a134c --- /dev/null +++ b/tests/unit_test/test_async/test_trio_backend/test_threads.py @@ -0,0 +1,32 @@ +from __future__ import annotations + +import time +from typing import TYPE_CHECKING + +import pytest + +if TYPE_CHECKING: + from trio import Nursery + + +@pytest.mark.feature_trio(async_test_auto_mark=True) +class TestRunSyncSoonWaiter: + + async def test____aclose____wait_for_thread_to_call_detach(self, nursery: Nursery) -> None: + # Arrange + import trio + + from easynetwork.lowlevel.api_async.backend._trio.threads import _PortalRunSyncSoonWaiter + + waiter = _PortalRunSyncSoonWaiter() + + waiter.attach_from_any_thread() + + def in_thread() -> None: + time.sleep(0.5) + trio.from_thread.run_sync(waiter.detach_in_trio_thread) + + nursery.start_soon(trio.to_thread.run_sync, in_thread) + + with trio.fail_after(2): + await waiter.aclose() diff --git a/tests/unit_test/test_async/test_trio_backend/test_utils.py b/tests/unit_test/test_async/test_trio_backend/test_utils.py new file mode 100644 index 00000000..565c3d5f --- /dev/null +++ b/tests/unit_test/test_async/test_trio_backend/test_utils.py @@ -0,0 +1,105 @@ +from __future__ import annotations + +import errno + +from easynetwork.lowlevel._utils import error_from_errno + +import pytest + + +@pytest.mark.feature_trio +def test____convert_trio_resource_errors____ClosedResourceError() -> None: + # Arrange + import trio + + from easynetwork.lowlevel.api_async.backend._trio._trio_utils import convert_trio_resource_errors + + # Act + with pytest.raises(OSError) as exc_info: + with convert_trio_resource_errors(broken_resource_errno=errno.ECONNABORTED): + try: + raise error_from_errno(errno.EBADF) + except OSError: + raise trio.ClosedResourceError from None + + # Assert + assert exc_info.value.errno == errno.EBADF + assert isinstance(exc_info.value.__cause__, trio.ClosedResourceError) + assert exc_info.value.__suppress_context__ + + +@pytest.mark.feature_trio +def test____convert_trio_resource_errors____BusyResourceError() -> None: + # Arrange + import trio + + from easynetwork.lowlevel.api_async.backend._trio._trio_utils import convert_trio_resource_errors + + # Act + with pytest.raises(OSError) as exc_info: + with convert_trio_resource_errors(broken_resource_errno=errno.ECONNABORTED): + raise trio.BusyResourceError + + # Assert + assert exc_info.value.errno == errno.EBUSY + assert isinstance(exc_info.value.__cause__, trio.BusyResourceError) + assert exc_info.value.__suppress_context__ + + +@pytest.mark.feature_trio +@pytest.mark.parametrize("broken_resource_errno", [errno.ECONNABORTED, errno.EPIPE]) +def test____convert_trio_resource_errors____BrokenResourceError____arbitrary(broken_resource_errno: int) -> None: + # Arrange + import trio + + from easynetwork.lowlevel.api_async.backend._trio._trio_utils import convert_trio_resource_errors + + # Act + with pytest.raises(OSError) as exc_info: + with convert_trio_resource_errors(broken_resource_errno=broken_resource_errno): + raise trio.BrokenResourceError + + # Assert + assert exc_info.value.errno == broken_resource_errno + assert isinstance(exc_info.value.__cause__, trio.BrokenResourceError) + assert exc_info.value.__suppress_context__ + + +@pytest.mark.feature_trio +@pytest.mark.parametrize("broken_resource_errno", [errno.ECONNABORTED, errno.ECONNRESET, errno.EPIPE, errno.ENOTCONN]) +def test____convert_trio_resource_errors____BrokenResourceError____because_of_OSError(broken_resource_errno: int) -> None: + # Arrange + import trio + + from easynetwork.lowlevel.api_async.backend._trio._trio_utils import convert_trio_resource_errors + + initial_error = error_from_errno(broken_resource_errno) + + # Act + with pytest.raises(OSError) as exc_info: + with convert_trio_resource_errors(broken_resource_errno=errno.ECONNABORTED): + try: + raise initial_error + except OSError as exc: + raise trio.BrokenResourceError from exc + + # Assert + assert exc_info.value is initial_error + assert exc_info.value.__cause__ is None + assert exc_info.value.__suppress_context__ + + +@pytest.mark.feature_trio +def test____convert_trio_resource_errors____other_exceptions() -> None: + # Arrange + from easynetwork.lowlevel.api_async.backend._trio._trio_utils import convert_trio_resource_errors + + initial_error = ValueError("invalid bufsize") + + # Act + with pytest.raises(ValueError) as exc_info: + with convert_trio_resource_errors(broken_resource_errno=errno.ECONNABORTED): + raise initial_error + + # Assert + assert exc_info.value is initial_error diff --git a/tests/unit_test/test_tools/test_utils.py b/tests/unit_test/test_tools/test_utils.py index bdb04af2..75042105 100644 --- a/tests/unit_test/test_tools/test_utils.py +++ b/tests/unit_test/test_tools/test_utils.py @@ -11,7 +11,7 @@ from collections.abc import Callable from errno import EINVAL, ENOTCONN, errorcode as errno_errorcode from socket import SO_ERROR, SOL_SOCKET -from typing import TYPE_CHECKING, Any +from typing import TYPE_CHECKING, Any, Literal from easynetwork.exceptions import BusyResourceError from easynetwork.lowlevel._final import runtime_final_class @@ -41,6 +41,7 @@ replace_kwargs, set_reuseport, supports_socket_sendmsg, + validate_listener_hosts, validate_timeout_delay, ) from easynetwork.lowlevel.constants import NOT_CONNECTED_SOCKET_ERRNOS @@ -723,6 +724,45 @@ def test____set_reuseport____not_supported____defined_but_not_implemented( mock_tcp_socket.setsockopt.assert_called_once_with(SOL_SOCKET, SO_REUSEPORT, True) +@pytest.mark.parametrize("host", ["", None], ids=repr) +def test____validate_listener_hosts____any_address(host: Literal[""] | None) -> None: + # Arrange + + # Act + host_list = validate_listener_hosts(host) + + # Assert + assert host_list == [None] + + +def test____validate_listener_hosts____single_address() -> None: + # Arrange + + # Act + host_list = validate_listener_hosts("local_address") + + # Assert + assert host_list == ["local_address"] + + +def test____validate_listener_hosts____sequence_of_addresses() -> None: + # Arrange + + # Act + host_list = validate_listener_hosts(["local_address_1", "local_address_2"]) + + # Assert + assert host_list == ["local_address_1", "local_address_2"] + + +def test____validate_listener_hosts____mix_is_forbidden() -> None: + # Arrange + + # Act & Assert + with pytest.raises(TypeError): + validate_listener_hosts(["local_address_1", None]) # type: ignore[list-item] + + def test____exception_with_notes____one_note() -> None: # Arrange exception = Exception() diff --git a/tox.ini b/tox.ini index af89f9dc..e1427c03 100644 --- a/tox.ini +++ b/tox.ini @@ -6,12 +6,13 @@ envlist = mypy-{full,test,docs,benchmark_server,micro_benchmarks} # Build build + doc-html # Tests (3.11) py311-other-{tests,docstrings} - py311-{unit,functional}-{standard,cbor,msgpack} + py311-{unit,functional}-{standard,cbor,msgpack,trio} py311-functional-{asyncio_proactor,uvloop} # Tests (3.12) - py312-{unit,functional}-{standard,cbor,msgpack} + py312-{unit,functional}-{standard,cbor,msgpack,trio} py312-functional-{asyncio_proactor,uvloop} # Report coverage @@ -26,6 +27,7 @@ setenv = [docs] root_dir = {toxinidir}{/}docs source_dir = {[docs]root_dir}{/}source +build_dir = {[docs]root_dir}{/}build extensions_dir = {[docs]source_dir}{/}_extensions examples_dir = {[docs]source_dir}{/}_include{/}examples @@ -46,6 +48,7 @@ platform = docstrings: linux groups = test + trio # Needed to import trio backend package setenv = {[base]setenv} {[pytest-conf]setenv} @@ -57,7 +60,7 @@ commands = docstrings: pytest --doctest-modules {posargs} {[docs]examples_dir}{/}tutorials{/}ftp_server docstrings: pytest --doctest-glob="*.rst" {posargs} {[docs]source_dir} -[testenv:{py311,py312}-{unit,functional}-{standard,cbor,msgpack}] +[testenv:{py311,py312}-{unit,functional}-{standard,cbor,msgpack,trio}] package = wheel wheel_build_env = {[base]wheel_build_env} groups = @@ -65,6 +68,8 @@ groups = coverage cbor: cbor msgpack: msgpack + trio: trio + trio: test-trio setenv = {[base]setenv} {[pytest-conf]setenv} @@ -80,6 +85,7 @@ commands = standard: pytest -n "{env:PYTEST_MAX_WORKERS:auto}" -m "not feature" {posargs} {env:TESTS_ROOTDIR} cbor: pytest -m "feature_cbor" {posargs} {env:TESTS_ROOTDIR} msgpack: pytest -m "feature_msgpack" {posargs} {env:TESTS_ROOTDIR} + trio: pytest -n "{env:PYTEST_MAX_WORKERS:auto}" -m "feature_trio" {posargs} {env:TESTS_ROOTDIR} [testenv:{py311,py312}-functional-{asyncio_proactor,uvloop}] package = wheel @@ -109,7 +115,7 @@ commands = [testenv:coverage] skip_install = true depends = - {py311,py312}-{unit,functional}-{standard,cbor,msgpack} + {py311,py312}-{unit,functional}-{standard,cbor,msgpack,trio} {py311,py312}-functional-{asyncio_proactor,uvloop} parallel_show_output = True groups = @@ -134,6 +140,15 @@ passenv = commands = python -m build --outdir {toxinidir}{/}dist +[testenv:doc-html] +package = editable +groups = + doc +setenv = + {[base]setenv} +commands = + sphinx-build -T -b html {[docs]source_dir} {[docs]build_dir} {posargs} + [testenv:mypy-{full,test,docs,benchmark_server,micro_benchmarks}] package = wheel wheel_build_env = {[base]wheel_build_env} @@ -145,6 +160,7 @@ groups = full,test,micro_benchmarks: cbor full,test,micro_benchmarks: msgpack full,test,micro_benchmarks: types-msgpack + full,test,docs: trio docs: doc benchmark_server: benchmark-servers benchmark_server: benchmark-servers-deps @@ -203,7 +219,7 @@ setenv = commands = pytest -c pytest-benchmark.ini {posargs:--benchmark-histogram=benchmark_reports{/}micro_benches{/}benchmark} -[testenv:benchmark-server-{tcpecho,sslecho,readline,udpecho}] +[testenv:benchmark-server-{tcpecho,sslecho,readline,udpecho}-{easynetwork,stdlib}] skip_install = true groups = benchmark-servers @@ -217,10 +233,18 @@ setenv = BENCHMARK_IMAGE_TAG = easynetwork/benchmark-{env:BENCHMARK_PYTHON_VERSION} # Benchmark name - tcpecho: BENCHMARK_PATTERN = ^tcpecho - sslecho: BENCHMARK_PATTERN = ^sslecho - udpecho: BENCHMARK_PATTERN = ^udpecho - readline: BENCHMARK_PATTERN = ^readline + ## TCP echo + tcpecho-easynetwork: BENCHMARK_PATTERN = ^tcpecho-easynetwork + tcpecho-stdlib: BENCHMARK_PATTERN = ^tcpecho-(?!easynetwork) + ## SSL echo + sslecho-easynetwork: BENCHMARK_PATTERN = ^sslecho-easynetwork + sslecho-stdlib: BENCHMARK_PATTERN = ^sslecho-(?!easynetwork) + ## UDP echo + udpecho-easynetwork: BENCHMARK_PATTERN = ^udpecho-easynetwork + udpecho-stdlib: BENCHMARK_PATTERN = ^udpecho-(?!easynetwork) + ## TCP readline + readline-easynetwork: BENCHMARK_PATTERN = ^readline-easynetwork + readline-stdlib: BENCHMARK_PATTERN = ^readline-(?!easynetwork) # Report files BENCHMARK_REPORT_JSON = {toxinidir}{/}benchmark_reports{/}server_benches{/}json{/}{envname}-{env:BENCHMARK_PYTHON_VERSION}-report.json @@ -234,4 +258,7 @@ interrupt_timeout = 3.0 # seconds commands_pre = python .{/}benchmark_server{/}build_benchmark_image --tag="{env:BENCHMARK_IMAGE_TAG}" --python-version="{env:BENCHMARK_PYTHON_VERSION}" commands = - python .{/}benchmark_server{/}run_benchmark {posargs:--add-date-to-report-file} -J "{env:BENCHMARK_REPORT_JSON}" -H "{env:BENCHMARK_REPORT_HTML}" -b "{env:BENCHMARK_PATTERN}" -t "{env:BENCHMARK_IMAGE_TAG}" + # For easynetwork benchmarks, also dump in JSON format + easynetwork: python .{/}benchmark_server{/}run_benchmark {posargs:--add-date-to-report-file} -J "{env:BENCHMARK_REPORT_JSON}" -H "{env:BENCHMARK_REPORT_HTML}" -b "{env:BENCHMARK_PATTERN}" -t "{env:BENCHMARK_IMAGE_TAG}" + # For other benchmarks, only create HTML report + !easynetwork: python .{/}benchmark_server{/}run_benchmark {posargs:--add-date-to-report-file} -H "{env:BENCHMARK_REPORT_HTML}" -b "{env:BENCHMARK_PATTERN}" -t "{env:BENCHMARK_IMAGE_TAG}"