Skip to content

Commit

Permalink
[FIX] Servers: Improved asynchronous generators usage
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia committed Oct 29, 2023
1 parent 6a44f60 commit 1e990fd
Show file tree
Hide file tree
Showing 9 changed files with 107 additions and 187 deletions.
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ class EchoRequestHandler(AsyncStreamRequestHandler[RequestType, ResponseType]):
await client.send_packet({"error": "Invalid JSON", "code": "parse_error"})
return

self.logger.info(f"{client.address} sent {request!r}")
self.logger.info(f"{client!r} sent {request!r}")

# As a good echo handler, the request is sent back to the client
response: ResponseType = request
Expand Down
39 changes: 12 additions & 27 deletions src/easynetwork/api_async/server/tcp.py
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@

from ..._typevars import _RequestT, _ResponseT
from ...exceptions import ClientClosedError, ServerAlreadyRunning, ServerClosedError
from ...lowlevel import _utils, constants
from ...lowlevel import _asyncgen, _utils, constants
from ...lowlevel.api_async.backend.factory import AsyncBackendFactory
from ...lowlevel.api_async.servers import stream as lowlevel_stream_server
from ...lowlevel.socket import (
Expand Down Expand Up @@ -393,48 +393,33 @@ async def disconnect_client() -> None:

del client_exit_stack

from ...lowlevel.api_async.servers._tools.actions import ErrorAction, RequestAction

try:
if client.is_closing():
return
if request_handler_generator is None:
request_handler_generator = await self.__new_request_handler(client)
action: _asyncgen.AsyncGenAction[None, _RequestT]
while not client.is_closing():
if request_handler_generator is None:
return
action: RequestAction[_RequestT] | ErrorAction
while True:
request_handler_generator = self.__request_handler.handle(client)
try:
await anext(request_handler_generator)
except StopAsyncIteration:
request_handler_generator = None
break
try:
action = RequestAction((yield))
action = _asyncgen.SendAction((yield))
except ConnectionError:
return
break
except BaseException as exc:
action = ErrorAction(exc)
action = _asyncgen.ThrowAction(exc)
try:
await action.asend(request_handler_generator)
except StopAsyncIteration:
request_handler_generator = None
finally:
del action
await backend.cancel_shielded_coro_yield()
if client.is_closing():
break
if request_handler_generator is None:
request_handler_generator = await self.__new_request_handler(client)
if request_handler_generator is None:
break
finally:
if request_handler_generator is not None:
await request_handler_generator.aclose()

async def __new_request_handler(self, client: _ConnectedClientAPI[_ResponseT]) -> AsyncGenerator[None, _RequestT] | None:
request_handler_generator = self.__request_handler.handle(client)
try:
await anext(request_handler_generator)
except StopAsyncIteration:
return None
return request_handler_generator

@staticmethod
def __set_socket_linger_if_not_closed(socket: ISocket) -> None:
with contextlib.suppress(OSError):
Expand Down
40 changes: 28 additions & 12 deletions src/easynetwork/api_async/server/udp.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,13 +21,13 @@
import contextlib
import logging
import weakref
from collections import deque
from collections import defaultdict, deque
from collections.abc import AsyncGenerator, Callable, Coroutine, Iterator, Mapping, Sequence
from typing import TYPE_CHECKING, Any, Generic, NoReturn, final

from ..._typevars import _RequestT, _ResponseT
from ...exceptions import ClientClosedError, ServerAlreadyRunning, ServerClosedError
from ...lowlevel import _utils
from ...lowlevel import _asyncgen, _utils
from ...lowlevel.api_async.backend.factory import AsyncBackendFactory
from ...lowlevel.api_async.servers import datagram as lowlevel_datagram_server
from ...lowlevel.api_async.transports.abc import AsyncDatagramListener
Expand Down Expand Up @@ -55,6 +55,7 @@ class AsyncUDPNetworkServer(AbstractAsyncNetworkServer, Generic[_RequestT, _Resp
"__is_shutdown",
"__shutdown_asked",
"__clients_cache",
"__send_locks_cache",
"__servers_tasks",
"__mainloop_task",
"__logger",
Expand Down Expand Up @@ -121,7 +122,16 @@ def __init__(
self.__servers_tasks: deque[Task[NoReturn]] = deque() # type: ignore[assignment]
self.__mainloop_task: Task[NoReturn] | None = None
self.__logger: logging.Logger = logger or logging.getLogger(__name__)
self.__clients_cache: weakref.WeakValueDictionary[SocketAddress, _ClientAPI[_ResponseT]] = weakref.WeakValueDictionary()
self.__clients_cache: defaultdict[
lowlevel_datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]],
weakref.WeakValueDictionary[SocketAddress, _ClientAPI[_ResponseT]],
]
self.__send_locks_cache: defaultdict[
lowlevel_datagram_server.AsyncDatagramServer[_RequestT, _ResponseT, tuple[Any, ...]],
weakref.WeakValueDictionary[SocketAddress, ILock],
]
self.__clients_cache = defaultdict(weakref.WeakValueDictionary)
self.__send_locks_cache = defaultdict(weakref.WeakValueDictionary)

def is_serving(self) -> bool:
return self.__servers is not None and all(not server.is_closing() for server in self.__servers)
Expand Down Expand Up @@ -210,6 +220,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) ->
################

# Initialize request handler
server_exit_stack.callback(self.__send_locks_cache.clear)
server_exit_stack.callback(self.__clients_cache.clear)
await self.__request_handler.service_init(
await server_exit_stack.enter_async_context(contextlib.AsyncExitStack()),
Expand Down Expand Up @@ -243,7 +254,7 @@ async def serve_forever(self, *, is_up_event: SupportsEventSet | None = None) ->
finally:
self.__mainloop_task = None

raise AssertionError("received_datagrams() does not return")
raise AssertionError("sleep_forever() does not return")

serve_forever.__doc__ = AbstractAsyncNetworkServer.serve_forever.__doc__

Expand All @@ -262,26 +273,31 @@ async def __datagram_received_coroutine(
) -> AsyncGenerator[None, _RequestT]:
address = new_socket_address(address, server.extra(INETSocketAttribute.family))
with self.__suppress_and_log_remaining_exception(client_address=address):
send_locks_cache = self.__send_locks_cache[server]
try:
client = self.__clients_cache[address]
send_lock = send_locks_cache[address]
except KeyError:
self.__clients_cache[address] = client = _ClientAPI(address, server, self.__backend.create_lock(), self.__logger)
send_locks_cache[address] = send_lock = self.__backend.create_lock()

clients_cache = self.__clients_cache[server]
try:
client = clients_cache[address]
except KeyError:
clients_cache[address] = client = _ClientAPI(address, server, send_lock, self.__logger)

async with contextlib.aclosing(self.__request_handler.handle(client)) as request_handler_generator:
del client
del client, send_lock
try:
await anext(request_handler_generator)
except StopAsyncIteration:
return

from ...lowlevel.api_async.servers._tools.actions import ErrorAction, RequestAction

action: RequestAction[_RequestT] | ErrorAction
action: _asyncgen.AsyncGenAction[None, _RequestT]
while True:
try:
action = RequestAction((yield))
action = _asyncgen.SendAction((yield))
except BaseException as exc:
action = ErrorAction(exc)
action = _asyncgen.ThrowAction(exc)
try:
await action.asend(request_handler_generator)
except StopAsyncIteration:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,64 +12,46 @@
# limitations under the License.
#
#
"""Low-level asynchronous server module"""
"""Async generators helper module"""

from __future__ import annotations

__all__ = [] # type: list[str]

import dataclasses
from abc import ABCMeta, abstractmethod
from collections.abc import AsyncGenerator, AsyncIterator, Awaitable, Callable
from collections.abc import AsyncGenerator
from typing import Any, Generic, TypeVar

_T = TypeVar("_T")
_T_Send = TypeVar("_T_Send")
_T_Yield = TypeVar("_T_Yield")


class Action(Generic[_T], metaclass=ABCMeta):
class AsyncGenAction(Generic[_T_Yield, _T_Send], metaclass=ABCMeta):
__slots__ = ()

@abstractmethod
async def asend(self, generator: AsyncGenerator[None, _T]) -> None:
async def asend(self, generator: AsyncGenerator[_T_Yield, _T_Send]) -> _T_Yield:
raise NotImplementedError


@dataclasses.dataclass(slots=True)
class RequestAction(Action[_T]):
request: _T
class SendAction(AsyncGenAction[_T_Yield, _T_Send]):
value: _T_Send

async def asend(self, generator: AsyncGenerator[None, _T]) -> None:
async def asend(self, generator: AsyncGenerator[_T_Yield, _T_Send]) -> _T_Yield:
try:
await generator.asend(self.request)
return await generator.asend(self.value)
finally:
del self


@dataclasses.dataclass(slots=True)
class ErrorAction(Action[Any]):
class ThrowAction(AsyncGenAction[_T_Yield, Any]):
exception: BaseException

async def asend(self, generator: AsyncGenerator[None, Any]) -> None:
async def asend(self, generator: AsyncGenerator[_T_Yield, Any]) -> _T_Yield:
try:
await generator.athrow(self.exception)
return await generator.athrow(self.exception)
finally:
del generator, self # Needed to avoid circular reference with raised exception


@dataclasses.dataclass(slots=True)
class ActionIterator(Generic[_T]):
request_factory: Callable[[], Awaitable[_T]]

def __aiter__(self) -> AsyncIterator[Action[_T]]:
return self

async def __anext__(self) -> Action[_T]:
try:
request = await self.request_factory()
except StopAsyncIteration:
raise
except BaseException as exc:
return ErrorAction(exc)
finally:
del self
return RequestAction(request)
19 changes: 0 additions & 19 deletions src/easynetwork/lowlevel/api_async/servers/_tools/__init__.py

This file was deleted.

46 changes: 16 additions & 30 deletions src/easynetwork/lowlevel/api_async/servers/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,9 @@

from .... import protocol as protocol_module
from ...._typevars import _RequestT, _ResponseT
from ... import _utils, typed_attr
from ... import _asyncgen, _utils, typed_attr
from ..backend.abc import AsyncBackend, ICondition, ILock, TaskGroup
from ..transports import abc as transports
from ._tools.actions import ActionIterator as _ActionIterator

_T_Address = TypeVar("_T_Address", bound=Hashable)
_KT = TypeVar("_KT")
Expand Down Expand Up @@ -123,6 +122,7 @@ async def serve(
client_coroutine = self.__client_coroutine
client_manager = self.__client_manager
backend = self.__backend
listener = self.__listener

async def handler(datagram: bytes, address: _T_Address, /) -> None:
with client_manager.datagram_queue(address) as datagram_queue:
Expand Down Expand Up @@ -151,7 +151,7 @@ async def handler(datagram: bytes, address: _T_Address, /) -> None:
assert_never(client_state)

while True:
datagram, address = await self.__listener.recv_from()
datagram, address = await listener.recv_from()
task_group.start_soon(handler, datagram, address)
del datagram, address
await backend.cancel_shielded_coro_yield()
Expand Down Expand Up @@ -200,15 +200,17 @@ async def __client_coroutine(
del datagram_queue[0]
return

request_factory = _utils.make_callback(
self.__request_factory,
datagram_queue,
address,
client_manager,
condition,
self.__protocol,
)
async for action in _ActionIterator(request_factory):
protocol = self.__protocol
action: _asyncgen.AsyncGenAction[None, _RequestT]
while True:
try:
if not datagram_queue:
with client_manager.set_client_state(address, _ClientState.TASK_WAITING):
await condition.wait()
self.__check_datagram_queue_not_empty(datagram_queue)
action = _asyncgen.SendAction(protocol.build_packet_from_datagram(datagram_queue.popleft()))
except BaseException as exc:
action = _asyncgen.ThrowAction(exc)
try:
await action.asend(request_handler_generator)
except StopAsyncIteration:
Expand Down Expand Up @@ -241,24 +243,8 @@ def __clear_queue_on_error(
exc_tb: TracebackType | None,
/,
) -> None:
if exc_type is not None and not issubclass(exc_type, self.__backend.get_cancelled_exc_class()):
datagram_queue.clear() # pragma: no cover

@classmethod
async def __request_factory(
cls,
datagram_queue: deque[bytes],
address: _T_Address,
client_manager: _ClientManager[_T_Address],
condition: ICondition,
protocol: protocol_module.DatagramProtocol[_ResponseT, _RequestT],
/,
) -> _RequestT:
if not datagram_queue:
with client_manager.set_client_state(address, _ClientState.TASK_WAITING):
await condition.wait()
cls.__check_datagram_queue_not_empty(datagram_queue)
return protocol.build_packet_from_datagram(datagram_queue.popleft())
if exc_type is not None:
datagram_queue.clear()

@staticmethod
def __check_datagram_queue_not_empty(datagram_queue: deque[bytes]) -> None:
Expand Down
Loading

0 comments on commit 1e990fd

Please sign in to comment.