diff --git a/src/easynetwork/api_async/server/tcp.py b/src/easynetwork/api_async/server/tcp.py index 4d00ba97..a8d45431 100644 --- a/src/easynetwork/api_async/server/tcp.py +++ b/src/easynetwork/api_async/server/tcp.py @@ -371,9 +371,9 @@ async def __client_coroutine( request_handler_generator: AsyncGenerator[None, _RequestT] | None = None _on_connection_hook = self.__request_handler.on_connection(client) - if inspect.isasyncgen(_on_connection_hook): + if isinstance(_on_connection_hook, AsyncGenerator): try: - await _on_connection_hook.asend(None) + await anext(_on_connection_hook) except StopAsyncIteration: pass else: diff --git a/src/easynetwork/lowlevel/api_async/servers/stream.py b/src/easynetwork/lowlevel/api_async/servers/stream.py index efe34d28..6e0e7433 100644 --- a/src/easynetwork/lowlevel/api_async/servers/stream.py +++ b/src/easynetwork/lowlevel/api_async/servers/stream.py @@ -20,7 +20,7 @@ import contextlib from collections.abc import AsyncGenerator, Callable, Mapping -from typing import Any, Generic, NoReturn +from typing import Any, Generic, NoReturn, Self from .... import protocol as protocol_module from ...._typevars import _RequestT, _ResponseT @@ -178,24 +178,7 @@ async def __client_coroutine( except StopAsyncIteration: return - bufsize: int = self.__max_recv_size - action: _asyncgen.AsyncGenAction[None, _RequestT] - while not transport.is_closing(): - try: - try: - action = _asyncgen.SendAction(next(consumer)) - except StopIteration: - data: bytes = await transport.recv(bufsize) - if not data: # Closed connection (EOF) - break - try: - consumer.feed(data) - finally: - del data - continue - except BaseException as exc: - action = _asyncgen.ThrowAction(exc) - + async for action in _RequestReceiver(transport, consumer, self.__max_recv_size): try: await action.asend(request_handler_generator) except StopAsyncIteration: @@ -211,3 +194,42 @@ def max_recv_size(self) -> int: @property def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]: return self.__listener.extra_attributes + + +class _RequestReceiver(Generic[_RequestT]): + __slots__ = ("__consumer", "__transport", "__max_recv_size") + + def __init__( + self, + transport: transports.AsyncStreamReadTransport, + consumer: _stream.StreamDataConsumer[_RequestT], + max_recv_size: int, + ) -> None: + assert max_recv_size > 0, f"{max_recv_size=}" # nosec assert_used + self.__transport: transports.AsyncStreamReadTransport = transport + self.__consumer: _stream.StreamDataConsumer[_RequestT] = consumer + self.__max_recv_size: int = max_recv_size + + def __aiter__(self) -> Self: + return self + + async def __anext__(self) -> _asyncgen.AsyncGenAction[None, _RequestT]: + transport: transports.AsyncStreamReadTransport = self.__transport + consumer: _stream.StreamDataConsumer[_RequestT] = self.__consumer + bufsize: int = self.__max_recv_size + try: + while not transport.is_closing(): + try: + return _asyncgen.SendAction(next(consumer)) + except StopIteration: + pass + data: bytes = await transport.recv(bufsize) + if not data: # Closed connection (EOF) + break + try: + consumer.feed(data) + finally: + del data + except BaseException as exc: + return _asyncgen.ThrowAction(exc) + raise StopAsyncIteration