diff --git a/src/easynetwork/lowlevel/api_async/servers/datagram.py b/src/easynetwork/lowlevel/api_async/servers/datagram.py index a3870103..acc2c073 100644 --- a/src/easynetwork/lowlevel/api_async/servers/datagram.py +++ b/src/easynetwork/lowlevel/api_async/servers/datagram.py @@ -26,7 +26,7 @@ import weakref from collections import deque from collections.abc import AsyncGenerator, Callable, Hashable, Mapping -from contextlib import AsyncExitStack, ExitStack +from contextlib import AsyncExitStack from typing import Any, Generic, NoReturn, TypeVar from ...._typevars import _T_Request, _T_Response @@ -185,15 +185,13 @@ async def handler(datagram: bytes, address: _T_Address, /) -> None: client = client_cache[address] except KeyError: client_cache[address] = client = _ClientToken(DatagramClientContext(address, self), _ClientData(backend)) - new_client_task = True - else: - new_client_task = client.data.state is None - if new_client_task: + await client.data.push_datagram(datagram) + + if client.data.state is None: + del datagram client.data.mark_pending() - await self.__client_coroutine(datagram_received_cb, datagram, client, task_group, default_context) - else: - await client.data.push_datagram(datagram) + await self.__client_coroutine(datagram_received_cb, client, task_group, default_context) await listener.serve(handler, task_group) @@ -204,63 +202,67 @@ async def __client_coroutine( datagram_received_cb: Callable[ [DatagramClientContext[_T_Response, _T_Address]], AsyncGenerator[float | None, _T_Request] ], - datagram: bytes, client: _ClientToken[_T_Response, _T_Address], task_group: TaskGroup, default_context: contextvars.Context, ) -> None: - client_data = client.data - async with client_data.task_lock: - with ExitStack() as exit_stack: - ##################################################################################################### - # CRITICAL SECTION - # This block must not have any asynchronous function calls - # or add any asynchronous callbacks/contexts to the exit stack. - client_data.mark_running() - exit_stack.callback( - self.__on_task_done, + async with client.data.task_lock: + client.data.mark_running() + try: + await self.__client_coroutine_inner_loop( + request_handler_generator=datagram_received_cb(client.ctx), + client_data=client.data, + ) + finally: + self.__on_task_done( datagram_received_cb=datagram_received_cb, client=client, task_group=task_group, default_context=default_context, ) - ##################################################################################################### - - request_handler_generator = datagram_received_cb(client.ctx) - timeout: float | None + async def __client_coroutine_inner_loop( + self, + *, + request_handler_generator: AsyncGenerator[float | None, _T_Request], + client_data: _ClientData, + ) -> None: + timeout: float | None + datagram: bytes = client_data.pop_datagram_no_wait() + try: + # Ignore sent timeout here, we already have the datagram. + await anext(request_handler_generator) + except StopAsyncIteration: + return + else: + action: AsyncGenAction[_T_Request] | None + action = self.__parse_datagram(datagram, self.__protocol) + try: + timeout = await action.asend(request_handler_generator) + except StopAsyncIteration: + return + finally: + action = None + + del datagram + null_timeout_ctx = contextlib.nullcontext() + while True: try: - # Ignore sent timeout here, we already have the datagram. - await anext(request_handler_generator) + with null_timeout_ctx if timeout is None else client_data.backend.timeout(timeout): + datagram = await client_data.pop_datagram() + action = self.__parse_datagram(datagram, self.__protocol) + except BaseException as exc: + action = ThrowAction(exc) + finally: + datagram = b"" + try: + timeout = await action.asend(request_handler_generator) except StopAsyncIteration: - return - else: - action: AsyncGenAction[_T_Request] = self.__parse_datagram(datagram, self.__protocol) - try: - timeout = await action.asend(request_handler_generator) - except StopAsyncIteration: - return - finally: - del action - - del datagram - while True: - try: - with contextlib.nullcontext() if timeout is None else client_data.backend.timeout(timeout): - datagram = await client_data.pop_datagram() - - action = self.__parse_datagram(datagram, self.__protocol) - del datagram - except BaseException as exc: - action = ThrowAction(exc) - try: - timeout = await action.asend(request_handler_generator) - except StopAsyncIteration: - break - finally: - del action + break finally: - await request_handler_generator.aclose() + action = None + finally: + await request_handler_generator.aclose() def __on_task_done( self, @@ -272,9 +274,7 @@ def __on_task_done( default_context: contextvars.Context, ) -> None: client.data.mark_done() - try: - pending_datagram = client.data.pop_datagram_no_wait() - except IndexError: + if client.data.queue_is_empty(): return client.data.mark_pending() @@ -282,7 +282,6 @@ def __on_task_done( task_group.start_soon, self.__client_coroutine, datagram_received_cb, - pending_datagram, client, task_group, default_context, @@ -347,8 +346,8 @@ def __init__(self, backend: AsyncBackend) -> None: self.__backend: AsyncBackend = backend self.__task_lock: ILock = backend.create_lock() self.__state: _ClientState | None = None - self._queue_condition: ICondition | None = None - self._datagram_queue: deque[bytes] | None = None + self._queue_condition: ICondition = backend.create_condition_var() + self._datagram_queue: deque[bytes] = deque() @property def backend(self) -> AsyncBackend: @@ -362,21 +361,20 @@ def task_lock(self) -> ILock: def state(self) -> _ClientState | None: return self.__state + def queue_is_empty(self) -> bool: + return not self._datagram_queue + async def push_datagram(self, datagram: bytes) -> None: - self.__ensure_queue().append(datagram) - if (queue_condition := self._queue_condition) is not None: - async with queue_condition: - queue_condition.notify() + self._datagram_queue.append(datagram) + async with (queue_condition := self._queue_condition): + queue_condition.notify() def pop_datagram_no_wait(self) -> bytes: - if not (queue := self._datagram_queue): - raise IndexError("pop from an empty deque") - return queue.popleft() + return self._datagram_queue.popleft() async def pop_datagram(self) -> bytes: - queue_condition = self.__ensure_queue_condition_var() - async with queue_condition: - queue = self.__ensure_queue() + async with (queue_condition := self._queue_condition): + queue = self._datagram_queue while not queue: await queue_condition.wait() return queue.popleft() @@ -396,16 +394,6 @@ def mark_running(self) -> None: self.handle_inconsistent_state_error() self.__state = _ClientState.TASK_RUNNING - def __ensure_queue(self) -> deque[bytes]: - if (queue := self._datagram_queue) is None: - self._datagram_queue = queue = deque() - return queue - - def __ensure_queue_condition_var(self) -> ICondition: - if (cond := self._queue_condition) is None: - self._queue_condition = cond = self.__backend.create_condition_var() - return cond - @staticmethod def handle_inconsistent_state_error() -> NoReturn: msg = "The server has created too many tasks and ends up in an inconsistent state." diff --git a/src/easynetwork/lowlevel/api_async/servers/stream.py b/src/easynetwork/lowlevel/api_async/servers/stream.py index 6c950f5a..2cdf6225 100644 --- a/src/easynetwork/lowlevel/api_async/servers/stream.py +++ b/src/easynetwork/lowlevel/api_async/servers/stream.py @@ -216,10 +216,7 @@ async def __client_coroutine( consumer=consumer, ) else: - _warn_msg = f"The transport implementation {transport!r} does not implement AsyncBufferedStreamReadTransport interface." - _warn_msg = f"{_warn_msg} Consider using StreamProtocol instead of BufferedStreamProtocol." - warnings.warn(_warn_msg, category=ManualBufferAllocationWarning, stacklevel=1) - del _warn_msg + self.__manual_buffer_allocation_warning(transport) consumer = _stream.StreamDataConsumer(self.__protocol.into_data_protocol()) request_receiver = _RequestReceiver( transport=transport, @@ -247,91 +244,100 @@ async def __client_coroutine( ) ) - del client_exit_stack, task_exit_stack, client_connected_cb - timeout: float | None try: timeout = await anext(request_handler_generator) except StopAsyncIteration: return else: - while True: - try: + try: + action: AsyncGenAction[_T_Request] | None + while True: action = await request_receiver.next(timeout) - except StopAsyncIteration: - break - try: - timeout = await action.asend(request_handler_generator) - except StopAsyncIteration: - break - finally: - del action + try: + timeout = await action.asend(request_handler_generator) + finally: + action = None + except StopAsyncIteration: + return finally: await request_handler_generator.aclose() + @staticmethod + def __manual_buffer_allocation_warning(transport: AsyncStreamTransport) -> None: + _warn_msg = " ".join( + [ + f"The transport implementation {transport!r} does not implement AsyncBufferedStreamReadTransport interface.", + "Consider using StreamProtocol instead of BufferedStreamProtocol.", + ] + ) + warnings.warn(_warn_msg, category=ManualBufferAllocationWarning, stacklevel=2) + @property @_utils.inherit_doc(AsyncBaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__listener.extra_attributes -@dataclasses.dataclass(kw_only=True, eq=False, frozen=True, slots=True) +@dataclasses.dataclass(kw_only=True, eq=False, slots=True) class _RequestReceiver(Generic[_T_Request]): transport: AsyncStreamReadTransport consumer: _stream.StreamDataConsumer[_T_Request] max_recv_size: int __null_timeout_ctx: contextlib.nullcontext[None] = dataclasses.field(init=False, default_factory=contextlib.nullcontext) + __backend: AsyncBackend = dataclasses.field(init=False) def __post_init__(self) -> None: assert self.max_recv_size > 0, f"{self.max_recv_size=}" # nosec assert_used + self.__backend = self.transport.backend() async def next(self, timeout: float | None) -> AsyncGenAction[_T_Request]: try: consumer = self.consumer - try: - request = consumer.next(None) - except StopIteration: - pass - else: - return SendAction(request) - - with self.__null_timeout_ctx if timeout is None else self.transport.backend().timeout(timeout): - while data := await self.transport.recv(self.max_recv_size): + with self.__null_timeout_ctx if timeout is None else self.__backend.timeout(timeout): + data: bytes | None = None + while True: try: request = consumer.next(data) except StopIteration: - continue + pass + else: + return SendAction(request) finally: - del data - return SendAction(request) + data = None + data = await self.transport.recv(self.max_recv_size) + if not data: + break except BaseException as exc: return ThrowAction(exc) raise StopAsyncIteration -@dataclasses.dataclass(kw_only=True, eq=False, frozen=True, slots=True) +@dataclasses.dataclass(kw_only=True, eq=False, slots=True) class _BufferedRequestReceiver(Generic[_T_Request]): transport: AsyncBufferedStreamReadTransport consumer: _stream.BufferedStreamDataConsumer[_T_Request] __null_timeout_ctx: contextlib.nullcontext[None] = dataclasses.field(init=False, default_factory=contextlib.nullcontext) + __backend: AsyncBackend = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.__backend = self.transport.backend() async def next(self, timeout: float | None) -> AsyncGenAction[_T_Request]: try: consumer = self.consumer - try: - request = consumer.next(None) - except StopIteration: - pass - else: - return SendAction(request) - - with self.__null_timeout_ctx if timeout is None else self.transport.backend().timeout(timeout): - while nbytes := await self.transport.recv_into(consumer.get_write_buffer()): + with self.__null_timeout_ctx if timeout is None else self.__backend.timeout(timeout): + nbytes: int | None = None + while True: try: request = consumer.next(nbytes) except StopIteration: - continue - return SendAction(request) + pass + else: + return SendAction(request) + nbytes = await self.transport.recv_into(consumer.get_write_buffer()) + if not nbytes: + break except BaseException as exc: return ThrowAction(exc) raise StopAsyncIteration diff --git a/src/easynetwork/servers/misc.py b/src/easynetwork/servers/misc.py index 571af44a..3cbc68ae 100644 --- a/src/easynetwork/servers/misc.py +++ b/src/easynetwork/servers/misc.py @@ -29,7 +29,7 @@ from .._typevars import _T_Request, _T_Response from ..lowlevel import _utils -from ..lowlevel._asyncgen import AsyncGenAction, SendAction, ThrowAction +from ..lowlevel._asyncgen import AsyncGenAction from ..lowlevel.api_async.servers import datagram as _lowlevel_datagram_server, stream as _lowlevel_stream_server from .handlers import AsyncDatagramClient, AsyncDatagramRequestHandler, AsyncStreamClient, AsyncStreamRequestHandler @@ -63,6 +63,8 @@ def build_lowlevel_stream_server_handler( if logger is None: logger = logging.getLogger(__name__) + from ..lowlevel._asyncgen import SendAction, ThrowAction + async def handler( lowlevel_client: _lowlevel_stream_server.Client[_T_Response], / ) -> AsyncGenerator[float | None, _T_Request]: @@ -179,6 +181,8 @@ def build_lowlevel_datagram_server_handler( an :term:`asynchronous generator` function. """ + from ..lowlevel._asyncgen import SendAction, ThrowAction + async def handler( lowlevel_client: _lowlevel_datagram_server.DatagramClientContext[_T_Response, _T_Address], / ) -> AsyncGenerator[float | None, _T_Request]: diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py index 465491c2..177625fc 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py @@ -206,10 +206,6 @@ def client_data(mock_backend: MagicMock) -> _ClientData: def get_client_state(client_data: _ClientData) -> _ClientState | None: return client_data.state - @staticmethod - def get_client_queue(client_data: _ClientData) -> deque[bytes] | None: - return client_data._datagram_queue - def test____dunder_init____default( self, client_data: _ClientData, @@ -219,8 +215,6 @@ def test____dunder_init____default( # Act & Assert assert isinstance(client_data.task_lock, asyncio.Lock) assert client_data.state is None - assert client_data._datagram_queue is None - assert client_data._queue_condition is None def test____client_state____regular_state_transition( self, @@ -279,7 +273,6 @@ async def test____datagram_queue____push_datagram( client_data: _ClientData, ) -> None: # Arrange - assert self.get_client_queue(client_data) is None # Act await client_data.push_datagram(b"datagram_1") @@ -287,8 +280,6 @@ async def test____datagram_queue____push_datagram( await client_data.push_datagram(b"datagram_3") # Assert - assert client_data._datagram_queue is not None - assert client_data._queue_condition is None assert list(client_data._datagram_queue) == [b"datagram_1", b"datagram_2", b"datagram_3"] @pytest.mark.asyncio @@ -313,19 +304,12 @@ async def test____datagram_queue____pop_datagram( # Assert assert len(client_data._datagram_queue) == 0 - if no_wait: - assert client_data._queue_condition is None - else: - assert client_data._queue_condition is not None - @pytest.mark.parametrize("queue", [deque(), None], ids=lambda p: f"queue=={p!r}") def test____datagram_queue____pop_datagram_no_wait____empty_list( self, - queue: deque[bytes] | None, client_data: _ClientData, ) -> None: # Arrange - client_data._datagram_queue = queue # Act & Assert with pytest.raises(IndexError):