From df84bc18c3d12ab5934ded56f34613708fb50245 Mon Sep 17 00:00:00 2001 From: Francis CLAIRICIA-ROSE-CLAIRE-JOSEPHINE Date: Sun, 23 Jun 2024 18:39:36 +0200 Subject: [PATCH] Fixed bug in DatagramServer where a datagram could never be freed after usage --- .../lowlevel/api_async/servers/datagram.py | 141 ++++++++---------- .../test_servers/test_datagram.py | 16 -- 2 files changed, 65 insertions(+), 92 deletions(-) diff --git a/src/easynetwork/lowlevel/api_async/servers/datagram.py b/src/easynetwork/lowlevel/api_async/servers/datagram.py index f38f84f9..b1f0eeff 100644 --- a/src/easynetwork/lowlevel/api_async/servers/datagram.py +++ b/src/easynetwork/lowlevel/api_async/servers/datagram.py @@ -26,7 +26,7 @@ import weakref from collections import deque from collections.abc import AsyncGenerator, Callable, Hashable, Mapping -from contextlib import AsyncExitStack, ExitStack +from contextlib import AsyncExitStack from typing import Any, Generic, NoReturn, TypeVar from ...._typevars import _T_Request, _T_Response @@ -189,11 +189,12 @@ async def handler(datagram: bytes, address: _T_Address, /) -> None: else: new_client_task = client.data.state is None + await client.data.push_datagram(datagram) + if new_client_task: + del datagram client.data.mark_pending() - await self.__client_coroutine(datagram_received_cb, datagram, client, task_group, default_context) - else: - await client.data.push_datagram(datagram) + await self.__client_coroutine(datagram_received_cb, client, task_group, default_context) await listener.serve(handler, task_group) @@ -204,65 +205,67 @@ async def __client_coroutine( datagram_received_cb: Callable[ [DatagramClientContext[_T_Response, _T_Address]], AsyncGenerator[float | None, _T_Request] ], - datagram: bytes, client: _ClientToken[_T_Response, _T_Address], task_group: TaskGroup, default_context: contextvars.Context, ) -> None: - client_data = client.data - async with client_data.task_lock: - with ExitStack() as exit_stack: - ##################################################################################################### - # CRITICAL SECTION - # This block must not have any asynchronous function calls - # or add any asynchronous callbacks/contexts to the exit stack. - client_data.mark_running() - exit_stack.callback( - self.__on_task_done, + async with client.data.task_lock: + client.data.mark_running() + try: + await self.__client_coroutine_inner_loop( + request_handler_generator=datagram_received_cb(client.ctx), + client_data=client.data, + ) + finally: + self.__on_task_done( datagram_received_cb=datagram_received_cb, client=client, task_group=task_group, default_context=default_context, ) - ##################################################################################################### - - request_handler_generator = datagram_received_cb(client.ctx) - timeout: float | None + async def __client_coroutine_inner_loop( + self, + *, + request_handler_generator: AsyncGenerator[float | None, _T_Request], + client_data: _ClientData, + ) -> None: + timeout: float | None + datagram: bytes = client_data.pop_datagram_no_wait() + try: + # Ignore sent timeout here, we already have the datagram. + await anext(request_handler_generator) + except StopAsyncIteration: + return + else: + action: AsyncGenAction[_T_Request] | None + action = self.__parse_datagram(datagram, self.__protocol) + try: + timeout = await action.asend(request_handler_generator) + except StopAsyncIteration: + return + finally: + action = None + + del datagram + null_timeout_ctx = contextlib.nullcontext() + while True: try: - # Ignore sent timeout here, we already have the datagram. - await anext(request_handler_generator) - except StopAsyncIteration: - return - else: - action: AsyncGenAction[_T_Request] | None + with null_timeout_ctx if timeout is None else client_data.backend.timeout(timeout): + datagram = await client_data.pop_datagram() action = self.__parse_datagram(datagram, self.__protocol) - try: - timeout = await action.asend(request_handler_generator) - except StopAsyncIteration: - return - finally: - action = None - - del datagram - null_timeout_ctx = contextlib.nullcontext() - while True: - try: - with null_timeout_ctx if timeout is None else client_data.backend.timeout(timeout): - datagram = await client_data.pop_datagram() - action = self.__parse_datagram(datagram, self.__protocol) - except BaseException as exc: - action = ThrowAction(exc) - finally: - datagram = b"" - try: - timeout = await action.asend(request_handler_generator) - except StopAsyncIteration: - break - finally: - action = None + except BaseException as exc: + action = ThrowAction(exc) + finally: + datagram = b"" + try: + timeout = await action.asend(request_handler_generator) + except StopAsyncIteration: + break finally: - await request_handler_generator.aclose() + action = None + finally: + await request_handler_generator.aclose() def __on_task_done( self, @@ -274,9 +277,7 @@ def __on_task_done( default_context: contextvars.Context, ) -> None: client.data.mark_done() - try: - pending_datagram = client.data.pop_datagram_no_wait() - except IndexError: + if client.data.queue_is_empty(): return client.data.mark_pending() @@ -284,7 +285,6 @@ def __on_task_done( task_group.start_soon, self.__client_coroutine, datagram_received_cb, - pending_datagram, client, task_group, default_context, @@ -349,8 +349,8 @@ def __init__(self, backend: AsyncBackend) -> None: self.__backend: AsyncBackend = backend self.__task_lock: ILock = backend.create_lock() self.__state: _ClientState | None = None - self._queue_condition: ICondition | None = None - self._datagram_queue: deque[bytes] | None = None + self._queue_condition: ICondition = backend.create_condition_var() + self._datagram_queue: deque[bytes] = deque() @property def backend(self) -> AsyncBackend: @@ -364,21 +364,20 @@ def task_lock(self) -> ILock: def state(self) -> _ClientState | None: return self.__state + def queue_is_empty(self) -> bool: + return not self._datagram_queue + async def push_datagram(self, datagram: bytes) -> None: - self.__ensure_queue().append(datagram) - if (queue_condition := self._queue_condition) is not None: - async with queue_condition: - queue_condition.notify() + self._datagram_queue.append(datagram) + async with (queue_condition := self._queue_condition): + queue_condition.notify() def pop_datagram_no_wait(self) -> bytes: - if not (queue := self._datagram_queue): - raise IndexError("pop from an empty deque") - return queue.popleft() + return self._datagram_queue.popleft() async def pop_datagram(self) -> bytes: - queue_condition = self.__ensure_queue_condition_var() - async with queue_condition: - queue = self.__ensure_queue() + async with (queue_condition := self._queue_condition): + queue = self._datagram_queue while not queue: await queue_condition.wait() return queue.popleft() @@ -398,16 +397,6 @@ def mark_running(self) -> None: self.handle_inconsistent_state_error() self.__state = _ClientState.TASK_RUNNING - def __ensure_queue(self) -> deque[bytes]: - if (queue := self._datagram_queue) is None: - self._datagram_queue = queue = deque() - return queue - - def __ensure_queue_condition_var(self) -> ICondition: - if (cond := self._queue_condition) is None: - self._queue_condition = cond = self.__backend.create_condition_var() - return cond - @staticmethod def handle_inconsistent_state_error() -> NoReturn: msg = "The server has created too many tasks and ends up in an inconsistent state." diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py index 465491c2..177625fc 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py @@ -206,10 +206,6 @@ def client_data(mock_backend: MagicMock) -> _ClientData: def get_client_state(client_data: _ClientData) -> _ClientState | None: return client_data.state - @staticmethod - def get_client_queue(client_data: _ClientData) -> deque[bytes] | None: - return client_data._datagram_queue - def test____dunder_init____default( self, client_data: _ClientData, @@ -219,8 +215,6 @@ def test____dunder_init____default( # Act & Assert assert isinstance(client_data.task_lock, asyncio.Lock) assert client_data.state is None - assert client_data._datagram_queue is None - assert client_data._queue_condition is None def test____client_state____regular_state_transition( self, @@ -279,7 +273,6 @@ async def test____datagram_queue____push_datagram( client_data: _ClientData, ) -> None: # Arrange - assert self.get_client_queue(client_data) is None # Act await client_data.push_datagram(b"datagram_1") @@ -287,8 +280,6 @@ async def test____datagram_queue____push_datagram( await client_data.push_datagram(b"datagram_3") # Assert - assert client_data._datagram_queue is not None - assert client_data._queue_condition is None assert list(client_data._datagram_queue) == [b"datagram_1", b"datagram_2", b"datagram_3"] @pytest.mark.asyncio @@ -313,19 +304,12 @@ async def test____datagram_queue____pop_datagram( # Assert assert len(client_data._datagram_queue) == 0 - if no_wait: - assert client_data._queue_condition is None - else: - assert client_data._queue_condition is not None - @pytest.mark.parametrize("queue", [deque(), None], ids=lambda p: f"queue=={p!r}") def test____datagram_queue____pop_datagram_no_wait____empty_list( self, - queue: deque[bytes] | None, client_data: _ClientData, ) -> None: # Arrange - client_data._datagram_queue = queue # Act & Assert with pytest.raises(IndexError):