Skip to content

Commit

Permalink
Fix: Improved server performances (slightly) (#304)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Jun 23, 2024
1 parent 0f88537 commit bb44709
Show file tree
Hide file tree
Showing 4 changed files with 118 additions and 136 deletions.
144 changes: 66 additions & 78 deletions src/easynetwork/lowlevel/api_async/servers/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import weakref
from collections import deque
from collections.abc import AsyncGenerator, Callable, Hashable, Mapping
from contextlib import AsyncExitStack, ExitStack
from contextlib import AsyncExitStack
from typing import Any, Generic, NoReturn, TypeVar

from ...._typevars import _T_Request, _T_Response
Expand Down Expand Up @@ -185,15 +185,13 @@ async def handler(datagram: bytes, address: _T_Address, /) -> None:
client = client_cache[address]
except KeyError:
client_cache[address] = client = _ClientToken(DatagramClientContext(address, self), _ClientData(backend))
new_client_task = True
else:
new_client_task = client.data.state is None

if new_client_task:
await client.data.push_datagram(datagram)

if client.data.state is None:
del datagram
client.data.mark_pending()
await self.__client_coroutine(datagram_received_cb, datagram, client, task_group, default_context)
else:
await client.data.push_datagram(datagram)
await self.__client_coroutine(datagram_received_cb, client, task_group, default_context)

await listener.serve(handler, task_group)

Expand All @@ -204,63 +202,67 @@ async def __client_coroutine(
datagram_received_cb: Callable[
[DatagramClientContext[_T_Response, _T_Address]], AsyncGenerator[float | None, _T_Request]
],
datagram: bytes,
client: _ClientToken[_T_Response, _T_Address],
task_group: TaskGroup,
default_context: contextvars.Context,
) -> None:
client_data = client.data
async with client_data.task_lock:
with ExitStack() as exit_stack:
#####################################################################################################
# CRITICAL SECTION
# This block must not have any asynchronous function calls
# or add any asynchronous callbacks/contexts to the exit stack.
client_data.mark_running()
exit_stack.callback(
self.__on_task_done,
async with client.data.task_lock:
client.data.mark_running()
try:
await self.__client_coroutine_inner_loop(
request_handler_generator=datagram_received_cb(client.ctx),
client_data=client.data,
)
finally:
self.__on_task_done(
datagram_received_cb=datagram_received_cb,
client=client,
task_group=task_group,
default_context=default_context,
)
#####################################################################################################

request_handler_generator = datagram_received_cb(client.ctx)
timeout: float | None

async def __client_coroutine_inner_loop(
self,
*,
request_handler_generator: AsyncGenerator[float | None, _T_Request],
client_data: _ClientData,
) -> None:
timeout: float | None
datagram: bytes = client_data.pop_datagram_no_wait()
try:
# Ignore sent timeout here, we already have the datagram.
await anext(request_handler_generator)
except StopAsyncIteration:
return
else:
action: AsyncGenAction[_T_Request] | None
action = self.__parse_datagram(datagram, self.__protocol)
try:
timeout = await action.asend(request_handler_generator)
except StopAsyncIteration:
return
finally:
action = None

del datagram
null_timeout_ctx = contextlib.nullcontext()
while True:
try:
# Ignore sent timeout here, we already have the datagram.
await anext(request_handler_generator)
with null_timeout_ctx if timeout is None else client_data.backend.timeout(timeout):
datagram = await client_data.pop_datagram()
action = self.__parse_datagram(datagram, self.__protocol)
except BaseException as exc:
action = ThrowAction(exc)
finally:
datagram = b""
try:
timeout = await action.asend(request_handler_generator)
except StopAsyncIteration:
return
else:
action: AsyncGenAction[_T_Request] = self.__parse_datagram(datagram, self.__protocol)
try:
timeout = await action.asend(request_handler_generator)
except StopAsyncIteration:
return
finally:
del action

del datagram
while True:
try:
with contextlib.nullcontext() if timeout is None else client_data.backend.timeout(timeout):
datagram = await client_data.pop_datagram()

action = self.__parse_datagram(datagram, self.__protocol)
del datagram
except BaseException as exc:
action = ThrowAction(exc)
try:
timeout = await action.asend(request_handler_generator)
except StopAsyncIteration:
break
finally:
del action
break
finally:
await request_handler_generator.aclose()
action = None
finally:
await request_handler_generator.aclose()

def __on_task_done(
self,
Expand All @@ -272,17 +274,14 @@ def __on_task_done(
default_context: contextvars.Context,
) -> None:
client.data.mark_done()
try:
pending_datagram = client.data.pop_datagram_no_wait()
except IndexError:
if client.data.queue_is_empty():
return

client.data.mark_pending()
default_context.run(
task_group.start_soon,
self.__client_coroutine,
datagram_received_cb,
pending_datagram,
client,
task_group,
default_context,
Expand Down Expand Up @@ -347,8 +346,8 @@ def __init__(self, backend: AsyncBackend) -> None:
self.__backend: AsyncBackend = backend
self.__task_lock: ILock = backend.create_lock()
self.__state: _ClientState | None = None
self._queue_condition: ICondition | None = None
self._datagram_queue: deque[bytes] | None = None
self._queue_condition: ICondition = backend.create_condition_var()
self._datagram_queue: deque[bytes] = deque()

@property
def backend(self) -> AsyncBackend:
Expand All @@ -362,21 +361,20 @@ def task_lock(self) -> ILock:
def state(self) -> _ClientState | None:
return self.__state

def queue_is_empty(self) -> bool:
return not self._datagram_queue

async def push_datagram(self, datagram: bytes) -> None:
self.__ensure_queue().append(datagram)
if (queue_condition := self._queue_condition) is not None:
async with queue_condition:
queue_condition.notify()
self._datagram_queue.append(datagram)
async with (queue_condition := self._queue_condition):
queue_condition.notify()

def pop_datagram_no_wait(self) -> bytes:
if not (queue := self._datagram_queue):
raise IndexError("pop from an empty deque")
return queue.popleft()
return self._datagram_queue.popleft()

async def pop_datagram(self) -> bytes:
queue_condition = self.__ensure_queue_condition_var()
async with queue_condition:
queue = self.__ensure_queue()
async with (queue_condition := self._queue_condition):
queue = self._datagram_queue
while not queue:
await queue_condition.wait()
return queue.popleft()
Expand All @@ -396,16 +394,6 @@ def mark_running(self) -> None:
self.handle_inconsistent_state_error()
self.__state = _ClientState.TASK_RUNNING

def __ensure_queue(self) -> deque[bytes]:
if (queue := self._datagram_queue) is None:
self._datagram_queue = queue = deque()
return queue

def __ensure_queue_condition_var(self) -> ICondition:
if (cond := self._queue_condition) is None:
self._queue_condition = cond = self.__backend.create_condition_var()
return cond

@staticmethod
def handle_inconsistent_state_error() -> NoReturn:
msg = "The server has created too many tasks and ends up in an inconsistent state."
Expand Down
88 changes: 47 additions & 41 deletions src/easynetwork/lowlevel/api_async/servers/stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,10 +216,7 @@ async def __client_coroutine(
consumer=consumer,
)
else:
_warn_msg = f"The transport implementation {transport!r} does not implement AsyncBufferedStreamReadTransport interface."
_warn_msg = f"{_warn_msg} Consider using StreamProtocol instead of BufferedStreamProtocol."
warnings.warn(_warn_msg, category=ManualBufferAllocationWarning, stacklevel=1)
del _warn_msg
self.__manual_buffer_allocation_warning(transport)
consumer = _stream.StreamDataConsumer(self.__protocol.into_data_protocol())
request_receiver = _RequestReceiver(
transport=transport,
Expand Down Expand Up @@ -247,91 +244,100 @@ async def __client_coroutine(
)
)

del client_exit_stack, task_exit_stack, client_connected_cb

timeout: float | None
try:
timeout = await anext(request_handler_generator)
except StopAsyncIteration:
return
else:
while True:
try:
try:
action: AsyncGenAction[_T_Request] | None
while True:
action = await request_receiver.next(timeout)
except StopAsyncIteration:
break
try:
timeout = await action.asend(request_handler_generator)
except StopAsyncIteration:
break
finally:
del action
try:
timeout = await action.asend(request_handler_generator)
finally:
action = None
except StopAsyncIteration:
return
finally:
await request_handler_generator.aclose()

@staticmethod
def __manual_buffer_allocation_warning(transport: AsyncStreamTransport) -> None:
_warn_msg = " ".join(
[
f"The transport implementation {transport!r} does not implement AsyncBufferedStreamReadTransport interface.",
"Consider using StreamProtocol instead of BufferedStreamProtocol.",
]
)
warnings.warn(_warn_msg, category=ManualBufferAllocationWarning, stacklevel=2)

@property
@_utils.inherit_doc(AsyncBaseTransport)
def extra_attributes(self) -> Mapping[Any, Callable[[], Any]]:
return self.__listener.extra_attributes


@dataclasses.dataclass(kw_only=True, eq=False, frozen=True, slots=True)
@dataclasses.dataclass(kw_only=True, eq=False, slots=True)
class _RequestReceiver(Generic[_T_Request]):
transport: AsyncStreamReadTransport
consumer: _stream.StreamDataConsumer[_T_Request]
max_recv_size: int
__null_timeout_ctx: contextlib.nullcontext[None] = dataclasses.field(init=False, default_factory=contextlib.nullcontext)
__backend: AsyncBackend = dataclasses.field(init=False)

def __post_init__(self) -> None:
assert self.max_recv_size > 0, f"{self.max_recv_size=}" # nosec assert_used
self.__backend = self.transport.backend()

async def next(self, timeout: float | None) -> AsyncGenAction[_T_Request]:
try:
consumer = self.consumer
try:
request = consumer.next(None)
except StopIteration:
pass
else:
return SendAction(request)

with self.__null_timeout_ctx if timeout is None else self.transport.backend().timeout(timeout):
while data := await self.transport.recv(self.max_recv_size):
with self.__null_timeout_ctx if timeout is None else self.__backend.timeout(timeout):
data: bytes | None = None
while True:
try:
request = consumer.next(data)
except StopIteration:
continue
pass
else:
return SendAction(request)
finally:
del data
return SendAction(request)
data = None
data = await self.transport.recv(self.max_recv_size)
if not data:
break
except BaseException as exc:
return ThrowAction(exc)
raise StopAsyncIteration


@dataclasses.dataclass(kw_only=True, eq=False, frozen=True, slots=True)
@dataclasses.dataclass(kw_only=True, eq=False, slots=True)
class _BufferedRequestReceiver(Generic[_T_Request]):
transport: AsyncBufferedStreamReadTransport
consumer: _stream.BufferedStreamDataConsumer[_T_Request]
__null_timeout_ctx: contextlib.nullcontext[None] = dataclasses.field(init=False, default_factory=contextlib.nullcontext)
__backend: AsyncBackend = dataclasses.field(init=False)

def __post_init__(self) -> None:
self.__backend = self.transport.backend()

async def next(self, timeout: float | None) -> AsyncGenAction[_T_Request]:
try:
consumer = self.consumer
try:
request = consumer.next(None)
except StopIteration:
pass
else:
return SendAction(request)

with self.__null_timeout_ctx if timeout is None else self.transport.backend().timeout(timeout):
while nbytes := await self.transport.recv_into(consumer.get_write_buffer()):
with self.__null_timeout_ctx if timeout is None else self.__backend.timeout(timeout):
nbytes: int | None = None
while True:
try:
request = consumer.next(nbytes)
except StopIteration:
continue
return SendAction(request)
pass
else:
return SendAction(request)
nbytes = await self.transport.recv_into(consumer.get_write_buffer())
if not nbytes:
break
except BaseException as exc:
return ThrowAction(exc)
raise StopAsyncIteration
6 changes: 5 additions & 1 deletion src/easynetwork/servers/misc.py
Original file line number Diff line number Diff line change
Expand Up @@ -29,7 +29,7 @@

from .._typevars import _T_Request, _T_Response
from ..lowlevel import _utils
from ..lowlevel._asyncgen import AsyncGenAction, SendAction, ThrowAction
from ..lowlevel._asyncgen import AsyncGenAction
from ..lowlevel.api_async.servers import datagram as _lowlevel_datagram_server, stream as _lowlevel_stream_server
from .handlers import AsyncDatagramClient, AsyncDatagramRequestHandler, AsyncStreamClient, AsyncStreamRequestHandler

Expand Down Expand Up @@ -63,6 +63,8 @@ def build_lowlevel_stream_server_handler(
if logger is None:
logger = logging.getLogger(__name__)

from ..lowlevel._asyncgen import SendAction, ThrowAction

async def handler(
lowlevel_client: _lowlevel_stream_server.Client[_T_Response], /
) -> AsyncGenerator[float | None, _T_Request]:
Expand Down Expand Up @@ -179,6 +181,8 @@ def build_lowlevel_datagram_server_handler(
an :term:`asynchronous generator` function.
"""

from ..lowlevel._asyncgen import SendAction, ThrowAction

async def handler(
lowlevel_client: _lowlevel_datagram_server.DatagramClientContext[_T_Response, _T_Address], /
) -> AsyncGenerator[float | None, _T_Request]:
Expand Down
Loading

0 comments on commit bb44709

Please sign in to comment.