Skip to content

Commit

Permalink
[FIX] TCP server: Avoid awaiting in 'except StopIteration' clause
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia committed Nov 1, 2023
1 parent 93fd6a8 commit 1b44ff5
Show file tree
Hide file tree
Showing 2 changed files with 43 additions and 21 deletions.
4 changes: 2 additions & 2 deletions src/easynetwork/api_async/server/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -371,9 +371,9 @@ async def __client_coroutine(

request_handler_generator: AsyncGenerator[None, _RequestT] | None = None
_on_connection_hook = self.__request_handler.on_connection(client)
if inspect.isasyncgen(_on_connection_hook):
if isinstance(_on_connection_hook, AsyncGenerator):
try:
await _on_connection_hook.asend(None)
await anext(_on_connection_hook)
except StopAsyncIteration:
pass
else:
Expand Down
60 changes: 41 additions & 19 deletions src/easynetwork/lowlevel/api_async/servers/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,7 +20,7 @@

import contextlib
from collections.abc import AsyncGenerator, Callable, Mapping
from typing import Any, Generic, NoReturn
from typing import Any, Generic, NoReturn, Self

from .... import protocol as protocol_module
from ...._typevars import _RequestT, _ResponseT
Expand Down Expand Up @@ -178,24 +178,7 @@ async def __client_coroutine(
except StopAsyncIteration:
return

bufsize: int = self.__max_recv_size
action: _asyncgen.AsyncGenAction[None, _RequestT]
while not transport.is_closing():
try:
try:
action = _asyncgen.SendAction(next(consumer))
except StopIteration:
data: bytes = await transport.recv(bufsize)
if not data: # Closed connection (EOF)
break
try:
consumer.feed(data)
finally:
del data
continue
except BaseException as exc:
action = _asyncgen.ThrowAction(exc)

async for action in _RequestReceiver(transport, consumer, self.__max_recv_size):
try:
await action.asend(request_handler_generator)
except StopAsyncIteration:
Expand All @@ -211,3 +194,42 @@ def max_recv_size(self) -> int:
@property
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return self.__listener.extra_attributes


class _RequestReceiver(Generic[_RequestT]):
__slots__ = ("__consumer", "__transport", "__max_recv_size")

def __init__(
self,
transport: transports.AsyncStreamReadTransport,
consumer: _stream.StreamDataConsumer[_RequestT],
max_recv_size: int,
) -> None:
assert max_recv_size > 0, f"{max_recv_size=}" # nosec assert_used
self.__transport: transports.AsyncStreamReadTransport = transport
self.__consumer: _stream.StreamDataConsumer[_RequestT] = consumer
self.__max_recv_size: int = max_recv_size

def __aiter__(self) -> Self:
return self

async def __anext__(self) -> _asyncgen.AsyncGenAction[None, _RequestT]:
transport: transports.AsyncStreamReadTransport = self.__transport
consumer: _stream.StreamDataConsumer[_RequestT] = self.__consumer
bufsize: int = self.__max_recv_size
try:
while not transport.is_closing():
try:
return _asyncgen.SendAction(next(consumer))
except StopIteration:
pass
data: bytes = await transport.recv(bufsize)
if not data: # Closed connection (EOF)
break
try:
consumer.feed(data)
finally:
del data
except BaseException as exc:
return _asyncgen.ThrowAction(exc)
raise StopAsyncIteration

0 comments on commit 1b44ff5

Please sign in to comment.