diff --git a/src/easynetwork/lowlevel/api_async/servers/datagram.py b/src/easynetwork/lowlevel/api_async/servers/datagram.py index 6319d19d..d05704fc 100644 --- a/src/easynetwork/lowlevel/api_async/servers/datagram.py +++ b/src/easynetwork/lowlevel/api_async/servers/datagram.py @@ -33,7 +33,7 @@ from ....protocol import DatagramProtocol from ... import _utils from ..._asyncgen import AsyncGenAction, SendAction, ThrowAction, anext_without_asyncgen_hook -from ..backend.abc import AsyncBackend, ICondition, ILock, TaskGroup +from ..backend.abc import AsyncBackend, ICondition, TaskGroup from ..transports import abc as _transports _T_Address = TypeVar("_T_Address", bound=Hashable) @@ -75,7 +75,6 @@ class AsyncDatagramServer(_transports.AsyncBaseTransport, Generic[_T_Request, _T __slots__ = ( "__listener", "__protocol", - "__sendto_lock", "__serve_guard", ) @@ -96,7 +95,6 @@ def __init__( self.__listener: _transports.AsyncDatagramListener[_T_Address] = listener self.__protocol: DatagramProtocol[_T_Response, _T_Request] = protocol - self.__sendto_lock: ILock = listener.backend().create_lock() self.__serve_guard: _utils.ResourceGuard = _utils.ResourceGuard("another task is currently receiving datagrams") def __del__(self, *, _warn: _utils.WarnCallback = warnings.warn) -> None: @@ -142,8 +140,8 @@ async def send_packet_to(self, packet: _T_Response, address: _T_Address) -> None datagram: bytes = self.__protocol.make_datagram(packet) except Exception as exc: raise RuntimeError("protocol.make_datagram() crashed") from exc - async with self.__sendto_lock: - await self.__listener.send_to(datagram, address) + + await self.__listener.send_to(datagram, address) async def serve( self, @@ -212,20 +210,19 @@ async def __client_coroutine( task_group: TaskGroup, default_context: contextvars.Context, ) -> None: - async with client.data.task_lock: - client.data.mark_running() - try: - await self.__client_coroutine_inner_loop( - request_handler_generator=datagram_received_cb(client.ctx), - client_data=client.data, - ) - finally: - self.__on_task_done( - datagram_received_cb=datagram_received_cb, - client=client, - task_group=task_group, - default_context=default_context, - ) + client.data.mark_running() + try: + await self.__client_coroutine_inner_loop( + request_handler_generator=datagram_received_cb(client.ctx), + client_data=client.data, + ) + finally: + self.__on_task_done( + datagram_received_cb=datagram_received_cb, + client=client, + task_group=task_group, + default_context=default_context, + ) async def __client_coroutine_inner_loop( self, @@ -284,7 +281,18 @@ def __on_task_done( return client.data.mark_pending() - default_context.run( + + # Why copy the context before calling run()? + # Short answer: asyncio.eager_task_factory :) + # + # If asyncio's eager task is enabled in this event loop, there is a chance + # to have a nested call if the request handler does not yield + # and we end up with this error: + # RuntimeError: cannot enter context: <_contextvars.Context object at ...> is already entered + # To avoid that, we always use a new context. The performance cost is negligible. + # See this functional test for a real situation: + # test____serve_forever____too_many_datagrams_while_request_handle_is_performed + default_context.copy().run( task_group.start_soon, self.__client_coroutine, datagram_received_cb, @@ -342,7 +350,6 @@ class _ClientToken(Generic[_T_Response, _T_Address]): class _ClientData: __slots__ = ( "__backend", - "__task_lock", "__state", "_queue_condition", "_datagram_queue", @@ -350,7 +357,6 @@ class _ClientData: def __init__(self, backend: AsyncBackend) -> None: self.__backend: AsyncBackend = backend - self.__task_lock: ILock = backend.create_lock() self.__state: _ClientState | None = None self._queue_condition: ICondition = backend.create_condition_var() self._datagram_queue: deque[bytes] = deque() @@ -359,10 +365,6 @@ def __init__(self, backend: AsyncBackend) -> None: def backend(self) -> AsyncBackend: return self.__backend - @property - def task_lock(self) -> ILock: - return self.__task_lock - @property def state(self) -> _ClientState | None: return self.__state diff --git a/src/easynetwork/lowlevel/api_async/transports/abc.py b/src/easynetwork/lowlevel/api_async/transports/abc.py index 97bea488..4b3f79ce 100644 --- a/src/easynetwork/lowlevel/api_async/transports/abc.py +++ b/src/easynetwork/lowlevel/api_async/transports/abc.py @@ -288,6 +288,9 @@ async def send_to(self, data: bytes | bytearray | memoryview, address: _T_Addres """ Send the `data` bytes to the remote peer `address`. + Important: + This method should be safe to call from multiple tasks. + Parameters: data: the bytes to send. address: the remote peer. diff --git a/tests/functional_test/test_communication/test_async/conftest.py b/tests/functional_test/test_communication/test_async/conftest.py index be734138..3a949431 100644 --- a/tests/functional_test/test_communication/test_async/conftest.py +++ b/tests/functional_test/test_communication/test_async/conftest.py @@ -1,14 +1,18 @@ from __future__ import annotations import asyncio -import sys import pytest -@pytest.fixture(params=[True, False] if sys.version_info >= (3, 12) else [False], ids=lambda p: f"enable_eager_tasks=={p}") +@pytest.fixture(params=[True, False], ids=lambda p: f"enable_eager_tasks=={p}") def enable_eager_tasks(request: pytest.FixtureRequest, event_loop: asyncio.AbstractEventLoop) -> bool: enable_eager_tasks = bool(request.param) if enable_eager_tasks: - event_loop.set_task_factory(getattr(asyncio, "eager_task_factory")) + try: + eager_task_factory = getattr(asyncio, "eager_task_factory") + except AttributeError: + pytest.skip("asyncio.eager_task_factory not implemented") + else: + event_loop.set_task_factory(eager_task_factory) return enable_eager_tasks diff --git a/tests/functional_test/test_communication/test_async/test_server/test_udp.py b/tests/functional_test/test_communication/test_async/test_server/test_udp.py index ebfea429..8ef2445c 100644 --- a/tests/functional_test/test_communication/test_async/test_server/test_udp.py +++ b/tests/functional_test/test_communication/test_async/test_server/test_udp.py @@ -166,16 +166,18 @@ async def handle(self, client: AsyncDatagramClient[str]) -> AsyncGenerator[None, class ConcurrencyTestRequestHandler(AsyncDatagramRequestHandler[str, str]): - sleep_time_before_second_yield: float = 0.0 - sleep_time_before_response: float = 0.0 + sleep_time_before_second_yield: float | None = None + sleep_time_before_response: float | None = None recreate_generator: bool = True async def handle(self, client: AsyncDatagramClient[str]) -> AsyncGenerator[None, str]: while True: assert (yield) == "something" - await asyncio.sleep(self.sleep_time_before_second_yield) + if self.sleep_time_before_second_yield is not None: + await asyncio.sleep(self.sleep_time_before_second_yield) request = yield - await asyncio.sleep(self.sleep_time_before_response) + if self.sleep_time_before_response is not None: + await asyncio.sleep(self.sleep_time_before_response) await client.send_packet(f"After wait: {request}") if self.recreate_generator: break @@ -660,9 +662,16 @@ async def test____serve_forever____too_many_datagrams_while_request_handle_is_pe await endpoint.sendto(b"something", None) await asyncio.sleep(0.1) await endpoint.sendto(b"hello, world.", None) + for i in range(3): + await endpoint.sendto(b"something", None) + await endpoint.sendto(f"hello, world {i+2} times.".encode(), None) await endpoint.sendto(b"something", None) await asyncio.sleep(0.1) + request_handler.sleep_time_before_response = None await endpoint.sendto(b"hello, world. new game +", None) async with asyncio.timeout(5): assert (await endpoint.recvfrom())[0] == b"After wait: hello, world." + assert (await endpoint.recvfrom())[0] == b"After wait: hello, world 2 times." + assert (await endpoint.recvfrom())[0] == b"After wait: hello, world 3 times." + assert (await endpoint.recvfrom())[0] == b"After wait: hello, world 4 times." assert (await endpoint.recvfrom())[0] == b"After wait: hello, world. new game +" diff --git a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py index 78522555..225ce4fc 100644 --- a/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py +++ b/tests/unit_test/test_async/test_lowlevel_api/test_servers/test_datagram.py @@ -213,7 +213,6 @@ def test____dunder_init____default( # Arrange # Act & Assert - assert isinstance(client_data.task_lock, asyncio.Lock) assert client_data.state is None def test____client_state____regular_state_transition(