Skip to content

Commit

Permalink
Datagram servers: Removed task locks because it is useless (#336)
Browse files Browse the repository at this point in the history
  • Loading branch information
francis-clairicia authored Aug 1, 2024
1 parent c319364 commit 94f5e72
Show file tree
Hide file tree
Showing 5 changed files with 51 additions and 34 deletions.
54 changes: 28 additions & 26 deletions src/easynetwork/lowlevel/api_async/servers/datagram.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@
from ....protocol import DatagramProtocol
from ... import _utils
from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction, anext_without_asyncgen_hook
from ..backend.abc import AsyncBackend, ICondition, ILock, TaskGroup
from ..backend.abc import AsyncBackend, ICondition, TaskGroup
from ..transports import abc as _transports

_T_Address = TypeVar("_T_Address", bound=Hashable)
Expand Down Expand Up @@ -75,7 +75,6 @@ class AsyncDatagramServer(_transports.AsyncBaseTransport, Generic[_T_Request, _T
__slots__ = (
"__listener",
"__protocol",
"__sendto_lock",
"__serve_guard",
)

Expand All @@ -96,7 +95,6 @@ def __init__(

self.__listener: _transports.AsyncDatagramListener[_T_Address] = listener
self.__protocol: DatagramProtocol[_T_Response, _T_Request] = protocol
self.__sendto_lock: ILock = listener.backend().create_lock()
self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently receiving datagrams")

def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None:
Expand Down Expand Up @@ -142,8 +140,8 @@ async def send_packet_to(self, packet: _T_Response, address: _T_Address) -> None
datagram: bytes = self.__protocol.make_datagram(packet)
except Exception as exc:
raise RuntimeError("protocol.make_datagram() crashed") from exc
async with self.__sendto_lock:
await self.__listener.send_to(datagram, address)

await self.__listener.send_to(datagram, address)

async def serve(
self,
Expand Down Expand Up @@ -212,20 +210,19 @@ async def __client_coroutine(
task_group: TaskGroup,
default_context: contextvars.Context,
) -> None:
async with client.data.task_lock:
client.data.mark_running()
try:
await self.__client_coroutine_inner_loop(
request_handler_generator=datagram_received_cb(client.ctx),
client_data=client.data,
)
finally:
self.__on_task_done(
datagram_received_cb=datagram_received_cb,
client=client,
task_group=task_group,
default_context=default_context,
)
client.data.mark_running()
try:
await self.__client_coroutine_inner_loop(
request_handler_generator=datagram_received_cb(client.ctx),
client_data=client.data,
)
finally:
self.__on_task_done(
datagram_received_cb=datagram_received_cb,
client=client,
task_group=task_group,
default_context=default_context,
)

async def __client_coroutine_inner_loop(
self,
Expand Down Expand Up @@ -284,7 +281,18 @@ def __on_task_done(
return

client.data.mark_pending()
default_context.run(

# Why copy the context before calling run()?
# Short answer: asyncio.eager_task_factory :)
#
# If asyncio's eager task is enabled in this event loop, there is a chance
# to have a nested call if the request handler does not yield
# and we end up with this error:
# RuntimeError: cannot enter context: <_contextvars.Context object at ...> is already entered
# To avoid that, we always use a new context. The performance cost is negligible.
# See this functional test for a real situation:
# test____serve_forever____too_many_datagrams_while_request_handle_is_performed
default_context.copy().run(
task_group.start_soon,
self.__client_coroutine,
datagram_received_cb,
Expand Down Expand Up @@ -342,15 +350,13 @@ class _ClientToken(Generic[_T_Response, _T_Address]):
class _ClientData:
__slots__ = (
"__backend",
"__task_lock",
"__state",
"_queue_condition",
"_datagram_queue",
)

def __init__(self, backend: AsyncBackend) -> None:
self.__backend: AsyncBackend = backend
self.__task_lock: ILock = backend.create_lock()
self.__state: _ClientState | None = None
self._queue_condition: ICondition = backend.create_condition_var()
self._datagram_queue: deque[bytes] = deque()
Expand All @@ -359,10 +365,6 @@ def __init__(self, backend: AsyncBackend) -> None:
def backend(self) -> AsyncBackend:
return self.__backend

@property
def task_lock(self) -> ILock:
return self.__task_lock

@property
def state(self) -> _ClientState | None:
return self.__state
Expand Down
3 changes: 3 additions & 0 deletions src/easynetwork/lowlevel/api_async/transports/abc.py
Original file line number Diff line number Diff line change
Expand Up @@ -288,6 +288,9 @@ async def send_to(self, data: bytes | bytearray | memoryview, address: _T_Addres
"""
Send the `data` bytes to the remote peer `address`.
Important:
This method should be safe to call from multiple tasks.
Parameters:
data: the bytes to send.
address: the remote peer.
Expand Down
10 changes: 7 additions & 3 deletions tests/functional_test/test_communication/test_async/conftest.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,18 @@
from __future__ import annotations

import asyncio
import sys

import pytest


@pytest.fixture(params=[True, False] if sys.version_info >= (3, 12) else [False], ids=lambda p: f"enable_eager_tasks=={p}")
@pytest.fixture(params=[True, False], ids=lambda p: f"enable_eager_tasks=={p}")
def enable_eager_tasks(request: pytest.FixtureRequest, event_loop: asyncio.AbstractEventLoop) -> bool:
enable_eager_tasks = bool(request.param)
if enable_eager_tasks:
event_loop.set_task_factory(getattr(asyncio, "eager_task_factory"))
try:
eager_task_factory = getattr(asyncio, "eager_task_factory")
except AttributeError:
pytest.skip("asyncio.eager_task_factory not implemented")
else:
event_loop.set_task_factory(eager_task_factory)
return enable_eager_tasks
Original file line number Diff line number Diff line change
Expand Up @@ -166,16 +166,18 @@ async def handle(self, client: AsyncDatagramClient[str]) -> AsyncGenerator[None,


class ConcurrencyTestRequestHandler(AsyncDatagramRequestHandler[str, str]):
sleep_time_before_second_yield: float = 0.0
sleep_time_before_response: float = 0.0
sleep_time_before_second_yield: float | None = None
sleep_time_before_response: float | None = None
recreate_generator: bool = True

async def handle(self, client: AsyncDatagramClient[str]) -> AsyncGenerator[None, str]:
while True:
assert (yield) == "something"
await asyncio.sleep(self.sleep_time_before_second_yield)
if self.sleep_time_before_second_yield is not None:
await asyncio.sleep(self.sleep_time_before_second_yield)
request = yield
await asyncio.sleep(self.sleep_time_before_response)
if self.sleep_time_before_response is not None:
await asyncio.sleep(self.sleep_time_before_response)
await client.send_packet(f"After wait: {request}")
if self.recreate_generator:
break
Expand Down Expand Up @@ -660,9 +662,16 @@ async def test____serve_forever____too_many_datagrams_while_request_handle_is_pe
await endpoint.sendto(b"something", None)
await asyncio.sleep(0.1)
await endpoint.sendto(b"hello, world.", None)
for i in range(3):
await endpoint.sendto(b"something", None)
await endpoint.sendto(f"hello, world {i+2} times.".encode(), None)
await endpoint.sendto(b"something", None)
await asyncio.sleep(0.1)
request_handler.sleep_time_before_response = None
await endpoint.sendto(b"hello, world. new game +", None)
async with asyncio.timeout(5):
assert (await endpoint.recvfrom())[0] == b"After wait: hello, world."
assert (await endpoint.recvfrom())[0] == b"After wait: hello, world 2 times."
assert (await endpoint.recvfrom())[0] == b"After wait: hello, world 3 times."
assert (await endpoint.recvfrom())[0] == b"After wait: hello, world 4 times."
assert (await endpoint.recvfrom())[0] == b"After wait: hello, world. new game +"
Original file line number Diff line number Diff line change
Expand Up @@ -213,7 +213,6 @@ def test____dunder_init____default(
# Arrange

# Act & Assert
assert isinstance(client_data.task_lock, asyncio.Lock)
assert client_data.state is None

def test____client_state____regular_state_transition(
Expand Down

0 comments on commit 94f5e72

Please sign in to comment.