diff --git a/src/easynetwork/lowlevel/api_async/servers/datagram.py b/src/easynetwork/lowlevel/api_async/servers/datagram.py index a3870103..3cc0e2de 100644 --- a/src/easynetwork/lowlevel/api_async/servers/datagram.py +++ b/src/easynetwork/lowlevel/api_async/servers/datagram.py @@ -235,30 +235,33 @@ async def __client_coroutine( except StopAsyncIteration: return else: - action: AsyncGenAction[_T_Request] = self.__parse_datagram(datagram, self.__protocol) + action: AsyncGenAction[_T_Request] | None + action = self.__parse_datagram(datagram, self.__protocol) try: timeout = await action.asend(request_handler_generator) except StopAsyncIteration: return finally: - del action + action = None del datagram - while True: - try: - with contextlib.nullcontext() if timeout is None else client_data.backend.timeout(timeout): - datagram = await client_data.pop_datagram() - - action = self.__parse_datagram(datagram, self.__protocol) - del datagram - except BaseException as exc: - action = ThrowAction(exc) - try: - timeout = await action.asend(request_handler_generator) - except StopAsyncIteration: - break - finally: - del action + null_timeout_ctx = contextlib.nullcontext() + try: + while True: + try: + with null_timeout_ctx if timeout is None else client_data.backend.timeout(timeout): + datagram = await client_data.pop_datagram() + action = self.__parse_datagram(datagram, self.__protocol) + except BaseException as exc: + action = ThrowAction(exc) + finally: + datagram = b"" + try: + timeout = await action.asend(request_handler_generator) + finally: + action = None + except StopAsyncIteration: + pass finally: await request_handler_generator.aclose() diff --git a/src/easynetwork/lowlevel/api_async/servers/stream.py b/src/easynetwork/lowlevel/api_async/servers/stream.py index 6c950f5a..2cdf6225 100644 --- a/src/easynetwork/lowlevel/api_async/servers/stream.py +++ b/src/easynetwork/lowlevel/api_async/servers/stream.py @@ -216,10 +216,7 @@ async def __client_coroutine( consumer=consumer, ) else: - _warn_msg = f"The transport implementation {transport!r} does not implement AsyncBufferedStreamReadTransport interface." - _warn_msg = f"{_warn_msg} Consider using StreamProtocol instead of BufferedStreamProtocol." - warnings.warn(_warn_msg, category=ManualBufferAllocationWarning, stacklevel=1) - del _warn_msg + self.__manual_buffer_allocation_warning(transport) consumer = _stream.StreamDataConsumer(self.__protocol.into_data_protocol()) request_receiver = _RequestReceiver( transport=transport, @@ -247,91 +244,100 @@ async def __client_coroutine( ) ) - del client_exit_stack, task_exit_stack, client_connected_cb - timeout: float | None try: timeout = await anext(request_handler_generator) except StopAsyncIteration: return else: - while True: - try: + try: + action: AsyncGenAction[_T_Request] | None + while True: action = await request_receiver.next(timeout) - except StopAsyncIteration: - break - try: - timeout = await action.asend(request_handler_generator) - except StopAsyncIteration: - break - finally: - del action + try: + timeout = await action.asend(request_handler_generator) + finally: + action = None + except StopAsyncIteration: + return finally: await request_handler_generator.aclose() + @staticmethod + def __manual_buffer_allocation_warning(transport: AsyncStreamTransport) -> None: + _warn_msg = " ".join( + [ + f"The transport implementation {transport!r} does not implement AsyncBufferedStreamReadTransport interface.", + "Consider using StreamProtocol instead of BufferedStreamProtocol.", + ] + ) + warnings.warn(_warn_msg, category=ManualBufferAllocationWarning, stacklevel=2) + @property @_utils.inherit_doc(AsyncBaseTransport) def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__listener.extra_attributes -@dataclasses.dataclass(kw_only=True, eq=False, frozen=True, slots=True) +@dataclasses.dataclass(kw_only=True, eq=False, slots=True) class _RequestReceiver(Generic[_T_Request]): transport: AsyncStreamReadTransport consumer: _stream.StreamDataConsumer[_T_Request] max_recv_size: int __null_timeout_ctx: contextlib.nullcontext[None] = dataclasses.field(init=False, default_factory=contextlib.nullcontext) + __backend: AsyncBackend = dataclasses.field(init=False) def __post_init__(self) -> None: assert self.max_recv_size > 0, f"{self.max_recv_size=}" # nosec assert_used + self.__backend = self.transport.backend() async def next(self, timeout: float | None) -> AsyncGenAction[_T_Request]: try: consumer = self.consumer - try: - request = consumer.next(None) - except StopIteration: - pass - else: - return SendAction(request) - - with self.__null_timeout_ctx if timeout is None else self.transport.backend().timeout(timeout): - while data := await self.transport.recv(self.max_recv_size): + with self.__null_timeout_ctx if timeout is None else self.__backend.timeout(timeout): + data: bytes | None = None + while True: try: request = consumer.next(data) except StopIteration: - continue + pass + else: + return SendAction(request) finally: - del data - return SendAction(request) + data = None + data = await self.transport.recv(self.max_recv_size) + if not data: + break except BaseException as exc: return ThrowAction(exc) raise StopAsyncIteration -@dataclasses.dataclass(kw_only=True, eq=False, frozen=True, slots=True) +@dataclasses.dataclass(kw_only=True, eq=False, slots=True) class _BufferedRequestReceiver(Generic[_T_Request]): transport: AsyncBufferedStreamReadTransport consumer: _stream.BufferedStreamDataConsumer[_T_Request] __null_timeout_ctx: contextlib.nullcontext[None] = dataclasses.field(init=False, default_factory=contextlib.nullcontext) + __backend: AsyncBackend = dataclasses.field(init=False) + + def __post_init__(self) -> None: + self.__backend = self.transport.backend() async def next(self, timeout: float | None) -> AsyncGenAction[_T_Request]: try: consumer = self.consumer - try: - request = consumer.next(None) - except StopIteration: - pass - else: - return SendAction(request) - - with self.__null_timeout_ctx if timeout is None else self.transport.backend().timeout(timeout): - while nbytes := await self.transport.recv_into(consumer.get_write_buffer()): + with self.__null_timeout_ctx if timeout is None else self.__backend.timeout(timeout): + nbytes: int | None = None + while True: try: request = consumer.next(nbytes) except StopIteration: - continue - return SendAction(request) + pass + else: + return SendAction(request) + nbytes = await self.transport.recv_into(consumer.get_write_buffer()) + if not nbytes: + break except BaseException as exc: return ThrowAction(exc) raise StopAsyncIteration diff --git a/src/easynetwork/servers/misc.py b/src/easynetwork/servers/misc.py index 571af44a..3cbc68ae 100644 --- a/src/easynetwork/servers/misc.py +++ b/src/easynetwork/servers/misc.py @@ -29,7 +29,7 @@ from .._typevars import _T_Request, _T_Response from ..lowlevel import _utils -from ..lowlevel._asyncgen import AsyncGenAction, SendAction, ThrowAction +from ..lowlevel._asyncgen import AsyncGenAction from ..lowlevel.api_async.servers import datagram as _lowlevel_datagram_server, stream as _lowlevel_stream_server from .handlers import AsyncDatagramClient, AsyncDatagramRequestHandler, AsyncStreamClient, AsyncStreamRequestHandler @@ -63,6 +63,8 @@ def build_lowlevel_stream_server_handler( if logger is None: logger = logging.getLogger(__name__) + from ..lowlevel._asyncgen import SendAction, ThrowAction + async def handler( lowlevel_client: _lowlevel_stream_server.Client[_T_Response], / ) -> AsyncGenerator[float | None, _T_Request]: @@ -179,6 +181,8 @@ def build_lowlevel_datagram_server_handler( an :term:`asynchronous generator` function. """ + from ..lowlevel._asyncgen import SendAction, ThrowAction + async def handler( lowlevel_client: _lowlevel_datagram_server.DatagramClientContext[_T_Response, _T_Address], / ) -> AsyncGenerator[float | None, _T_Request]: