From 1e990fde4d86abf8fcd618498baa5e35785b37d2 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE <francis.clairicia-rose-claire-josephine@epitech.eu> Date: Sun, 29 Oct 2023 11:23:17 +0100 Subject: [PATCH] [FIX] Servers: Improved asynchronous generators usage --- README.md | 2 +- src/easynetwork/api_async/server/tcp.py | 39 ++++--------- src/easynetwork/api_async/server/udp.py | 40 +++++++++---- .../_tools/actions.py => _asyncgen.py} | 44 +++++--------- .../api_async/servers/_tools/__init__.py | 19 ------- .../lowlevel/api_async/servers/datagram.py | 46 ++++++--------- .../lowlevel/api_async/servers/stream.py | 45 +++++++-------- tests/scripts/async_server_test.py | 2 +- .../test_asyncgen.py} | 57 ++++++------------- 9 files changed, 107 insertions(+), 187 deletions(-) rename src/easynetwork/lowlevel/{api_async/servers/_tools/actions.py => _asyncgen.py} (50%) delete mode 100644 src/easynetwork/lowlevel/api_async/servers/_tools/__init__.py rename tests/unit_test/{test_async/test_lowlevel_api/test_servers/test_actions.py => test_tools/test_asyncgen.py} (55%) diff --git a/README.md b/README.md index 9a8a0bc5..ba7646e7 100644 --- a/README.md +++ b/README.md @@ -88,7 +88,7 @@ class EchoRequestHandler(AsyncStreamRequestHandler[RequestType, ResponseType]): await client.send_packet({"error": "Invalid JSON", "code": "parse_error"}) return - self.logger.info(f"{client.address} sent {request!r}") + self.logger.info(f"{client!r} sent {request!r}") # As a good echo handler, the request is sent back to the client response: ResponseType = request diff --git a/src/easynetwork/api_async/server/tcp.py b/src/easynetwork/api_async/server/tcp.py index 5781fe2c..4d00ba97 100644 --- a/src/easynetwork/api_async/server/tcp.py +++ b/src/easynetwork/api_async/server/tcp.py @@ -28,7 +28,7 @@ from ..._typevars import _RequestT, _ResponseT from ...exceptions import ClientClosedError, ServerAlreadyRunning, ServerClosedError -from ...lowlevel import _utils, constants +from ...lowlevel import _asyncgen, _utils, constants from ...lowlevel.api_async.backend.factory import AsyncBackendFactory from ...lowlevel.api_async.servers import stream as lowlevel_stream_server from ...lowlevel.socket import ( @@ -393,23 +393,22 @@ async def disconnect_client() -> None: del client_exit_stack - from ...lowlevel.api_async.servers._tools.actions import ErrorAction, RequestAction - try: - if client.is_closing(): - return - if request_handler_generator is None: - request_handler_generator = await self.__new_request_handler(client) + action: _asyncgen.AsyncGenAction[None, _RequestT] + while not client.is_closing(): if request_handler_generator is None: - return - action: RequestAction[_RequestT] | ErrorAction - while True: + request_handler_generator = self.__request_handler.handle(client) + try: + await anext(request_handler_generator) + except StopAsyncIteration: + request_handler_generator = None + break try: - action = RequestAction((yield)) + action = _asyncgen.SendAction((yield)) except ConnectionError: - return + break except BaseException as exc: - action = ErrorAction(exc) + action = _asyncgen.ThrowAction(exc) try: await action.asend(request_handler_generator) except StopAsyncIteration: @@ -417,24 +416,10 @@ async def disconnect_client() -> None: finally: del action await backend.cancel_shielded_coro_yield() - if client.is_closing(): - break - if request_handler_generator is None: - request_handler_generator = await self.__new_request_handler(client) - if request_handler_generator is None: - break finally: if request_handler_generator is not None: await request_handler_generator.aclose() - async def __new_request_handler(self, client: _ConnectedClientAPI[_ResponseT]) -> AsyncGenerator[None, _RequestT] | None: - request_handler_generator = self.__request_handler.handle(client) - try: - await anext(request_handler_generator) - except StopAsyncIteration: - return None - return request_handler_generator - @staticmethod def __set_socket_linger_if_not_closed(socket: ISocket) -> None: with contextlib.suppress(OSError): diff --git a/src/easynetwork/api_async/server/udp.py b/src/easynetwork/api_async/server/udp.py index 1ed54141..ecb80b04 100644 --- a/src/easynetwork/api_async/server/udp.py +++ b/src/easynetwork/api_async/server/udp.py @@ -21,13 +21,13 @@ import contextlib import logging import weakref -from collections import deque +from collections import defaultdict, deque from collections.abc import AsyncGenerator, Callable, Coroutine, Iterator, Mapping, Sequence from typing import TYPE_CHECKING, Any, Generic, NoReturn, final from ..._typevars import _RequestT, _ResponseT from ...exceptions import ClientClosedError, ServerAlreadyRunning, ServerClosedError -from ...lowlevel import _utils +from ...lowlevel import _asyncgen, _utils from ...lowlevel.api_async.backend.factory import AsyncBackendFactory from ...lowlevel.api_async.servers import datagram as lowlevel_datagram_server from ...lowlevel.api_async.transports.abc import AsyncDatagramListener @@ -55,6 +55,7 @@ class AsyncUDPNetworkServer(AbstractAsyncNetworkServer, Generic[_RequestT, _Resp "__is_shutdown", "__shutdown_asked", "__clients_cache", + "__send_locks_cache", "__servers_tasks", "__mainloop_task", "__logger", @@ -121,7 +122,16 @@ def __init__( self.__servers_tasks: deque[Task[NoReturn]] = deque() # type: ignore[assignment] self.__mainloop_task: Task[NoReturn] | None = None self.__logger: logging.Logger = logger or logging.getLogger(__name__) - self.__clients_cache: weakref.WeakValueDictionary[SocketAddress, _ClientAPI[_ResponseT]] = weakref.WeakValueDictionary() + self.__clients_cache: defaultdict[ + lowlevel_datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], + weakref.WeakValueDictionary[SocketAddress, _ClientAPI[_ResponseT]], + ] + self.__send_locks_cache: defaultdict[ + lowlevel_datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]], + weakref.WeakValueDictionary[SocketAddress, ILock], + ] + self.__clients_cache = defaultdict(weakref.WeakValueDictionary) + self.__send_locks_cache = defaultdict(weakref.WeakValueDictionary) def is_serving(self) -> bool: return self.__servers is not None and all(not server.is_closing() for server in self.__servers) @@ -210,6 +220,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> ################ # Initialize request handler + server_exit_stack.callback(self.__send_locks_cache.clear) server_exit_stack.callback(self.__clients_cache.clear) await self.__request_handler.service_init( await server_exit_stack.enter_async_context(contextlib.AsyncExitStack()), @@ -243,7 +254,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) -> finally: self.__mainloop_task = None - raise AssertionError("received_datagrams() does not return") + raise AssertionError("sleep_forever() does not return") serve_forever.__doc__ = AbstractAsyncNetworkServer.serve_forever.__doc__ @@ -262,26 +273,31 @@ async def __datagram_received_coroutine( ) -> AsyncGenerator[None, _RequestT]: address = new_socket_address(address, server.extra(INETSocketAttribute.family)) with self.__suppress_and_log_remaining_exception(client_address=address): + send_locks_cache = self.__send_locks_cache[server] try: - client = self.__clients_cache[address] + send_lock = send_locks_cache[address] except KeyError: - self.__clients_cache[address] = client = _ClientAPI(address, server, self.__backend.create_lock(), self.__logger) + send_locks_cache[address] = send_lock = self.__backend.create_lock() + + clients_cache = self.__clients_cache[server] + try: + client = clients_cache[address] + except KeyError: + clients_cache[address] = client = _ClientAPI(address, server, send_lock, self.__logger) async with contextlib.aclosing(self.__request_handler.handle(client)) as request_handler_generator: - del client + del client, send_lock try: await anext(request_handler_generator) except StopAsyncIteration: return - from ...lowlevel.api_async.servers._tools.actions import ErrorAction, RequestAction - - action: RequestAction[_RequestT] | ErrorAction + action: _asyncgen.AsyncGenAction[None, _RequestT] while True: try: - action = RequestAction((yield)) + action = _asyncgen.SendAction((yield)) except BaseException as exc: - action = ErrorAction(exc) + action = _asyncgen.ThrowAction(exc) try: await action.asend(request_handler_generator) except StopAsyncIteration: diff --git a/src/easynetwork/lowlevel/api_async/servers/_tools/actions.py b/src/easynetwork/lowlevel/_asyncgen.py similarity index 50% rename from src/easynetwork/lowlevel/api_async/servers/_tools/actions.py rename to src/easynetwork/lowlevel/_asyncgen.py index ae08c44f..8a330493 100644 --- a/src/easynetwork/lowlevel/api_async/servers/_tools/actions.py +++ b/src/easynetwork/lowlevel/_asyncgen.py @@ -12,7 +12,7 @@ # limitations under the License. # # -"""Low-level asynchronous server module""" +"""Async generators helper module""" from __future__ import annotations @@ -20,56 +20,38 @@ import dataclasses from abc import ABCMeta, abstractmethod -from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable +from collections.abc import AsyncGenerator from typing import Any, Generic, TypeVar -_T = TypeVar("_T") +_T_Send = TypeVar("_T_Send") +_T_Yield = TypeVar("_T_Yield") -class Action(Generic[_T], metaclass=ABCMeta): +class AsyncGenAction(Generic[_T_Yield, _T_Send], metaclass=ABCMeta): __slots__ = () @abstractmethod - async def asend(self, generator: AsyncGenerator[None, _T]) -> None: + async def asend(self, generator: AsyncGenerator[_T_Yield, _T_Send]) -> _T_Yield: raise NotImplementedError @dataclasses.dataclass(slots=True) -class RequestAction(Action[_T]): - request: _T +class SendAction(AsyncGenAction[_T_Yield, _T_Send]): + value: _T_Send - async def asend(self, generator: AsyncGenerator[None, _T]) -> None: + async def asend(self, generator: AsyncGenerator[_T_Yield, _T_Send]) -> _T_Yield: try: - await generator.asend(self.request) + return await generator.asend(self.value) finally: del self @dataclasses.dataclass(slots=True) -class ErrorAction(Action[Any]): +class ThrowAction(AsyncGenAction[_T_Yield, Any]): exception: BaseException - async def asend(self, generator: AsyncGenerator[None, Any]) -> None: + async def asend(self, generator: AsyncGenerator[_T_Yield, Any]) -> _T_Yield: try: - await generator.athrow(self.exception) + return await generator.athrow(self.exception) finally: del generator, self # Needed to avoid circular reference with raised exception - - -@dataclasses.dataclass(slots=True) -class ActionIterator(Generic[_T]): - request_factory: Callable[[], Awaitable[_T]] - - def __aiter__(self) -> AsyncIterator[Action[_T]]: - return self - - async def __anext__(self) -> Action[_T]: - try: - request = await self.request_factory() - except StopAsyncIteration: - raise - except BaseException as exc: - return ErrorAction(exc) - finally: - del self - return RequestAction(request) diff --git a/src/easynetwork/lowlevel/api_async/servers/_tools/__init__.py b/src/easynetwork/lowlevel/api_async/servers/_tools/__init__.py deleted file mode 100644 index d7bcb5e7..00000000 --- a/src/easynetwork/lowlevel/api_async/servers/_tools/__init__.py +++ /dev/null @@ -1,19 +0,0 @@ -# Copyright 2021-2023, Francis Clairicia-Rose-Claire-Josephine -# Licensed under the Apache License, Version 2.0 (the "License"); -# you may not use this file except in compliance with the License. -# You may obtain a copy of the License at -# -# http://www.apache.org/licenses/LICENSE-2.0 -# -# Unless required by applicable law or agreed to in writing, software -# distributed under the License is distributed on an "AS IS" BASIS, -# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. -# See the License for the specific language governing permissions and -# limitations under the License. -# -# -"""Low-level asynchronous server module""" - -from __future__ import annotations - -__all__ = [] # type: list[str] diff --git a/src/easynetwork/lowlevel/api_async/servers/datagram.py b/src/easynetwork/lowlevel/api_async/servers/datagram.py index 28abd0ca..15026ae3 100644 --- a/src/easynetwork/lowlevel/api_async/servers/datagram.py +++ b/src/easynetwork/lowlevel/api_async/servers/datagram.py @@ -29,10 +29,9 @@ from .... import protocol as protocol_module from ...._typevars import _RequestT, _ResponseT -from ... import _utils, typed_attr +from ... import _asyncgen, _utils, typed_attr from ..backend.abc import AsyncBackend, ICondition, ILock, TaskGroup from ..transports import abc as transports -from ._tools.actions import ActionIterator as _ActionIterator _T_Address = TypeVar("_T_Address", bound=Hashable) _KT = TypeVar("_KT") @@ -123,6 +122,7 @@ async def serve( client_coroutine = self.__client_coroutine client_manager = self.__client_manager backend = self.__backend + listener = self.__listener async def handler(datagram: bytes, address: _T_Address, /) -> None: with client_manager.datagram_queue(address) as datagram_queue: @@ -151,7 +151,7 @@ async def handler(datagram: bytes, address: _T_Address, /) -> None: assert_never(client_state) while True: - datagram, address = await self.__listener.recv_from() + datagram, address = await listener.recv_from() task_group.start_soon(handler, datagram, address) del datagram, address await backend.cancel_shielded_coro_yield() @@ -200,15 +200,17 @@ async def __client_coroutine( del datagram_queue[0] return - request_factory = _utils.make_callback( - self.__request_factory, - datagram_queue, - address, - client_manager, - condition, - self.__protocol, - ) - async for action in _ActionIterator(request_factory): + protocol = self.__protocol + action: _asyncgen.AsyncGenAction[None, _RequestT] + while True: + try: + if not datagram_queue: + with client_manager.set_client_state(address, _ClientState.TASK_WAITING): + await condition.wait() + self.__check_datagram_queue_not_empty(datagram_queue) + action = _asyncgen.SendAction(protocol.build_packet_from_datagram(datagram_queue.popleft())) + except BaseException as exc: + action = _asyncgen.ThrowAction(exc) try: await action.asend(request_handler_generator) except StopAsyncIteration: @@ -241,24 +243,8 @@ def __clear_queue_on_error( exc_tb: TracebackType | None, /, ) -> None: - if exc_type is not None and not issubclass(exc_type, self.__backend.get_cancelled_exc_class()): - datagram_queue.clear() # pragma: no cover - - @classmethod - async def __request_factory( - cls, - datagram_queue: deque[bytes], - address: _T_Address, - client_manager: _ClientManager[_T_Address], - condition: ICondition, - protocol: protocol_module.DatagramProtocol[_ResponseT, _RequestT], - /, - ) -> _RequestT: - if not datagram_queue: - with client_manager.set_client_state(address, _ClientState.TASK_WAITING): - await condition.wait() - cls.__check_datagram_queue_not_empty(datagram_queue) - return protocol.build_packet_from_datagram(datagram_queue.popleft()) + if exc_type is not None: + datagram_queue.clear() @staticmethod def __check_datagram_queue_not_empty(datagram_queue: deque[bytes]) -> None: diff --git a/src/easynetwork/lowlevel/api_async/servers/stream.py b/src/easynetwork/lowlevel/api_async/servers/stream.py index 3be13b82..efe34d28 100644 --- a/src/easynetwork/lowlevel/api_async/servers/stream.py +++ b/src/easynetwork/lowlevel/api_async/servers/stream.py @@ -24,10 +24,9 @@ from .... import protocol as protocol_module from ...._typevars import _RequestT, _ResponseT -from ... import _stream, _utils, typed_attr +from ... import _asyncgen, _stream, _utils, typed_attr from ..backend.abc import AsyncBackend, TaskGroup from ..transports import abc as transports, utils as transports_utils -from ._tools.actions import ActionIterator as _ActionIterator class AsyncStreamClient(typed_attr.TypedAttributeProvider, Generic[_ResponseT]): @@ -179,8 +178,24 @@ async def __client_coroutine( except StopAsyncIteration: return - request_factory = _utils.make_callback(self.__request_factory, transport, consumer, self.__max_recv_size) - async for action in _ActionIterator(request_factory): + bufsize: int = self.__max_recv_size + action: _asyncgen.AsyncGenAction[None, _RequestT] + while not transport.is_closing(): + try: + try: + action = _asyncgen.SendAction(next(consumer)) + except StopIteration: + data: bytes = await transport.recv(bufsize) + if not data: # Closed connection (EOF) + break + try: + consumer.feed(data) + finally: + del data + continue + except BaseException as exc: + action = _asyncgen.ThrowAction(exc) + try: await action.asend(request_handler_generator) except StopAsyncIteration: @@ -188,28 +203,6 @@ async def __client_coroutine( finally: del action - @classmethod - async def __request_factory( - cls, - transport: transports.AsyncStreamReadTransport, - consumer: _stream.StreamDataConsumer[_RequestT], - bufsize: int, - /, - ) -> _RequestT: - while not transport.is_closing(): - try: - return next(consumer) - except StopIteration: - pass - data: bytes = await transport.recv(bufsize) - if not data: # Closed connection (EOF) - break - try: - consumer.feed(data) - finally: - del data - raise StopAsyncIteration - @property def max_recv_size(self) -> int: """Read buffer size. Read-only attribute.""" diff --git a/tests/scripts/async_server_test.py b/tests/scripts/async_server_test.py index 53630462..88b66a4f 100644 --- a/tests/scripts/async_server_test.py +++ b/tests/scripts/async_server_test.py @@ -24,7 +24,7 @@ async def service_init(self, exit_stack: contextlib.AsyncExitStack, server: Abst async def handle(self, client: AsyncBaseClientInterface[str]) -> AsyncGenerator[None, str]: request: str = yield - logger.debug(f"Received {request!r}") + logger.debug(f"Received {request!r} from {client!r}") if request == "wait:": request = (yield) + " after wait" await client.send_packet(request.upper()) diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_actions.py b/tests/unit_test/test_tools/test_asyncgen.py similarity index 55% rename from tests/unit_test/test_async/test_lowlevel_api/test_servers/test_actions.py rename to tests/unit_test/test_tools/test_asyncgen.py index b6fabfdc..1fa73e56 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_actions.py +++ b/tests/unit_test/test_tools/test_asyncgen.py @@ -2,20 +2,20 @@ import sys from collections.abc import AsyncGenerator -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any -from easynetwork.lowlevel.api_async.servers._tools.actions import ActionIterator, ErrorAction, RequestAction +from easynetwork.lowlevel._asyncgen import SendAction, ThrowAction import pytest if TYPE_CHECKING: - from unittest.mock import AsyncMock, MagicMock + from unittest.mock import MagicMock from pytest_mock import MockerFixture @pytest.mark.asyncio -class TestAction: +class TestAsyncGenAction: @pytest.fixture @staticmethod def mock_generator(mocker: MockerFixture) -> MagicMock: @@ -25,34 +25,38 @@ def mock_generator(mocker: MockerFixture) -> MagicMock: mock_generator.aclose.return_value = None return mock_generator - async def test____RequestAction____send_request(self, mock_generator: MagicMock, mocker: MockerFixture) -> None: + async def test____SendAction____send_value(self, mock_generator: MagicMock, mocker: MockerFixture) -> None: # Arrange - action = RequestAction(mocker.sentinel.request) + mock_generator.asend.return_value = mocker.sentinel.to_yield_from_send + action: SendAction[Any, Any] = SendAction(mocker.sentinel.value) # Act - await action.asend(mock_generator) + to_yield = await action.asend(mock_generator) # Assert - assert mock_generator.mock_calls == [mocker.call.asend(mocker.sentinel.request)] - mock_generator.asend.assert_awaited_once_with(mocker.sentinel.request) + assert mock_generator.mock_calls == [mocker.call.asend(mocker.sentinel.value)] + mock_generator.asend.assert_awaited_once_with(mocker.sentinel.value) + assert to_yield is mocker.sentinel.to_yield_from_send - async def test____ErrorAction____throw_exception(self, mock_generator: MagicMock, mocker: MockerFixture) -> None: + async def test____ThrowAction____throw_exception(self, mock_generator: MagicMock, mocker: MockerFixture) -> None: # Arrange exc = ValueError("abc") - action = ErrorAction(exc) + mock_generator.athrow.return_value = mocker.sentinel.to_yield_from_throw + action: ThrowAction[Any] = ThrowAction(exc) # Act - await action.asend(mock_generator) + to_yield = await action.asend(mock_generator) # Assert assert mock_generator.mock_calls == [mocker.call.athrow(exc)] mock_generator.athrow.assert_awaited_once_with(exc) + assert to_yield is mocker.sentinel.to_yield_from_throw # Test taken from "outcome" project (https://github.com/python-trio/outcome) async def test____ErrorAction____does_not_create_reference_cycles(self, mock_generator: MagicMock) -> None: # Arrange exc = ValueError("abc") - action = ErrorAction(exc) + action: ThrowAction[Any] = ThrowAction(exc) mock_generator.athrow.side_effect = exc # Act @@ -71,30 +75,3 @@ async def test____ErrorAction____does_not_create_reference_cycles(self, mock_gen unwrap_frame = exc.__traceback__.tb_next.tb_frame assert unwrap_frame.f_code.co_name == "asend" assert unwrap_frame.f_locals == {} - - -@pytest.mark.asyncio -class TestActionIterator: - @pytest.fixture - @staticmethod - def request_factory(mocker: MockerFixture) -> AsyncMock: - return mocker.AsyncMock(spec=lambda: None) - - async def test____dunder_aiter____return_self(self, request_factory: AsyncMock) -> None: - # Arrange - action_iterator = ActionIterator(request_factory) - - # Act & Assert - assert aiter(action_iterator) is action_iterator - - async def test____dunder_anext____yield_actions(self, request_factory: AsyncMock, mocker: MockerFixture) -> None: - # Arrange - exc = BaseException() - request_factory.side_effect = [mocker.sentinel.request, exc, StopAsyncIteration] - - # Act - actions = [action async for action in ActionIterator(request_factory)] - - # Assert - assert request_factory.await_count == 3 - assert actions == [RequestAction(mocker.sentinel.request), ErrorAction(exc)]